use std::collections::HashMap; use std::rc::Rc; use parsing::{AST, Statement, Declaration, Signature, Expression, ExpressionType, Operation, Variant, TypeName, TypeSingletonName}; // from Niko's talk /* fn type_check(expression, expected_ty) -> Ty { let ty = bare_type_check(expression, expected_type); if ty icompatible with expected_ty { try_coerce(expression, ty, expected_ty) } else { ty } } fn bare_type_check(exprssion, expected_type) -> Ty { ... } */ /* H-M ALGO NOTES from https://www.youtube.com/watch?v=il3gD7XMdmA (also check out http://dev.stephendiehl.com/fun/006_hindley_milner.html) typeInfer :: Expr a -> Matching (Type a) unify :: Type a -> Type b -> Matching (Type c) (Matching a) is a monad in which unification is done ex: typeInfer (If e1 e2 e3) = do t1 <- typeInfer e1 t2 <- typeInfer e2 t3 <- typeInfer e3 _ <- unify t1 BoolType unify t2 t3 -- b/c t2 and t3 have to be the same type typeInfer (Const (ConstInt _)) = IntType -- same for other literals --function application typeInfer (Apply f x) = do tf <- typeInfer f tx <- typeInfer x case tf of FunctionType t1 t2 -> do _ <- unify t1 tx return t2 _ -> fail "Not a function" --type annotation typeInfer (Typed x t) = do tx <- typeInfer x unify tx t --variable and let expressions - need to pass around a map of variable names to types here typeInfer :: [ (Var, Type Var) ] -> Expr Var -> Matching (Type Var) typeInfer ctx (Var x) = case (lookup x ctx) of Just t -> return t Nothing -> fail "Unknown variable" --let x = e1 in e2 typeInfer ctx (Let x e1 e2) = do t1 <- typeInfer ctx e1 typeInfer ((x, t1) :: ctx) e2 --lambdas are complicated (this represents ʎx.e) typeInfer ctx (Lambda x e) = do t1 <- allocExistentialVariable t2 <- typeInfer ((x, t1) :: ctx) e return $ FunctionType t1 t2 -- ie. t1 -> t2 --to solve the problem of map :: (a -> b) -> [a] -> [b] when we use a variable whose type has universal tvars, convert those universal tvars to existential ones -and each distinct universal tvar needs to map to the same existential type -so we change typeinfer: typeInfer ctx (Var x) = do case (lookup x ctx) of Nothing -> ... Just t -> do let uvars = nub (toList t) -- nub removes duplicates, so this gets unique universally quantified variables evars <- mapM (const allocExistentialVariable) uvars let varMap = zip uvars evars let vixVar varMap v = fromJust $ lookup v varMap return (fmap (fixVar varMap) t) --how do we define unify?? -recall, type signature is: unify :: Type a -> Type b -> Matching (Type c) unify BoolType BoolType = BoolType --easy, same for all constants unify (FunctionType t1 t2) (FunctionType t3 t4) = do t5 <- unify t1 t3 t6 <- unify t2 t4 return $ FunctionType t5 t6 unify (TVar a) (TVar b) = if a == b then TVar a else fail --existential types can be assigned another type at most once --some complicated stuff about hanlding existential types --everything else is a type error unify a b = fail SKOLEMIZATION - how you prevent an unassigned existential type variable from leaking! -before a type gets to global scope, replace all unassigned existential vars w/ new unique universal type variables */ #[derive(Debug, PartialEq, Clone)] pub enum Type { TVar(TypeVar), TConst(TypeConst), TFunc(Box, Box), } #[derive(Debug, PartialEq, Clone)] pub enum TypeVar { Univ(Rc), Exist(u64), } impl TypeVar { fn univ(label: &str) -> TypeVar { TypeVar::Univ(Rc::new(label.to_string())) } } #[derive(Debug, PartialEq, Clone)] pub enum TypeConst { UserT(Rc), Integer, Float, StringT, Boolean, Unit, Bottom, } type TypeCheckResult = Result; #[derive(Debug, PartialEq, Eq, Hash)] struct PathSpecifier(Rc); #[derive(Debug, PartialEq, Clone)] struct TypeContextEntry { ty: Type, constant: bool } pub struct TypeContext { symbol_table: HashMap, evar_table: HashMap, existential_type_label_count: u64 } impl TypeContext { pub fn new() -> TypeContext { TypeContext { symbol_table: HashMap::new(), evar_table: HashMap::new(), existential_type_label_count: 0, } } pub fn add_symbols(&mut self, ast: &AST) { use self::Declaration::*; use self::Type::*; use self::TypeConst::*; for statement in ast.0.iter() { match *statement { Statement::ExpressionStatement(_) => (), Statement::Declaration(ref decl) => match *decl { FuncSig(_) => (), Impl { .. } => (), TypeDecl(ref type_constructor, ref body) => { for variant in body.0.iter() { let (spec, ty) = match variant { &Variant::UnitStruct(ref data_constructor) => { let spec = PathSpecifier(data_constructor.clone()); let ty = TConst(UserT(type_constructor.name.clone())); (spec, ty) }, &Variant::TupleStruct(ref data_construcor, ref args) => { //TODO fix let arg = args.get(0).unwrap(); let type_arg = self.from_anno(arg); let spec = PathSpecifier(data_construcor.clone()); let ty = TFunc(Box::new(type_arg), Box::new(TConst(UserT(type_constructor.name.clone())))); (spec, ty) }, &Variant::Record(_, _) => unimplemented!(), }; let entry = TypeContextEntry { ty, constant: true }; self.symbol_table.insert(spec, entry); } }, TypeAlias { .. } => (), Binding {ref name, ref constant, ref expr} => { let spec = PathSpecifier(name.clone()); let ty = expr.1.as_ref() .map(|ty| self.from_anno(ty)) .unwrap_or_else(|| { self.alloc_existential_type() }); // this call to alloc_existential is OK b/c a binding only ever has one type, so if the annotation is absent, it's fine to just make one de novo let entry = TypeContextEntry { ty, constant: *constant }; self.symbol_table.insert(spec, entry); }, FuncDecl(ref signature, _) => { let spec = PathSpecifier(signature.name.clone()); let ty = self.from_signature(signature); let entry = TypeContextEntry { ty, constant: true }; self.symbol_table.insert(spec, entry); }, } } } } fn lookup(&mut self, binding: &Rc) -> Option { let key = PathSpecifier(binding.clone()); self.symbol_table.get(&key).map(|entry| entry.clone()) } pub fn debug_symbol_table(&self) -> String { format!("Symbol table:\n {:?}\nEvar table:\n{:?}", self.symbol_table, self.evar_table) } fn alloc_existential_type(&mut self) -> Type { let ret = Type::TVar(TypeVar::Exist(self.existential_type_label_count)); self.existential_type_label_count += 1; ret } fn from_anno(&mut self, anno: &TypeName) -> Type { use self::Type::*; use self::TypeConst::*; match anno { &TypeName::Singleton(TypeSingletonName { ref name, .. }) => { match name.as_ref().as_ref() { "Int" => TConst(Integer), "Float" => TConst(Float), "Bool" => TConst(Boolean), "String" => TConst(StringT), s => TVar(TypeVar::Univ(Rc::new(format!("{}",s)))), } }, &TypeName::Tuple(ref items) => { if items.len() == 1 { TConst(Unit) } else { TConst(Bottom) } } } } fn from_signature(&mut self, sig: &Signature) -> Type { use self::Type::*; use self::TypeConst::*; //TODO this won't work properly until you make sure that all (universal) type vars in the function have the same existential type var // actually this should never even put existential types into the symbol table at all //this will crash if more than 5 arg function is used let names = vec!["a", "b", "c", "d", "e", "f"]; let mut idx = 0; let mut get_type = || { let q = TVar(TypeVar::Univ(Rc::new(format!("{}", names.get(idx).unwrap())))); idx += 1; q }; let return_type = sig.type_anno.as_ref().map(|anno| self.from_anno(&anno)).unwrap_or_else(|| { get_type() }); if sig.params.len() == 0 { TFunc(Box::new(TConst(Unit)), Box::new(return_type)) } else { let mut output_type = return_type; for p in sig.params.iter() { let p_type = p.1.as_ref().map(|anno| self.from_anno(anno)).unwrap_or_else(|| { get_type() }); output_type = TFunc(Box::new(p_type), Box::new(output_type)); } output_type } } pub fn type_check(&mut self, ast: &AST) -> TypeCheckResult { use self::Type::*; use self::TypeConst::*; let mut last = TConst(Unit); for statement in ast.0.iter() { match statement { &Statement::Declaration(ref _decl) => { //return Err(format!("Declarations not supported")); }, &Statement::ExpressionStatement(ref expr) => { last = self.infer(expr)?; } } } Ok(last) } fn infer(&mut self, expr: &Expression) -> TypeCheckResult { match (&expr.0, &expr.1) { (exprtype, &Some(ref anno)) => { let tx = self.infer_no_anno(exprtype)?; let ty = self.from_anno(anno); self.unify(tx, ty) }, (exprtype, &None) => self.infer_no_anno(exprtype), } } fn infer_no_anno(&mut self, ex: &ExpressionType) -> TypeCheckResult { use self::ExpressionType::*; use self::Type::*; use self::TypeConst::*; Ok(match ex { &IntLiteral(_) => TConst(Integer), &FloatLiteral(_) => TConst(Float), &StringLiteral(_) => TConst(StringT), &BoolLiteral(_) => TConst(Boolean), &Value(ref name, _) => { self.lookup(name) .map(|entry| entry.ty) .ok_or(format!("Couldn't find {}", name))? }, &BinExp(ref op, ref lhs, ref rhs) => { let t_lhs = self.infer(lhs)?; match self.infer_op(op)? { TFunc(t1, t2) => { let _ = self.unify(t_lhs, *t1)?; let t_rhs = self.infer(rhs)?; let x = *t2; match x { TFunc(t3, t4) => { let _ = self.unify(t_rhs, *t3)?; *t4 }, _ => return Err(format!("Not a function type either")), } }, _ => return Err(format!("Op {:?} is not a function type", op)), } }, &Call { ref f, ref arguments } => { let tf = self.infer(f)?; let targ = self.infer(arguments.get(0).unwrap())?; match tf { TFunc(box t1, box t2) => { let _ = self.unify(t1, targ)?; t2 }, _ => return Err(format!("Not a function!")), } }, _ => TConst(Bottom), }) } fn infer_op(&mut self, op: &Operation) -> TypeCheckResult { use self::Type::*; use self::TypeConst::*; macro_rules! binoptype { ($lhs:expr, $rhs:expr, $out:expr) => { TFunc(Box::new($lhs), Box::new(TFunc(Box::new($rhs), Box::new($out)))) }; } Ok(match (*op.0).as_ref() { "+" => binoptype!(TConst(Integer), TConst(Integer), TConst(Integer)), "++" => binoptype!(TConst(StringT), TConst(StringT), TConst(StringT)), "-" => binoptype!(TConst(Integer), TConst(Integer), TConst(Integer)), "*" => binoptype!(TConst(Integer), TConst(Integer), TConst(Integer)), "/" => binoptype!(TConst(Integer), TConst(Integer), TConst(Integer)), "%" => binoptype!(TConst(Integer), TConst(Integer), TConst(Integer)), _ => TConst(Bottom) }) } fn unify(&mut self, t1: Type, t2: Type) -> TypeCheckResult { use self::Type::*; use self::TypeVar::*; println!("Calling unify with `{:?}` and `{:?}`", t1, t2); match (&t1, &t2) { (&TConst(ref c1), &TConst(ref c2)) if c1 == c2 => Ok(TConst(c1.clone())), (&TFunc(ref t1, ref t2), &TFunc(ref t3, ref t4)) => { let t5 = self.unify(*t1.clone().clone(), *t3.clone().clone())?; let t6 = self.unify(*t2.clone().clone(), *t4.clone().clone())?; Ok(TFunc(Box::new(t5), Box::new(t6))) }, (&TVar(Univ(ref a)), &TVar(Univ(ref b))) => { if a == b { Ok(TVar(Univ(a.clone()))) } else { Err(format!("Couldn't unify universal types {} and {}", a, b)) } }, //the interesting case!! (&TVar(Exist(ref a)), ref t2) => { let x = self.evar_table.get(a).map(|x| x.clone()); match x { Some(ref t1) => self.unify(t1.clone().clone(), t2.clone().clone()), None => { self.evar_table.insert(*a, t2.clone().clone()); Ok(t2.clone().clone()) } } }, (ref t1, &TVar(Exist(ref a))) => { let x = self.evar_table.get(a).map(|x| x.clone()); match x { Some(ref t2) => self.unify(t2.clone().clone(), t1.clone().clone()), None => { self.evar_table.insert(*a, t1.clone().clone()); Ok(t1.clone().clone()) } } }, _ => Err(format!("Types {:?} and {:?} don't unify", t1, t2)) } } } #[cfg(test)] mod tests { use super::{Type, TypeVar, TypeConst, TypeContext}; use super::Type::*; use super::TypeConst::*; use schala_lang::parsing::{parse, tokenize}; macro_rules! type_test { ($input:expr, $correct:expr) => { { let mut tc = TypeContext::new(); let ast = parse(tokenize($input)).0.unwrap() ; tc.add_symbols(&ast); assert_eq!($correct, tc.type_check(&ast).unwrap()) } } } #[test] fn basic_inference() { type_test!("30", TConst(Integer)); type_test!("fn x(a: Int): Bool {}; x(1)", TConst(Boolean)); } }