use std::rc::Rc; use crate::ast::*; use crate::util::ScopeStack; pub type TypeName = Rc; pub struct TypeContext<'a> { variable_map: ScopeStack<'a, Rc, Type>, evar_count: u32 } /// `InferResult` is the monad in which type inference takes place. type InferResult = Result; #[derive(Debug, Clone)] struct TypeError { msg: String } impl TypeError { fn new(msg: &str) -> InferResult { Err(TypeError { msg: msg.to_string() }) } } /// `Type` is parameterized by whether the type variables can be just universal, or universal or /// existential. #[derive(Debug, Clone)] enum Type { Var(A), Const(TConst), Arrow(Box>, Box>), } #[derive(Debug, Clone)] enum TVar { Univ(UVar), Exist(ExistentialVar) } #[derive(Debug, Clone)] struct UVar(Rc); #[derive(Debug, Clone)] struct ExistentialVar(u32); impl Type { fn to_tvar(&self) -> Type { match self { Type::Var(UVar(name)) => Type::Var(TVar::Univ(UVar(name.clone()))), Type::Const(ref c) => Type::Const(c.clone()), Type::Arrow(a, b) => Type::Arrow( Box::new(a.to_tvar()), Box::new(b.to_tvar()) ) } } } impl Type { fn skolemize(&self) -> Type { match self { Type::Var(TVar::Univ(uvar)) => Type::Var(uvar.clone()), Type::Var(TVar::Exist(_)) => Type::Var(UVar(Rc::new(format!("sk")))), Type::Const(ref c) => Type::Const(c.clone()), Type::Arrow(a, b) => Type::Arrow( Box::new(a.skolemize()), Box::new(b.skolemize()) ) } } } impl TypeIdentifier { fn to_monotype(&self) -> Type { match self { TypeIdentifier::Tuple(_) => Type::Const(TConst::Nat), TypeIdentifier::Singleton(TypeSingletonName { name, .. }) => { match &name[..] { "Nat" => Type::Const(TConst::Nat), "Int" => Type::Const(TConst::Int), "Float" => Type::Const(TConst::Float), "Bool" => Type::Const(TConst::Bool), "String" => Type::Const(TConst::StringT), _ => Type::Const(TConst::Nat), } } } } } #[derive(Debug, Clone)] enum TConst { User(Rc), Unit, Nat, Int, Float, StringT, Bool, } impl TConst { fn user(name: &str) -> TConst { TConst::User(Rc::new(name.to_string())) } } impl<'a> TypeContext<'a> { pub fn new() -> TypeContext<'a> { TypeContext { variable_map: ScopeStack::new(None), evar_count: 0 } } pub fn typecheck(&mut self, ast: &AST) -> Result { match self.infer_ast(ast) { Ok(t) => Ok(format!("{:?}", t)), Err(err) => Err(format!("Type error: {:?}", err)) } } } impl<'a> TypeContext<'a> { fn infer_ast(&mut self, ast: &AST) -> InferResult> { self.infer_block(&ast.0) } fn infer_statement(&mut self, stmt: &Statement) -> InferResult> { match stmt { Statement::ExpressionStatement(ref expr) => self.infer_expr(expr.node()), Statement::Declaration(ref decl) => self.infer_decl(decl), } } fn infer_expr(&mut self, expr: &Expression) -> InferResult> { match expr { Expression(expr_type, Some(type_anno)) => { let tx = self.infer_expr_type(expr_type)?; let ty = type_anno.to_monotype(); self.unify(&ty.to_tvar(), &tx.to_tvar()).map(|x| x.skolemize()) }, Expression(expr_type, None) => self.infer_expr_type(expr_type) } } fn infer_decl(&mut self, _decl: &Declaration) -> InferResult> { Ok(Type::Const(TConst::user("unimplemented"))) } fn infer_expr_type(&mut self, expr_type: &ExpressionType) -> InferResult> { use self::ExpressionType::*; Ok(match expr_type { NatLiteral(_) => Type::Const(TConst::Nat), FloatLiteral(_) => Type::Const(TConst::Float), StringLiteral(_) => Type::Const(TConst::StringT), BoolLiteral(_) => Type::Const(TConst::Bool), Value(name) => { //TODO handle the distinction between 0-arg constructors and variables at some point // need symbol table for that match self.variable_map.lookup(name) { Some(ty) => ty.clone().skolemize(), None => return TypeError::new(&format!("Variable {} not found", name)) } }, IfExpression { discriminator, body } => self.infer_if_expr(discriminator, body)?, Call { f, arguments } => { let tf = self.infer_expr(f)?; //has to be an Arrow Type let targ = self.infer_expr(&arguments[0].node())?; // TODO make this work with functions with more than one arg match tf { Type::Arrow(t1, t2) => { self.unify(&t1.to_tvar(), &targ.to_tvar())?; *t2.clone() }, _ => return TypeError::new("not a function") } }, Lambda { params, .. } => { let _arg_type = match ¶ms[0] { (_, Some(type_anno)) => type_anno.to_monotype().to_tvar(), (_, None) => self.allocate_existential(), }; //let _result_type = unimplemented!(); return TypeError::new("Unimplemented"); //Type::Arrow(Box::new(arg_type), Box::new(result_type)) } _ => Type::Const(TConst::user("unimplemented")) }) } fn infer_if_expr(&mut self, discriminator: &Discriminator, body: &IfExpressionBody) -> InferResult> { let _test = match discriminator { Discriminator::Simple(expr) => expr, _ => return TypeError::new("Dame desu") }; let (_then_clause, _maybe_else_clause) = match body { IfExpressionBody::SimpleConditional(a, b) => (a, b), _ => return TypeError::new("Dont work") }; TypeError::new("Not implemented") } fn infer_block(&mut self, block: &Block) -> InferResult> { let mut output = Type::Const(TConst::Unit); for statement in block.iter() { output = self.infer_statement(statement.node())?; } Ok(output) } fn unify(&mut self, _t1: &Type, _t2: &Type) -> InferResult> { TypeError::new("not implemented") } fn allocate_existential(&mut self) -> Type { let n = self.evar_count; self.evar_count += 1; Type::Var(TVar::Exist(ExistentialVar(n))) } } #[cfg(test)] mod tests { use super::*; fn parse(input: &str) -> AST { let tokens: Vec = crate::tokenizing::tokenize(input); let mut parser = crate::parsing::Parser::new(tokens); parser.parse().unwrap() } macro_rules! type_test { ($input:expr, $correct:expr) => { { let mut tc = TypeContext::new(); let ast = parse($input); tc.add_symbols(&ast); assert_eq!($correct, tc.type_check(&ast).unwrap()) } } } #[test] fn basic_inference() { } }