schala/schala-lang/language/src/ast_visitor.rs

158 lines
6.0 KiB
Rust

use std::rc::Rc;
use builtin::{PrefixOp, BinOp};
use ast::*;
pub fn dispatch<V: ASTVisitor>(visitor: &mut V, ast: &AST) {
for statement in ast.0.iter() {
match statement {
Statement::ExpressionStatement(e) => {
dispatch_expression(visitor, e);
visitor.expression(e);
},
Statement::Declaration(decl) => {
dispatch_declaration(visitor, decl);
visitor.declaration(decl);
},
};
visitor.statement(statement);
}
visitor.ast(ast)
}
fn dispatch_expression<V: ASTVisitor>(visitor: &mut V, expression: &Expression) {
match expression {
Expression(expr, maybe_anno) => {
match expr {
ExpressionType::NatLiteral(n) => visitor.nat_literal(n),
ExpressionType::FloatLiteral(f) => visitor.float_literal(f),
ExpressionType::StringLiteral(s) => visitor.string_literal(s),
ExpressionType::BoolLiteral(b) => visitor.bool_literal(b),
ExpressionType::BinExp(binop, lhs, rhs) => visitor.binop(binop, lhs, rhs),
ExpressionType::PrefixExp(prefix, expr) => visitor.prefixop(prefix, expr),
ExpressionType::TupleLiteral(v) => visitor.tuple_literal(v),
ExpressionType::Value(v) => visitor.value(v),
ExpressionType::NamedStruct { name, fields } => visitor.named_struct(name, fields),
ExpressionType::Call { f, arguments } => visitor.call(f, arguments),
ExpressionType::Index { indexee, indexers } => visitor.index(indexee, indexers),
ExpressionType::IfExpression { discriminator, body } => visitor.if_expression(discriminator, body),
ExpressionType::WhileExpression { condition, body } => visitor.while_expresssion(condition, body),
ExpressionType::ForExpression { enumerators, body } => visitor.for_expression(enumerators, body),
ExpressionType::Lambda { params, type_anno, body } => visitor.lambda_expression(params, type_anno, body),
ExpressionType::ListLiteral(items) => visitor.list_literal(items),
}
visitor.anno_expr(maybe_anno);
visitor.expr_kind(expr);
}
}
}
fn dispatch_declaration<V: ASTVisitor>(visitor: &mut V, declaration: &Declaration) {
match declaration {
Declaration::FuncSig(sig) => visitor.func_signature(sig),
Declaration::FuncDecl(sig, block) => visitor.func_declaration(sig, block),
Declaration::TypeDecl { name, body, mutable } => visitor.type_declaration(name, body, mutable),
Declaration::TypeAlias(alias, name) => visitor.type_alias(alias, name),
Declaration::Binding { name, constant, expr} => visitor.binding(name, constant, expr),
Declaration::Impl { type_name, interface_name, block } => visitor.impl_block(type_name, interface_name, block),
Declaration::Interface { name, signatures } => visitor.interface(name, signatures),
}
}
pub trait ASTVisitor {
fn ast(&mut self, _ast: &AST) { }
fn statement(&mut self, _stmt: &Statement) { }
fn expression(&mut self, _expr: &Expression) { }
fn anno_expr(&mut self, &Option<TypeIdentifier>) { }
fn expr_kind(&mut self, _expr: &ExpressionType) { }
fn nat_literal(&mut self, _n: &u64) { }
fn float_literal(&mut self, _f: &f64) { }
fn string_literal(&mut self, _s: &Rc<String>) { }
fn bool_literal(&mut self, _bool: &bool) { }
fn binop(&mut self, _binop: &BinOp, _lhs: &Expression, _rhs: &Expression) { }
fn prefixop(&mut self, prefix: &PrefixOp, _expr: &Expression) { }
fn tuple_literal(&mut self, _v: &Vec<Expression>) { }
fn value(&mut self, _v: &Rc<String>) { }
fn named_struct(&mut self, _name: &Rc<String>, _values: &Vec<(Rc<String>, Expression)>) { }
fn call(&mut self, _f: &Box<Expression>, _arguments: &Vec<Expression>) { }
fn index(&mut self, _indexee: &Box<Expression>, _indexers: &Vec<Expression>) { }
fn if_expression(&mut self, _discriminator: &Discriminator, _body: &IfExpressionBody) { }
fn while_expresssion(&mut self, _condition: &Option<Box<Expression>>, body: &Block) { }
fn for_expression(&mut self, _enumerators: &Vec<Enumerator>, _body: &Box<ForBody>) { }
fn lambda_expression(&mut self, _params: &Vec<FormalParam>, type_anno: &Option<TypeIdentifier>, body: &Block) { }
fn list_literal(&mut self, _items: &Vec<Expression>) { }
fn declaration(&mut self, _decl: &Declaration) { }
fn func_signature(&mut self, _sig: &Signature) { }
fn func_declaration(&mut self, _sig: &Signature, _block: &Vec<Statement>) { }
fn type_declaration(&mut self, _name: &TypeSingletonName, _body: &TypeBody, _mutable: &bool) { }
fn type_alias(&mut self, _alias: &Rc<String>, _name: &Rc<String>) { }
fn binding(&mut self, _name: &Rc<String>, _constant: &bool, _expr: &Expression) { }
fn impl_block(&mut self, _type_name: &TypeIdentifier, _interface_name: &Option<InterfaceName>, _block: &Vec<Declaration>) { }
fn interface(&mut self, name: &Rc<String>, signatures: &Vec<Signature>) { }
}
#[derive(Clone)]
struct SchalaPrinter {
s: String
}
impl SchalaPrinter {
fn new() -> SchalaPrinter {
SchalaPrinter {
s: format!("Schala source code:\n"),
}
}
fn done(self) -> String {
self.s
}
}
impl ASTVisitor for SchalaPrinter {
fn statement(&mut self, _: &Statement) {
self.s.push_str("\n");
}
fn expression(&mut self, _: &Expression) {
self.s.push_str("some_expr");
}
fn binding(&mut self, name: &Rc<String>, constant: &bool, _expr: &Expression) {
self.s.push_str(&format!("let{} {} = {}",
if *constant { "" } else { " mut" },
name,
"some_expr"));
}
}
#[cfg(test)]
mod visitor_tests {
use ::tokenizing::{Token, tokenize};
use ::parsing::ParseResult;
use ::ast::AST;
use super::*;
fn parse(input: &str) -> ParseResult<AST> {
let tokens = tokenize(input);
let mut parser = ::parsing::Parser::new(tokens);
parser.parse()
}
#[test]
fn test() {
let ast = parse("let a = 1 + 2; let b = 2 + 44;foo()").unwrap();
let mut pp = SchalaPrinter::new();
dispatch(&mut pp, &ast);
let result = pp.done();
assert_eq!(result, r#"Schala source code:
let a = 1 + 2
let b = 2 + 44
foo()
"#);
}
}