use std::{collections::HashMap, rc::Rc, str::FromStr}; use crate::{ ast, builtin::Builtin, symbol_table::{DefId, SymbolSpec, SymbolTable}, type_inference::{TypeContext, TypeId}, }; mod test; mod types; pub use types::*; pub fn reduce(ast: &ast::AST, symbol_table: &SymbolTable, type_context: &TypeContext) -> ReducedIR { let reducer = Reducer::new(symbol_table, type_context); reducer.reduce(ast) } struct Reducer<'a, 'b> { symbol_table: &'a SymbolTable, functions: HashMap, type_context: &'b TypeContext, } impl<'a, 'b> Reducer<'a, 'b> { fn new(symbol_table: &'a SymbolTable, type_context: &'b TypeContext) -> Self { Self { symbol_table, functions: HashMap::new(), type_context } } fn reduce(mut self, ast: &ast::AST) -> ReducedIR { // First reduce all functions // TODO once this works, maybe rewrite it using the Visitor for statement in ast.statements.statements.iter() { self.top_level_statement(statement); } // Then compute the entrypoint statements (which may reference previously-computed // functions by ID) let mut entrypoint = vec![]; for statement in ast.statements.statements.iter() { let ast::Statement { id: item_id, kind, .. } = statement; match &kind { ast::StatementKind::Expression(expr) => { entrypoint.push(Statement::Expression(self.expression(expr))); } ast::StatementKind::Declaration(ast::Declaration::Binding { name: _, constant, expr, .. }) => { let symbol = self.symbol_table.lookup_symbol(item_id).unwrap(); let def_id = symbol.def_id().unwrap(); entrypoint.push(Statement::Binding { id: def_id, constant: *constant, expr: self.expression(expr), }); } _ => (), } } ReducedIR { functions: self.functions, entrypoint } } fn top_level_statement(&mut self, statement: &ast::Statement) { let ast::Statement { id: item_id, kind, .. } = statement; match kind { ast::StatementKind::Expression(_expr) => { //TODO expressions can in principle contain definitions, but I won't worry //about it now } ast::StatementKind::Declaration(decl) => if let ast::Declaration::FuncDecl(_, statements) = decl { self.insert_function_definition(item_id, statements); }, // Imports should have already been processed by the symbol table and are irrelevant // for this representation. ast::StatementKind::Import(..) => (), ast::StatementKind::Flow(..) => { //TODO this should be an error } } } fn function_internal_statement(&mut self, statement: &ast::Statement) -> Option { let ast::Statement { id: item_id, kind, .. } = statement; match kind { ast::StatementKind::Expression(expr) => Some(Statement::Expression(self.expression(expr))), ast::StatementKind::Declaration(decl) => match decl { ast::Declaration::FuncDecl(_, statements) => { self.insert_function_definition(item_id, statements); None } ast::Declaration::Binding { constant, expr, .. } => { let symbol = self.symbol_table.lookup_symbol(item_id).unwrap(); let def_id = symbol.def_id().unwrap(); Some(Statement::Binding { id: def_id, constant: *constant, expr: self.expression(expr) }) } _ => None, }, ast::StatementKind::Import(_) => None, ast::StatementKind::Flow(ast::FlowControl::Return(expr)) => if let Some(expr) = expr { Some(Statement::Return(self.expression(expr))) } else { Some(Statement::Return(Expression::unit())) }, ast::StatementKind::Flow(ast::FlowControl::Break) => Some(Statement::Break), ast::StatementKind::Flow(ast::FlowControl::Continue) => Some(Statement::Continue), } } fn insert_function_definition(&mut self, item_id: &ast::ItemId, statements: &ast::Block) { let symbol = self.symbol_table.lookup_symbol(item_id).unwrap(); let def_id = symbol.def_id().unwrap(); let function_def = FunctionDefinition { body: self.function_internal_block(statements) }; self.functions.insert(def_id, function_def); } fn expression(&mut self, expr: &ast::Expression) -> Expression { use crate::ast::ExpressionKind::*; match &expr.kind { NatLiteral(n) => Expression::Literal(Literal::Nat(*n)), FloatLiteral(f) => Expression::Literal(Literal::Float(*f)), StringLiteral(s) => Expression::Literal(Literal::StringLit(s.clone())), BoolLiteral(b) => Expression::Literal(Literal::Bool(*b)), BinExp(binop, lhs, rhs) => self.binop(binop, lhs, rhs), PrefixExp(op, arg) => self.prefix(op, arg), Value(qualified_name) => self.value(qualified_name), Call { f, arguments } => Expression::Call { f: Box::new(self.expression(f)), args: arguments.iter().map(|arg| self.invocation_argument(arg)).collect(), }, TupleLiteral(exprs) => Expression::Tuple(exprs.iter().map(|e| self.expression(e)).collect()), IfExpression { discriminator, body } => self.reduce_if_expression(discriminator.as_ref().map(|x| x.as_ref()), body), Lambda { params, body, .. } => Expression::Callable(Callable::Lambda { arity: params.len() as u8, body: self.function_internal_block(body), }), NamedStruct { name, fields } => { let symbol = match self.symbol_table.lookup_symbol(&name.id) { Some(symbol) => symbol, None => return Expression::ReductionError(format!("No symbol found for {:?}", name)), }; let (tag, type_id) = match symbol.spec() { SymbolSpec::RecordConstructor { tag, type_id } => (tag, type_id), e => return Expression::ReductionError(format!("Bad symbol for NamedStruct: {:?}", e)), }; let field_order = compute_field_orderings(self.type_context, &type_id, tag).unwrap(); let mut field_map = HashMap::new(); for (name, expr) in fields.iter() { field_map.insert(name.as_ref(), expr); } let mut ordered_args = vec![]; for field in field_order.iter() { let expr = match field_map.get(&field) { Some(expr) => expr, None => return Expression::ReductionError(format!( "Field {} not specified for record {}", field, name )), }; ordered_args.push(self.expression(expr)); } let constructor = Expression::Callable(Callable::RecordConstructor { type_id, tag, field_order }); Expression::Call { f: Box::new(constructor), args: ordered_args } } Index { indexee, indexers } => self.reduce_index(indexee.as_ref(), indexers.as_slice()), WhileExpression { condition, body } => { let cond = Box::new(if let Some(condition) = condition { self.expression(condition) } else { Expression::Literal(Literal::Bool(true)) }); let statements = self.function_internal_block(body); Expression::Loop { cond, statements } } ForExpression { .. } => Expression::ReductionError("For expr not implemented".to_string()), ListLiteral(items) => Expression::List(items.iter().map(|item| self.expression(item)).collect()), Access { name, expr } => Expression::Access { name: name.as_ref().to_string(), expr: Box::new(self.expression(expr)) }, } } //TODO figure out the semantics of multiple indexers - for now, just ignore them fn reduce_index(&mut self, indexee: &ast::Expression, indexers: &[ast::Expression]) -> Expression { if indexers.len() != 1 { return Expression::ReductionError("Invalid index expression".to_string()); } let indexee = self.expression(indexee); let indexer = self.expression(&indexers[0]); Expression::Index { indexee: Box::new(indexee), indexer: Box::new(indexer) } } fn reduce_if_expression( &mut self, discriminator: Option<&ast::Expression>, body: &ast::IfExpressionBody, ) -> Expression { use ast::IfExpressionBody::*; let cond = Box::new(match discriminator { Some(expr) => self.expression(expr), None => return Expression::ReductionError("blank cond if-expr not supported".to_string()), }); match body { SimpleConditional { then_case, else_case } => { let then_clause = self.function_internal_block(then_case); let else_clause = match else_case.as_ref() { None => vec![], Some(stmts) => self.function_internal_block(stmts), }; Expression::Conditional { cond, then_clause, else_clause } } SimplePatternMatch { pattern, then_case, else_case } => { let alternatives = vec![ Alternative { pattern: match pattern.reduce(self.symbol_table) { Ok(p) => p, Err(e) => return Expression::ReductionError(format!("Bad pattern: {:?}", e)), }, item: self.function_internal_block(then_case), }, Alternative { pattern: Pattern::Ignored, item: match else_case.as_ref() { Some(else_case) => self.function_internal_block(else_case), None => vec![], }, }, ]; Expression::CaseMatch { cond, alternatives } } CondList(ref condition_arms) => { let mut alternatives = vec![]; for arm in condition_arms { match arm.condition { ast::Condition::Pattern(ref pat) => { let alt = Alternative { pattern: match pat.reduce(self.symbol_table) { Ok(p) => p, Err(e) => return Expression::ReductionError(format!("Bad pattern: {:?}", e)), }, item: self.function_internal_block(&arm.body), }; alternatives.push(alt); } ast::Condition::TruncatedOp(_, _) => return Expression::ReductionError("case-expression-trunc-op".to_string()), ast::Condition::Else => return Expression::ReductionError("case-expression-else".to_string()), } } Expression::CaseMatch { cond, alternatives } } } } fn invocation_argument(&mut self, invoc: &ast::InvocationArgument) -> Expression { use crate::ast::InvocationArgument::*; match invoc { Positional(ex) => self.expression(ex), Keyword { .. } => Expression::ReductionError("Keyword arguments not supported".to_string()), Ignored => Expression::ReductionError("Ignored arguments not supported".to_string()), } } fn function_internal_block(&mut self, block: &ast::Block) -> Vec { block.statements.iter().filter_map(|stmt| self.function_internal_statement(stmt)).collect() } fn prefix(&mut self, prefix: &ast::PrefixOp, arg: &ast::Expression) -> Expression { let builtin: Option = TryFrom::try_from(prefix).ok(); match builtin { Some(op) => Expression::Call { f: Box::new(Expression::Callable(Callable::Builtin(op))), args: vec![self.expression(arg)], }, None => { //TODO need this for custom prefix ops Expression::ReductionError("User-defined prefix ops not supported".to_string()) } } } fn binop(&mut self, binop: &ast::BinOp, lhs: &ast::Expression, rhs: &ast::Expression) -> Expression { use Expression::ReductionError; let operation = Builtin::from_str(binop.sigil()).ok(); match operation { Some(Builtin::Assignment) => { let lval = match &lhs.kind { ast::ExpressionKind::Value(qualified_name) => { if let Some(symbol) = self.symbol_table.lookup_symbol(&qualified_name.id) { symbol.def_id().unwrap() } else { return ReductionError(format!("Couldn't look up name: {:?}", qualified_name)); } } _ => return ReductionError("Trying to assign to a non-name".to_string()), }; Expression::Assign { lval, rval: Box::new(self.expression(rhs)) } } Some(op) => Expression::Call { f: Box::new(Expression::Callable(Callable::Builtin(op))), args: vec![self.expression(lhs), self.expression(rhs)], }, //TODO handle a user-defined operation None => ReductionError("User-defined operations not supported".to_string()), } } fn value(&mut self, qualified_name: &ast::QualifiedName) -> Expression { use SymbolSpec::*; let symbol = match self.symbol_table.lookup_symbol(&qualified_name.id) { Some(s) => s, None => return Expression::ReductionError(format!("No symbol found for name: {:?}", qualified_name)), }; let def_id = symbol.def_id(); match symbol.spec() { Builtin(b) => Expression::Callable(Callable::Builtin(b)), Func => Expression::Lookup(Lookup::Function(def_id.unwrap())), GlobalBinding => Expression::Lookup(Lookup::GlobalVar(def_id.unwrap())), LocalVariable => Expression::Lookup(Lookup::LocalVar(def_id.unwrap())), FunctionParam(n) => Expression::Lookup(Lookup::Param(n)), DataConstructor { tag, type_id } => Expression::Callable(Callable::DataConstructor { type_id, tag }), RecordConstructor { .. } => Expression::ReductionError(format!( "The symbol for value {:?} is unexpectdly a RecordConstructor", qualified_name )), } } } impl ast::Pattern { fn reduce(&self, symbol_table: &SymbolTable) -> Result { Ok(match self { ast::Pattern::Ignored => Pattern::Ignored, ast::Pattern::TuplePattern(subpatterns) => { let items: Result, PatternError> = subpatterns.iter().map(|pat| pat.reduce(symbol_table)).into_iter().collect(); let items = items?; Pattern::Tuple { tag: None, subpatterns: items } } ast::Pattern::Literal(lit) => Pattern::Literal(match lit { ast::PatternLiteral::NumPattern { neg, num } => match (neg, num) { (false, ast::ExpressionKind::NatLiteral(n)) => Literal::Nat(*n), (false, ast::ExpressionKind::FloatLiteral(f)) => Literal::Float(*f), (true, ast::ExpressionKind::NatLiteral(n)) => Literal::Int(-(*n as i64)), (true, ast::ExpressionKind::FloatLiteral(f)) => Literal::Float(-f), (_, e) => return Err(format!("Internal error, unexpected pattern literal: {:?}", e).into()), }, ast::PatternLiteral::StringPattern(s) => Literal::StringLit(s.clone()), ast::PatternLiteral::BoolPattern(b) => Literal::Bool(*b), }), ast::Pattern::TupleStruct(name, subpatterns) => { let symbol = symbol_table.lookup_symbol(&name.id).unwrap(); if let SymbolSpec::DataConstructor { tag, type_id: _ } = symbol.spec() { let items: Result, PatternError> = subpatterns.iter().map(|pat| pat.reduce(symbol_table)).into_iter().collect(); let items = items?; Pattern::Tuple { tag: Some(tag), subpatterns: items } } else { return Err( "Internal error, trying to match something that's not a DataConstructor".into() ); } } ast::Pattern::VarOrName(name) => { let symbol = symbol_table.lookup_symbol(&name.id).unwrap(); match symbol.spec() { SymbolSpec::DataConstructor { tag, type_id: _ } => Pattern::Tuple { tag: Some(tag), subpatterns: vec![] }, SymbolSpec::LocalVariable => { let def_id = symbol.def_id().unwrap(); Pattern::Binding(def_id) } spec => return Err(format!("Unexpected VarOrName symbol: {:?}", spec).into()), } } ast::Pattern::Record(name, specified_members) => { let symbol = symbol_table.lookup_symbol(&name.id).unwrap(); if let SymbolSpec::RecordConstructor { tag, type_id: _ } = symbol.spec() { //TODO do this computation from the type_id /* if specified_members.iter().any(|(member, _)| !members.contains_key(member)) { return Err(format!("Unknown key in record pattern").into()); } */ let subpatterns: Result, Pattern)>, PatternError> = specified_members .iter() .map(|(name, pat)| { pat.reduce(symbol_table).map(|reduced_pat| (name.clone(), reduced_pat)) }) .into_iter() .collect(); let subpatterns = subpatterns?; Pattern::Record { tag, subpatterns } } else { return Err(format!("Unexpected Record pattern symbol: {:?}", symbol.spec()).into()); } } }) } } /// Given the type context and a variant, compute what order the fields on it were stored. /// This needs to be public until type-checking is fully implemented because the type information /// is only available at runtime. pub fn compute_field_orderings( type_context: &TypeContext, type_id: &TypeId, tag: u32, ) -> Option> { // Eventually, the ReducedIR should decide what field ordering is optimal. // For now, just do it alphabetically. let record_members = type_context.lookup_record_members(type_id, tag)?; let mut field_order: Vec = record_members.iter().map(|(field, _type_id)| field).cloned().collect(); field_order.sort_unstable(); Some(field_order) }