diff --git a/schala-lang/src/typechecking.rs b/schala-lang/src/typechecking.rs index ec2f459..0fdfc48 100644 --- a/schala-lang/src/typechecking.rs +++ b/schala-lang/src/typechecking.rs @@ -117,6 +117,29 @@ impl PolyType { #[derive(Debug, PartialEq, Clone)] struct Substitution(HashMap, MonoType>); +impl Substitution { + fn new() -> Substitution { + Substitution(HashMap::new()) + } + + fn bind_variable(name: &Rc, var: &MonoType) -> Substitution { + Substitution(hashmap! { + name.clone() => var.clone() + }) + } + + fn merge(self, other: Substitution) -> Substitution { + let mut map = HashMap::new(); + for (name, ty) in self.0.into_iter() { + map.insert(name, ty); + } + for (name, ty) in other.0.into_iter() { + map.insert(name, ty); + } + Substitution(map) + } +} + #[derive(Debug)] struct TypeEnvironment { @@ -267,6 +290,21 @@ impl Infer { fn infer_block(&mut self, block: &Vec) -> Result { Ok(MonoType::Const(TypeConst::Unit)) } + + fn unify(&mut self, a: MonoType, b: MonoType) -> Result { + use self::InferError::*; use self::MonoType::*; + Ok(match (a, b) { + (Const(ref a), Const(ref b)) if a == b => Substitution::new(), + (Var(ref name), ref var) => Substitution::bind_variable(name, var), + (ref var, Var(ref name)) => Substitution::bind_variable(name, var), + (Function(box a1, box b1), Function(box a2, box b2)) => { + let s1 = self.unify(a1, a2)?; + let s2 = self.unify(b1.apply_substitution(&s1), b2.apply_substitution(&s1))?; + s1.merge(s2) + }, + (a, b) => return Err(CannotUnify(a, b)) + }) + } }