use std::rc::Rc; use super::{EvalResult, Memory, MemoryValue, Primitive, State}; use crate::{ builtin::Builtin, reduced_ir::{ Alternative, Callable, Expression, FunctionDefinition, Literal, Lookup, Pattern, ReducedIR, Statement, }, type_inference::TypeContext, util::ScopeStack, }; #[derive(Debug)] enum StatementOutput { Primitive(Primitive), Nothing, } #[derive(Debug, Clone, Copy)] enum LoopControlFlow { Break, Continue, } pub struct Evaluator<'a, 'b> { type_context: &'b TypeContext, state: &'b mut State<'a>, early_returning: bool, loop_control: Option, } impl<'a, 'b> Evaluator<'a, 'b> { pub(crate) fn new(state: &'b mut State<'a>, type_context: &'b TypeContext) -> Self { Self { state, type_context, early_returning: false, loop_control: None } } pub fn evaluate(&mut self, reduced: ReducedIR, repl: bool) -> Vec> { let mut acc = vec![]; for (def_id, function) in reduced.functions.into_iter() { let mem = (&def_id).into(); self.state.environments.insert(mem, MemoryValue::Function(function)); } for statement in reduced.entrypoint.into_iter() { match self.statement(statement) { Ok(StatementOutput::Primitive(output)) if repl => acc.push(Ok(output.to_repl(self.type_context))), Ok(_) => (), Err(error) => { acc.push(Err(error.msg)); return acc; } } } acc } fn block(&mut self, statements: Vec) -> EvalResult { let mut retval = None; for stmt in statements.into_iter() { match self.statement(stmt)? { StatementOutput::Nothing => (), StatementOutput::Primitive(prim) => { retval = Some(prim); } }; if self.early_returning { break; } if self.loop_control.is_some() { break; } } Ok(if let Some(ret) = retval { ret } else { self.expression(Expression::unit())? }) } fn statement(&mut self, stmt: Statement) -> EvalResult { match stmt { Statement::Binding { ref id, expr, constant: _ } => { let evaluated = self.expression(expr)?; self.state.environments.insert(id.into(), evaluated.into()); Ok(StatementOutput::Nothing) } Statement::Expression(expr) => { let evaluated = self.expression(expr)?; Ok(StatementOutput::Primitive(evaluated)) } Statement::Return(expr) => { let evaluated = self.expression(expr)?; self.early_returning = true; Ok(StatementOutput::Primitive(evaluated)) } Statement::Break => { self.loop_control = Some(LoopControlFlow::Break); Ok(StatementOutput::Nothing) } Statement::Continue => { self.loop_control = Some(LoopControlFlow::Continue); Ok(StatementOutput::Nothing) } } } fn expression(&mut self, expression: Expression) -> EvalResult { Ok(match expression { Expression::Literal(lit) => Primitive::Literal(lit), Expression::Tuple(items) => Primitive::Tuple( items .into_iter() .map(|expr| self.expression(expr)) .collect::>>()?, ), Expression::List(items) => Primitive::List( items .into_iter() .map(|expr| self.expression(expr)) .collect::>>()?, ), Expression::Lookup(kind) => match kind { Lookup::Function(ref id) => { let mem = id.into(); match self.state.environments.lookup(&mem) { // This just checks that the function exists in "memory" by ID, we don't // actually retrieve it until `apply_function()` Some(MemoryValue::Function(_)) => Primitive::Callable(Callable::UserDefined(*id)), x => return Err(format!("Function not found for id: {} : {:?}", id, x).into()), } } Lookup::Param(n) => { let mem = n.into(); match self.state.environments.lookup(&mem) { Some(MemoryValue::Primitive(prim)) => prim.clone(), e => return Err(format!("Param lookup error, got {:?}", e).into()), } } Lookup::LocalVar(ref id) | Lookup::GlobalVar(ref id) => { let mem = id.into(); match self.state.environments.lookup(&mem) { Some(MemoryValue::Primitive(expr)) => expr.clone(), _ => return Err( format!("Nothing found for local/gloval variable lookup {}", id).into() ), } } }, Expression::Assign { ref lval, box rval } => { let mem = lval.into(); let evaluated = self.expression(rval)?; println!("Inserting {:?} into {:?}", evaluated, mem); self.state.environments.insert(mem, MemoryValue::Primitive(evaluated)); Primitive::unit() } Expression::Call { box f, args } => self.call_expression(f, args)?, Expression::Callable(Callable::DataConstructor { type_id, tag }) => { let arity = self.type_context.lookup_variant_arity(&type_id, tag).unwrap(); if arity == 0 { Primitive::Object { type_id, tag, items: vec![], ordered_fields: None } } else { Primitive::Callable(Callable::DataConstructor { type_id, tag }) } } Expression::Callable(func) => Primitive::Callable(func), Expression::Conditional { box cond, then_clause, else_clause } => { let cond = self.expression(cond)?; match cond { Primitive::Literal(Literal::Bool(true)) => self.block(then_clause)?, Primitive::Literal(Literal::Bool(false)) => self.block(else_clause)?, v => return Err(format!("Non-boolean value {:?} in if-statement", v).into()), } } Expression::CaseMatch { box cond, alternatives } => self.case_match_expression(cond, alternatives)?, Expression::Index { box indexee, box indexer } => { let indexee = self.expression(indexee)?; let indexer = self.expression(indexer)?; match (indexee, indexer) { (Primitive::List(items), Primitive::Literal(Literal::Nat(n))) => match items.get(n as usize) { Some(item) => item.clone(), None => return Err(format!("Invalid index {} for this value", n).into()), }, _ => return Err("Invalid index type".to_string().into()), } } Expression::Loop { box cond, statements } => self.loop_expression(cond, statements)?, Expression::ReductionError(e) => return Err(e.into()), Expression::Access { name, box expr } => { let expr = self.expression(expr)?; match expr { Primitive::Object { items, ordered_fields: Some(ordered_fields), .. } => { let idx = match ordered_fields.iter().position(|s| s == &name) { Some(idx) => idx, None => return Err(format!("Field `{}` not found", name).into()), }; let item = match items.get(idx) { Some(item) => item, None => return Err(format!("Field lookup `{}` failed", name).into()), }; item.clone() } e => return Err( format!("Trying to do a field lookup on a non-object value: {:?}", e).into() ), } } }) } fn loop_expression(&mut self, cond: Expression, statements: Vec) -> EvalResult { let existing = self.loop_control; let output = self.loop_expression_inner(cond, statements); self.loop_control = existing; output } fn loop_expression_inner( &mut self, cond: Expression, statements: Vec, ) -> EvalResult { loop { let cond = self.expression(cond.clone())?; println!("COND: {:?}", cond); match cond { Primitive::Literal(Literal::Bool(true)) => (), Primitive::Literal(Literal::Bool(false)) => break, e => return Err(format!("Loop condition evaluates to non-boolean: {:?}", e).into()), }; //TODO eventually loops shoudl be able to return something let _output = self.block(statements.clone())?; match self.loop_control { None => (), Some(LoopControlFlow::Continue) => { self.loop_control = None; } Some(LoopControlFlow::Break) => { break; } } } Ok(Primitive::unit()) } fn case_match_expression( &mut self, cond: Expression, alternatives: Vec, ) -> EvalResult { fn matches(scrut: &Primitive, pat: &Pattern, scope: &mut ScopeStack) -> bool { match pat { Pattern::Ignored => true, Pattern::Binding(ref def_id) => { let mem = def_id.into(); scope.insert(mem, MemoryValue::Primitive(scrut.clone())); //TODO make sure this doesn't cause problems with nesting true } Pattern::Literal(pat_literal) => if let Primitive::Literal(scrut_literal) = scrut { pat_literal == scrut_literal } else { false }, Pattern::Tuple { subpatterns, tag } => match tag { None => match scrut { Primitive::Tuple(items) if items.len() == subpatterns.len() => items .iter() .zip(subpatterns.iter()) .all(|(item, subpat)| matches(item, subpat, scope)), _ => false, //TODO should be a type error }, Some(pattern_tag) => match scrut { //TODO should test type_ids for runtime type checking, once those work Primitive::Object { tag, items, .. } if tag == pattern_tag && items.len() == subpatterns.len() => items .iter() .zip(subpatterns.iter()) .all(|(item, subpat)| matches(item, subpat, scope)), _ => false, }, }, Pattern::Record { tag: pattern_tag, subpatterns } => match scrut { //TODO several types of possible error here Primitive::Object { tag, items, ordered_fields: Some(ordered_fields), .. } if tag == pattern_tag => subpatterns.iter().all(|(field_name, subpat)| { let idx = ordered_fields .iter() .position(|field| field.as_str() == field_name.as_ref()) .unwrap(); let item = &items[idx]; matches(item, subpat, scope) }), _ => false, }, } } let cond = self.expression(cond)?; for alt in alternatives.into_iter() { let mut new_scope = self.state.environments.new_scope(None); if matches(&cond, &alt.pattern, &mut new_scope) { let mut new_state = State { environments: new_scope }; let mut evaluator = Evaluator::new(&mut new_state, self.type_context); let output = evaluator.block(alt.item); self.early_returning = evaluator.early_returning; return output; } } Err("No valid match in match expression".into()) } fn call_expression(&mut self, f: Expression, args: Vec) -> EvalResult { let func = match self.expression(f)? { Primitive::Callable(func) => func, other => return Err(format!("Trying to call non-function value: {:?}", other).into()), }; match func { Callable::Builtin(builtin) => self.apply_builtin(builtin, args), Callable::UserDefined(def_id) => { let mem = (&def_id).into(); match self.state.environments.lookup(&mem) { Some(MemoryValue::Function(FunctionDefinition { body })) => { let body = body.clone(); //TODO ideally this clone would not happen self.apply_function(body, args) } e => Err(format!("Error looking up function with id {}: {:?}", def_id, e).into()), } } Callable::Lambda { arity, body } => { if arity as usize != args.len() { return Err(format!( "Lambda expression requries {} arguments, only {} provided", arity, args.len() ) .into()); } self.apply_function(body, args) } Callable::DataConstructor { type_id, tag } => { let arity = self.type_context.lookup_variant_arity(&type_id, tag).unwrap(); if arity as usize != args.len() { return Err(format!( "Constructor expression requries {} arguments, only {} provided", arity, args.len() ) .into()); } let mut items: Vec = vec![]; for arg in args.into_iter() { items.push(self.expression(arg)?); } Ok(Primitive::Object { type_id, tag, items, ordered_fields: None }) } Callable::RecordConstructor { type_id, tag, field_order } => { //TODO maybe I'll want to do a runtime check of the evaluated fields /* let record_members = self.type_context.lookup_record_members(type_id, tag) .ok_or(format!("Runtime record lookup for: {} {} not found", type_id, tag).into())?; */ let mut items: Vec = vec![]; for arg in args.into_iter() { items.push(self.expression(arg)?); } Ok(Primitive::Object { type_id, tag, items, ordered_fields: Some(field_order) }) } } } fn apply_builtin(&mut self, builtin: Builtin, args: Vec) -> EvalResult { use Builtin::*; use Literal::*; use Primitive::Literal as Lit; let evaled_args: EvalResult> = args.into_iter().map(|arg| self.expression(arg)).collect(); let evaled_args = evaled_args?; Ok(match (builtin, evaled_args.as_slice()) { /* builtin functions */ (IOPrint, &[ref anything]) => { print!("{}", anything.to_repl(self.type_context)); Primitive::Tuple(vec![]) } (IOPrintLn, &[ref anything]) => { println!("{}", anything.to_repl(self.type_context)); Primitive::Tuple(vec![]) } (IOGetLine, &[]) => { let mut buf = String::new(); std::io::stdin().read_line(&mut buf).expect("Error readling line in 'getline'"); StringLit(Rc::new(buf.trim().to_string())).into() } /* Binops */ (binop, &[ref lhs, ref rhs]) => match (binop, lhs, rhs) { // TODO need a better way of handling these literals (Add, Lit(Nat(l)), Lit(Nat(r))) => Nat(l + r).into(), (Add, Lit(Int(l)), Lit(Int(r))) => Int(l + r).into(), (Add, Lit(Nat(l)), Lit(Int(r))) => Int((*l as i64) + (*r as i64)).into(), (Add, Lit(Int(l)), Lit(Nat(r))) => Int((*l as i64) + (*r as i64)).into(), (Concatenate, Lit(StringLit(ref s1)), Lit(StringLit(ref s2))) => StringLit(Rc::new(format!("{}{}", s1, s2))).into(), (Subtract, Lit(Nat(l)), Lit(Nat(r))) => Nat(l - r).into(), (Multiply, Lit(Nat(l)), Lit(Nat(r))) => Nat(l * r).into(), (Divide, Lit(Nat(l)), Lit(Nat(r))) => Float((*l as f64) / (*r as f64)).into(), (Quotient, Lit(Nat(l)), Lit(Nat(r))) => if *r == 0 { return Err("Divide-by-zero error".into()); } else { Nat(l / r).into() }, (Modulo, Lit(Nat(l)), Lit(Nat(r))) => Nat(l % r).into(), (Exponentiation, Lit(Nat(l)), Lit(Nat(r))) => Nat(l ^ r).into(), (BitwiseAnd, Lit(Nat(l)), Lit(Nat(r))) => Nat(l & r).into(), (BitwiseOr, Lit(Nat(l)), Lit(Nat(r))) => Nat(l | r).into(), /* comparisons */ (Equality, Lit(Nat(l)), Lit(Nat(r))) => Bool(l == r).into(), (Equality, Lit(Int(l)), Lit(Int(r))) => Bool(l == r).into(), (Equality, Lit(Float(l)), Lit(Float(r))) => Bool(l == r).into(), (Equality, Lit(Bool(l)), Lit(Bool(r))) => Bool(l == r).into(), (Equality, Lit(StringLit(ref l)), Lit(StringLit(ref r))) => Bool(l == r).into(), (NotEqual, Lit(Nat(l)), Lit(Nat(r))) => Bool(l != r).into(), (NotEqual, Lit(Int(l)), Lit(Int(r))) => Bool(l != r).into(), (NotEqual, Lit(Float(l)), Lit(Float(r))) => Bool(l != r).into(), (NotEqual, Lit(Bool(l)), Lit(Bool(r))) => Bool(l != r).into(), (NotEqual, Lit(StringLit(ref l)), Lit(StringLit(ref r))) => Bool(l != r).into(), (LessThan, Lit(Nat(l)), Lit(Nat(r))) => Bool(l < r).into(), (LessThan, Lit(Int(l)), Lit(Int(r))) => Bool(l < r).into(), (LessThan, Lit(Float(l)), Lit(Float(r))) => Bool(l < r).into(), (LessThanOrEqual, Lit(Nat(l)), Lit(Nat(r))) => Bool(l <= r).into(), (LessThanOrEqual, Lit(Int(l)), Lit(Int(r))) => Bool(l <= r).into(), (LessThanOrEqual, Lit(Float(l)), Lit(Float(r))) => Bool(l <= r).into(), (GreaterThan, Lit(Nat(l)), Lit(Nat(r))) => Bool(l > r).into(), (GreaterThan, Lit(Int(l)), Lit(Int(r))) => Bool(l > r).into(), (GreaterThan, Lit(Float(l)), Lit(Float(r))) => Bool(l > r).into(), (GreaterThanOrEqual, Lit(Nat(l)), Lit(Nat(r))) => Bool(l >= r).into(), (GreaterThanOrEqual, Lit(Int(l)), Lit(Int(r))) => Bool(l >= r).into(), (GreaterThanOrEqual, Lit(Float(l)), Lit(Float(r))) => Bool(l >= r).into(), (binop, lhs, rhs) => return Err(format!("Invalid binop expression {:?} {:?} {:?}", lhs, binop, rhs).into()), }, (prefix, &[ref arg]) => match (prefix, arg) { (BooleanNot, Lit(Bool(true))) => Bool(false), (BooleanNot, Lit(Bool(false))) => Bool(true), (Negate, Lit(Nat(n))) => Int(-(*n as i64)), (Negate, Lit(Int(n))) => Int(-(*n as i64)), (Negate, Lit(Float(f))) => Float(-(*f as f64)), (Increment, Lit(Int(n))) => Int(*n), (Increment, Lit(Nat(n))) => Nat(*n), _ => return Err("No valid prefix op".into()), } .into(), (x, args) => return Err(format!("bad or unimplemented builtin {:?} | {:?}", x, args).into()), }) } fn apply_function(&mut self, body: Vec, args: Vec) -> EvalResult { let mut evaluated_args: Vec = vec![]; for arg in args.into_iter() { evaluated_args.push(self.expression(arg)?); } let mut frame_state = State { environments: self.state.environments.new_scope(None) }; let mut evaluator = Evaluator::new(&mut frame_state, self.type_context); for (n, evaled) in evaluated_args.into_iter().enumerate() { let n = n as u8; let mem = n.into(); evaluator.state.environments.insert(mem, MemoryValue::Primitive(evaled)); } evaluator.block(body) } }