From e947569100750076ed96b077ade11d49b4199435 Mon Sep 17 00:00:00 2001 From: Greg Shuflin Date: Sat, 23 Oct 2021 00:22:12 -0700 Subject: [PATCH] Rewrite Visitor And implement the scope resolver in terms of it --- schala-lang/language/src/ast/mod.rs | 14 +- schala-lang/language/src/ast/visitor.rs | 247 +++++++++++++--- schala-lang/language/src/ast/visitor_test.rs | 41 --- schala-lang/language/src/ast/walker.rs | 273 ------------------ .../language/src/symbol_table/resolver.rs | 42 ++- 5 files changed, 231 insertions(+), 386 deletions(-) delete mode 100644 schala-lang/language/src/ast/visitor_test.rs delete mode 100644 schala-lang/language/src/ast/walker.rs diff --git a/schala-lang/language/src/ast/mod.rs b/schala-lang/language/src/ast/mod.rs index 3f54d24..8e22b2a 100644 --- a/schala-lang/language/src/ast/mod.rs +++ b/schala-lang/language/src/ast/mod.rs @@ -1,16 +1,14 @@ #![allow(clippy::upper_case_acronyms)] #![allow(clippy::enum_variant_names)] +mod visitor; +mod operators; + +pub use operators::{PrefixOp, BinOp}; +pub use visitor::{walk_ast, ASTVisitor}; + use std::rc::Rc; use crate::derivative::Derivative; - -mod walker; -mod visitor; -mod visitor_test; -mod operators; -pub use operators::{PrefixOp, BinOp}; -pub use visitor::ASTVisitor; -pub use walker::walk_ast; use crate::tokenizing::Location; /// An abstract identifier for an AST node. Note that diff --git a/schala-lang/language/src/ast/visitor.rs b/schala-lang/language/src/ast/visitor.rs index 3514e98..8d16ed2 100644 --- a/schala-lang/language/src/ast/visitor.rs +++ b/schala-lang/language/src/ast/visitor.rs @@ -1,43 +1,212 @@ -use std::rc::Rc; use crate::ast::*; -//TODO maybe these functions should take closures that return a KeepRecursing | StopHere type, -//or a tuple of (T, ) - pub trait ASTVisitor: Sized { - fn ast(&mut self, _ast: &AST) {} - fn block(&mut self, _statements: &[ Statement ]) {} - fn statement(&mut self, _statement: &Statement) {} - fn declaration(&mut self, _declaration: &Declaration) {} - fn signature(&mut self, _signature: &Signature) {} - fn type_declaration(&mut self, _name: &TypeSingletonName, _body: &TypeBody, _mutable: bool) {} - fn type_alias(&mut self, _alias: &Rc, _original: &Rc) {} - fn binding(&mut self, _name: &Rc, _constant: bool, _type_anno: Option<&TypeIdentifier>, _expr: &Expression) {} - fn implemention(&mut self, _type_name: &TypeIdentifier, _interface_name: Option<&TypeSingletonName>, _block: &[ Declaration ]) {} - fn interface(&mut self, _name: &Rc, _signatures: &[ Signature ]) {} - fn expression(&mut self, _expression: &Expression) {} - fn expression_kind(&mut self, _kind: &ExpressionKind) {} - fn type_annotation(&mut self, _type_anno: Option<&TypeIdentifier>) {} - fn named_struct(&mut self, _name: &QualifiedName, _fields: &[ (Rc, Expression) ]) {} - fn call(&mut self, _f: &Expression, _arguments: &[ InvocationArgument ]) {} - fn index(&mut self, _indexee: &Expression, _indexers: &[ Expression ]) {} - fn if_expression(&mut self, _discrim: Option<&Expression>, _body: &IfExpressionBody) {} - fn condition_arm(&mut self, _arm: &ConditionArm) {} - #[allow(clippy::ptr_arg)] - fn while_expression(&mut self, _condition: Option<&Expression>, _body: &Block) {} - fn for_expression(&mut self, _enumerators: &[ Enumerator ], _body: &ForBody) {} - #[allow(clippy::ptr_arg)] - fn lambda(&mut self, _params: &[ FormalParam ], _type_anno: Option<&TypeIdentifier>, _body: &Block) {} - fn invocation_argument(&mut self, _arg: &InvocationArgument) {} - fn formal_param(&mut self, _param: &FormalParam) {} - fn import(&mut self, _import: &ImportSpecifier) {} - fn module(&mut self, _module: &ModuleSpecifier) {} - fn qualified_name(&mut self, _name: &QualifiedName) {} - fn nat_literal(&mut self, _n: u64) {} - fn float_literal(&mut self, _f: f64) {} - fn string_literal(&mut self, _s: &Rc) {} - fn bool_literal(&mut self, _b: bool) {} - fn binexp(&mut self, _op: &BinOp, _lhs: &Expression, _rhs: &Expression) {} - fn prefix_exp(&mut self, _op: &PrefixOp, _arg: &Expression) {} - fn pattern(&mut self, _pat: &Pattern) {} + fn expression(&mut self, _expression: &Expression) {} + fn expression_post(&mut self, _expression: &Expression) {} + + fn declaration(&mut self, _declaration: &Declaration) {} + fn declaration_post(&mut self, _declaration: &Declaration) {} + + fn import(&mut self, _import: &ImportSpecifier) {} + fn module(&mut self, _module: &ModuleSpecifier) {} + fn module_post(&mut self, _module: &ModuleSpecifier) {} + + fn pattern(&mut self, _pat: &Pattern) {} + fn pattern_post(&mut self, _pat: &Pattern) {} +} + +pub fn walk_ast(v: &mut V, ast: &AST) { + walk_block(v, &ast.statements); +} + +fn walk_block(v: &mut V, block: &Block) { + use StatementKind::*; + for statement in block.iter() { + match statement.kind { + StatementKind::Expression(ref expr) => { + walk_expression(v, expr); + } + Declaration(ref decl) => { + walk_declaration(v, decl); + } + Import(ref import_spec) => v.import(import_spec), + Module(ref module_spec) => { + v.module(module_spec); + walk_block(v, &module_spec.contents); + v.module_post(module_spec); + } + } + } +} + +fn walk_declaration(v: &mut V, decl: &Declaration) { + use Declaration::*; + + v.declaration(decl); + + match decl { + FuncDecl(_sig, block) => { + walk_block(v, block); + } + Binding { + name: _, + constant: _, + type_anno: _, + expr, + } => { + walk_expression(v, expr); + } + _ => (), + }; + v.declaration_post(decl); +} + +fn walk_expression(v: &mut V, expr: &Expression) { + use ExpressionKind::*; + + v.expression(expr); + + match &expr.kind { + NatLiteral(_) | FloatLiteral(_) | StringLiteral(_) | BoolLiteral(_) | Value(_) => (), + BinExp(_, lhs, rhs) => { + walk_expression(v, &lhs); + walk_expression(v, &rhs); + } + PrefixExp(_, arg) => { + walk_expression(v, &arg); + } + TupleLiteral(exprs) => { + for expr in exprs { + walk_expression(v, &expr); + } + } + NamedStruct { name: _, fields } => { + for (_, expr) in fields.iter() { + walk_expression(v, expr); + } + } + Call { f, arguments } => { + walk_expression(v, &f); + for arg in arguments.iter() { + match arg { + InvocationArgument::Positional(expr) => walk_expression(v, expr), + InvocationArgument::Keyword { expr, .. } => walk_expression(v, expr), //TODO maybe I can combine this pattern + _ => (), + } + } + } + Index { indexee, indexers } => { + walk_expression(v, &indexee); + for indexer in indexers.iter() { + walk_expression(v, indexer); + } + } + IfExpression { + discriminator, + body, + } => { + if let Some(d) = discriminator.as_ref() { + walk_expression(v, &d); + } + walk_if_expr_body(v, &body.as_ref()); + } + WhileExpression { condition, body } => { + if let Some(d) = condition.as_ref() { + walk_expression(v, d); + } + walk_block(v, &body); + } + ForExpression { enumerators, body } => { + for enumerator in enumerators { + walk_expression(v, &enumerator.generator); + } + match body.as_ref() { + ForBody::MonadicReturn(expr) => walk_expression(v, expr), + ForBody::StatementBlock(block) => walk_block(v, block), + }; + } + Lambda { + params: _, + type_anno: _, + body, + } => { + walk_block(v, &body); + } + ListLiteral(exprs) => { + for expr in exprs { + walk_expression(v, &expr); + } + } + }; + v.expression_post(expr); +} + +fn walk_if_expr_body(v: &mut V, body: &IfExpressionBody) { + use IfExpressionBody::*; + + match body { + SimpleConditional { + then_case, + else_case, + } => { + walk_block(v, then_case); + if let Some(block) = else_case.as_ref() { + walk_block(v, block) + } + } + SimplePatternMatch { + pattern, + then_case, + else_case, + } => { + walk_pattern(v, pattern); + walk_block(v, &then_case); + if let Some(ref block) = else_case.as_ref() { + walk_block(v, &block) + } + } + CondList(arms) => { + for arm in arms { + match arm.condition { + Condition::Pattern(ref pat) => { + walk_pattern(v, pat); + } + Condition::TruncatedOp(ref _binop, ref expr) => { + walk_expression(v, expr); + } + Condition::Expression(ref expr) => { + walk_expression(v, expr); + } + _ => (), + } + } + } + } +} + +fn walk_pattern(v: &mut V, pat: &Pattern) { + use Pattern::*; + + v.pattern(pat); + + match pat { + TuplePattern(patterns) => { + for pat in patterns { + walk_pattern(v, pat); + } + } + TupleStruct(_, patterns) => { + for pat in patterns { + walk_pattern(v, pat); + } + } + Record(_, name_and_patterns) => { + for (_, pat) in name_and_patterns { + walk_pattern(v, pat); + } + } + _ => (), + }; + + v.pattern_post(pat); } diff --git a/schala-lang/language/src/ast/visitor_test.rs b/schala-lang/language/src/ast/visitor_test.rs deleted file mode 100644 index 7c92a7a..0000000 --- a/schala-lang/language/src/ast/visitor_test.rs +++ /dev/null @@ -1,41 +0,0 @@ -#![cfg(test)] - -use crate::ast::visitor::ASTVisitor; -use crate::ast::walker; -use crate::util::quick_ast; - -struct Tester { - count: u64, - float_count: u64 -} - -impl ASTVisitor for Tester { - fn nat_literal(&mut self, _n: u64) { - self.count += 1; - } - fn float_literal(&mut self, _f: f64) { - self.float_count += 1; - } -} - - -#[test] -fn foo() { - let mut tester = Tester { count: 0, float_count: 0 }; - let ast = quick_ast(r#" -import gragh - -let a = 20 + 84 -let b = 28 + 1 + 2 + 2.0 -fn heh() { - let m = 9 - -} - -"#); - - walker::walk_ast(&mut tester, &ast); - - assert_eq!(tester.count, 6); - assert_eq!(tester.float_count, 1); -} diff --git a/schala-lang/language/src/ast/walker.rs b/schala-lang/language/src/ast/walker.rs deleted file mode 100644 index 4fc63b8..0000000 --- a/schala-lang/language/src/ast/walker.rs +++ /dev/null @@ -1,273 +0,0 @@ -#![allow(dead_code)] -use std::rc::Rc; -use crate::ast::*; -use crate::ast::visitor::ASTVisitor; -use crate::util::deref_optional_box; - -pub fn walk_ast(v: &mut V, ast: &AST) { - v.ast(ast); - walk_block(v, &ast.statements); -} - -#[allow(clippy::ptr_arg)] -fn walk_block(v: &mut V, block: &Vec) { - for s in block { - v.statement(s); - statement(v, s); - } -} - -fn statement(v: &mut V, statement: &Statement) { - use StatementKind::*; - match statement.kind { - Expression(ref expr) => { - v.expression(expr); - expression(v, expr); - }, - Declaration(ref decl) => { - v.declaration(decl); - declaration(v, decl); - }, - Import(ref import_spec) => v.import(import_spec), - Module(ref module_spec) => { - v.module(module_spec); - walk_block(v, &module_spec.contents); - } - } -} - -fn declaration(v: &mut V, decl: &Declaration) { - use Declaration::*; - match decl { - FuncSig(sig) => { - v.signature(sig); - signature(v, sig); - }, - FuncDecl(sig, block) => { - v.signature(sig); - v.block(block); - walk_block(v, block); - }, - TypeDecl { name, body, mutable } => v.type_declaration(name, body, *mutable), - TypeAlias { alias, original} => v.type_alias(alias, original), - Binding { name, constant, type_anno, expr } => { - v.binding(name, *constant, type_anno.as_ref(), expr); - v.type_annotation(type_anno.as_ref()); - v.expression(expr); - expression(v, expr); - }, - Impl { type_name, interface_name, block } => { - v.implemention(type_name, interface_name.as_ref(), block); - } - Interface { name, signatures } => v.interface(name, signatures), - //TODO fill this in - Annotation { .. } => () - } -} - -fn signature(v: &mut V, signature: &Signature) { - for p in signature.params.iter() { - v.formal_param(p); - } - v.type_annotation(signature.type_anno.as_ref()); - for p in signature.params.iter() { - formal_param(v, p); - } -} - -fn expression(v: &mut V, expression: &Expression) { - v.expression_kind(&expression.kind); - v.type_annotation(expression.type_anno.as_ref()); - expression_kind(v, &expression.kind); -} - - -fn call(v: &mut V, f: &Expression, args: &[ InvocationArgument ]) { - v.expression(f); - expression(v, f); - for arg in args.iter() { - v.invocation_argument(arg); - invocation_argument(v, arg); - } -} - -fn invocation_argument(v: &mut V, arg: &InvocationArgument) { - use InvocationArgument::*; - match arg { - Positional(expr) => { - v.expression(expr); - expression(v, expr); - }, - Keyword { expr, .. } => { - v.expression(expr); - expression(v, expr); - }, - Ignored => (), - } -} - -fn index(v: &mut V, indexee: &Expression, indexers: &[ Expression ]) { - v.expression(indexee); - for i in indexers.iter() { - v.expression(i); - } -} - -fn named_struct(v: &mut V, n: &QualifiedName, fields: &[ (Rc, Expression) ]) { - v.qualified_name(n); - for (_, expr) in fields.iter() { - v.expression(expr); - } -} - -#[allow(clippy::ptr_arg)] -fn lambda(v: &mut V, params: &Vec, type_anno: Option<&TypeIdentifier>, body: &Block) { - for param in params { - v.formal_param(param); - formal_param(v, param); - } - v.type_annotation(type_anno); - v.block(body); - walk_block(v, body); -} - -fn formal_param(v: &mut V, param: &FormalParam) { - if let Some(p) = param.default.as_ref() { - v.expression(p); - expression(v, p); - }; - v.type_annotation(param.anno.as_ref()); -} - -fn expression_kind(v: &mut V, expression_kind: &ExpressionKind) { - use ExpressionKind::*; - match expression_kind { - NatLiteral(n) => v.nat_literal(*n), - FloatLiteral(f) => v.float_literal(*f), - StringLiteral(s) => v.string_literal(s), - BoolLiteral(b) => v.bool_literal(*b), - BinExp(op, lhs, rhs) => { - v.binexp(op, lhs, rhs); - expression(v, lhs); - expression(v, rhs); - }, - PrefixExp(op, arg) => { - v.prefix_exp(op, arg); - expression(v, arg); - } - TupleLiteral(exprs) => { - for expr in exprs { - v.expression(expr); - expression(v, expr); - } - }, - Value(name) => v.qualified_name(name), - NamedStruct { name, fields } => { - v.named_struct(name, fields); - named_struct(v, name, fields); - } - Call { f, arguments } => { - v.call(f, arguments); - call(v, f, arguments); - }, - Index { indexee, indexers } => { - v.index(indexee, indexers); - index(v, indexee, indexers); - }, - IfExpression { discriminator, body } => { - v.if_expression(deref_optional_box(discriminator), body); - if let Some(d) = discriminator.as_ref() { expression(v, d) } - if_expression_body(v, body); - }, - WhileExpression { condition, body } => v.while_expression(deref_optional_box(condition), body), - ForExpression { enumerators, body } => v.for_expression(enumerators, body), - Lambda { params , type_anno, body } => { - v.lambda(params, type_anno.as_ref(), body); - lambda(v, params, type_anno.as_ref(), body); - }, - ListLiteral(exprs) => { - for expr in exprs { - v.expression(expr); - expression(v, expr); - } - }, - } -} - -fn if_expression_body(v: &mut V, body: &IfExpressionBody) { - use IfExpressionBody::*; - match body { - SimpleConditional { then_case, else_case } => { - walk_block(v, then_case); - if let Some(block) = else_case.as_ref() { walk_block(v, block) } - }, - SimplePatternMatch { pattern, then_case, else_case } => { - v.pattern(pattern); - walk_pattern(v, pattern); - walk_block(v, then_case); - if let Some(block) = else_case.as_ref() { walk_block(v, block) } - }, - CondList(arms) => { - for arm in arms { - v.condition_arm(arm); - condition_arm(v, arm); - } - } - } -} - -fn condition_arm(v: &mut V, arm: &ConditionArm) { - use Condition::*; - v.condition_arm(arm); - match arm.condition { - Pattern(ref pat) => { - v.pattern(pat); - walk_pattern(v, pat); - }, - TruncatedOp(ref _binop, ref expr) => { - v.expression(expr); - expression(v, expr); - }, - Expression(ref expr) => { - v.expression(expr); - expression(v, expr); - }, - _ => () - } - if let Some(guard) = arm.guard.as_ref() { - v.expression(guard); - expression(v, guard); - }; - v.block(&arm.body); - walk_block(v, &arm.body); -} - -fn walk_pattern(v: &mut V, pat: &Pattern) { - use Pattern::*; - match pat { - TuplePattern(patterns) => { - for pat in patterns { - v.pattern(pat); - walk_pattern(v, pat); - } - }, - TupleStruct(qualified_name, patterns) => { - v.qualified_name(qualified_name); - for pat in patterns { - v.pattern(pat); - walk_pattern(v, pat); - } - }, - Record(qualified_name, name_and_patterns) => { - v.qualified_name(qualified_name); - for (_, pat) in name_and_patterns { - v.pattern(pat); - walk_pattern(v, pat); - } - }, - VarOrName(qualified_name) => { - v.qualified_name(qualified_name); - }, - _ => () - } -} diff --git a/schala-lang/language/src/symbol_table/resolver.rs b/schala-lang/language/src/symbol_table/resolver.rs index 00111e0..ae27f35 100644 --- a/schala-lang/language/src/symbol_table/resolver.rs +++ b/schala-lang/language/src/symbol_table/resolver.rs @@ -46,12 +46,11 @@ impl<'a> Resolver<'a> { } } - // This might be a variable or a pattern, depending on whether this symbol - // already exists in the table. - fn qualified_name_in_pattern(&mut self, qualified_name: &QualifiedName) { - let maybe_sym = self.symbol_table.id_to_symbol.get(&qualified_name.id).cloned(); - if let Some(symbol) = maybe_sym { - self.symbol_table.id_to_symbol.insert(qualified_name.id.clone(), symbol); + fn qualified_name(&mut self, name: &QualifiedName) { + let fqsn = self.lookup_name_in_scope(name); + let symbol = self.symbol_table.fqsn_to_symbol.get(&fqsn); + if let Some(symbol) = symbol { + self.symbol_table.id_to_symbol.insert(name.id.clone(), symbol.clone()); } } } @@ -100,35 +99,28 @@ impl<'a> ASTVisitor for Resolver<'a> { }; } - fn qualified_name(&mut self, qualified_name: &QualifiedName) { - let fqsn = self.lookup_name_in_scope(qualified_name); - let symbol = self.symbol_table.fqsn_to_symbol.get(&fqsn); - if let Some(symbol) = symbol { - self.symbol_table.id_to_symbol.insert(qualified_name.id.clone(), symbol.clone()); - } - } - - fn named_struct( - &mut self, - qualified_name: &QualifiedName, - _fields: &[(Rc, Expression)], - ) { - let fqsn = self.lookup_name_in_scope(qualified_name); - - let symbol = self.symbol_table.fqsn_to_symbol.get(&fqsn); - if let Some(symbol) = symbol { - self.symbol_table.id_to_symbol.insert(qualified_name.id.clone(), symbol.clone()); + fn expression(&mut self, expression: &Expression) { + use ExpressionKind::*; + match &expression.kind { + Value(name) => { + self.qualified_name(name); + }, + NamedStruct { name, fields: _ } => { + self.qualified_name(name); + }, + _ => (), } } fn pattern(&mut self, pat: &Pattern) { use Pattern::*; + match pat { //TODO I think not handling TuplePattern is an oversight TuplePattern(_) => (), Literal(_) | Ignored => (), TupleStruct(name, _) | Record(name, _) | VarOrName(name) => { - self.qualified_name_in_pattern(name) + self.qualified_name(name); } }; }