diff --git a/schala-lang/language/src/typechecking.rs b/schala-lang/language/src/typechecking.rs index 4fa41a8..95ba5ec 100644 --- a/schala-lang/language/src/typechecking.rs +++ b/schala-lang/language/src/typechecking.rs @@ -6,7 +6,7 @@ use util::ScopeStack; pub type TypeName = Rc; pub struct TypeContext<'a> { - variable_map: ScopeStack<'a, Rc, Type<()>>, + variable_map: ScopeStack<'a, Rc, Type>, evar_count: u32 } @@ -28,16 +28,47 @@ enum Type { Arrow(Box>, Box>), } +#[derive(Debug, Clone)] enum TVar { - Univ(UniversalVar), + Univ(UVar), Exist(ExistentialVar) } -struct UniversalVar(Rc); +#[derive(Debug, Clone)] +struct UVar(Rc); + +#[derive(Debug, Clone)] struct ExistentialVar(u32); +impl Type { + fn to_tvar(&self) -> Type { + match self { + Type::Var(UVar(name)) => Type::Var(TVar::Univ(UVar(name.clone()))), + Type::Const(ref c) => Type::Const(c.clone()), + Type::Arrow(a, b) => Type::Arrow( + Box::new(a.to_tvar()), + Box::new(b.to_tvar()) + ) + } + } +} + +impl Type { + fn skolemize(&self) -> Type { + match self { + Type::Var(TVar::Univ(uvar)) => Type::Var(uvar.clone()), + Type::Var(TVar::Exist(evar)) => unimplemented!(), + Type::Const(ref c) => Type::Const(c.clone()), + Type::Arrow(a, b) => Type::Arrow( + Box::new(a.skolemize()), + Box::new(b.skolemize()) + ) + } + } +} + impl TypeIdentifier { - fn to_monotype(&self) -> Type<()> { + fn to_monotype(&self) -> Type { match self { TypeIdentifier::Tuple(items) => unimplemented!(), TypeIdentifier::Singleton(TypeSingletonName { name, .. }) => { @@ -94,7 +125,7 @@ impl<'a> TypeContext<'a> { } impl<'a> TypeContext<'a> { - fn infer_ast(&mut self, ast: &AST) -> InferResult> { + fn infer_ast(&mut self, ast: &AST) -> InferResult> { let mut output = Type::Const(TConst::Unit); for statement in ast.0.iter() { output = self.infer_statement(statement)?; @@ -102,29 +133,29 @@ impl<'a> TypeContext<'a> { Ok(output) } - fn infer_statement(&mut self, stmt: &Statement) -> InferResult> { + fn infer_statement(&mut self, stmt: &Statement) -> InferResult> { match stmt { Statement::ExpressionStatement(ref expr) => self.infer_expr(expr), Statement::Declaration(ref decl) => self.infer_decl(decl), } } - fn infer_expr(&mut self, expr: &Expression) -> InferResult> { + fn infer_expr(&mut self, expr: &Expression) -> InferResult> { match expr { Expression(expr_type, Some(type_anno)) => { let tx = self.infer_expr_type(expr_type)?; let ty = type_anno.to_monotype(); - self.unify(&ty, &tx) + self.unify(&ty.to_tvar(), &tx.to_tvar()).map(|x| x.skolemize()) }, Expression(expr_type, None) => self.infer_expr_type(expr_type) } } - fn infer_decl(&mut self, expr: &Declaration) -> InferResult> { + fn infer_decl(&mut self, expr: &Declaration) -> InferResult> { Ok(Type::Const(TConst::user("unimplemented"))) } - fn infer_expr_type(&mut self, expr_type: &ExpressionType) -> InferResult> { + fn infer_expr_type(&mut self, expr_type: &ExpressionType) -> InferResult> { use self::ExpressionType::*; Ok(match expr_type { NatLiteral(_) => Type::Const(TConst::Nat), @@ -135,17 +166,17 @@ impl<'a> TypeContext<'a> { //TODO handle the distinction between 0-arg constructors and variables at some point // need symbol table for that match self.variable_map.lookup(name) { - Some(ty) => ty.clone(), + Some(ty) => ty.clone().skolemize(), None => return TypeError::new(&format!("Variable {} not found", name)) } }, IfExpression { discriminator, body } => self.infer_if_expr(discriminator, body)?, Call { f, arguments } => { - let tf: Type<()> = self.infer_expr(f)?; //has to be an Arrow Type + let tf = self.infer_expr(f)?; //has to be an Arrow Type let targ = self.infer_expr(&arguments[0])?; // TODO make this work with functions with more than one arg match tf { Type::Arrow(t1, t2) => { - self.unify(&t1, &targ)?; + self.unify(&t1.to_tvar(), &targ.to_tvar())?; *t2.clone() }, _ => return TypeError::new("not a function") @@ -163,7 +194,7 @@ impl<'a> TypeContext<'a> { }) } - fn infer_if_expr(&mut self, discriminator: &Discriminator, body: &IfExpressionBody) -> InferResult> { + fn infer_if_expr(&mut self, discriminator: &Discriminator, body: &IfExpressionBody) -> InferResult> { let test = match discriminator { Discriminator::Simple(expr) => expr, _ => return TypeError::new("Dame desu") @@ -177,7 +208,7 @@ impl<'a> TypeContext<'a> { unimplemented!() } - fn infer_block(&mut self, block: &Block) -> InferResult> { + fn infer_block(&mut self, block: &Block) -> InferResult> { let mut output = Type::Const(TConst::Unit); for statement in block.iter() { output = self.infer_statement(statement)?; @@ -185,7 +216,7 @@ impl<'a> TypeContext<'a> { Ok(output) } - fn unify(&mut self, t1: &Type<()>, t2: &Type<()>) -> InferResult> { + fn unify(&mut self, t1: &Type, t2: &Type) -> InferResult> { unimplemented!() }