diff --git a/src/schala_lang/type_check.rs b/src/schala_lang/type_check.rs index dce4d53..79394ac 100644 --- a/src/schala_lang/type_check.rs +++ b/src/schala_lang/type_check.rs @@ -20,32 +20,37 @@ use schala_lang::parsing::{AST, Statement, Declaration, Signature, Expression, E // typeInfer :: Expr a -> Matching (Type a) // unify :: Type a -> Type b -> Matching (Type c) +#[derive(Debug, PartialEq, Clone)] +pub enum Type { + TVar(TypeVar), + TConst(TypeConst), +} #[derive(Debug, PartialEq, Clone)] -pub enum TypeVariable { - Univ(UVar), +pub enum TypeVar { + Univ(String), Exist(u64), } #[derive(Debug, PartialEq, Clone)] -pub enum UVar { +pub enum TypeConst { Integer, Float, - String, + StringT, Boolean, Unit, - Function(Box, Box), + FunctionT(Box, Box), Bottom, } -type TypeCheckResult = Result; +type TypeCheckResult = Result; #[derive(Debug, PartialEq, Eq, Hash)] struct PathSpecifier(Rc); #[derive(Debug, PartialEq, Clone)] struct TypeContextEntry { - type_var: TypeVariable, + type_var: Type, constant: bool } @@ -99,46 +104,51 @@ impl TypeContext { pub fn debug_symbol_table(&self) -> String { format!("Symbol table:\n {:?}", self.symbol_table) } - fn get_existential_type(&mut self) -> TypeVariable { - let ret = TypeVariable::Exist(self.existential_type_label_count); + fn get_existential_type(&mut self) -> Type { + let ret = Type::TVar(TypeVar::Exist(self.existential_type_label_count)); self.existential_type_label_count += 1; ret } - fn from_anno(&mut self, anno: &TypeName) -> TypeVariable { - use self::TypeVariable::*; - use self::UVar::*; + fn from_anno(&mut self, anno: &TypeName) -> Type { + use self::Type::*; + use self::TypeConst::*; match anno { &TypeName::Singleton { ref name, .. } => { match name.as_ref().as_ref() { - "Int" => Univ(Integer), - "Bool" => Univ(Boolean), + "Int" => TConst(Integer), + "Bool" => TConst(Boolean), _ => self.get_existential_type() } }, - _ => Univ(Bottom), + _ => TConst(Bottom) } } - fn from_signature(&mut self, sig: &Signature) -> TypeVariable { - use self::TypeVariable::Univ; - use self::UVar::{Unit, Function}; + fn from_signature(&mut self, sig: &Signature) -> Type { + use self::Type::*; + use self::TypeConst::*; + let return_type = sig.type_anno.as_ref().map(|anno| self.from_anno(&anno)).unwrap_or_else(|| { self.get_existential_type() }); if sig.params.len() == 0 { - Univ(Function(Box::new(Univ(Unit)), Box::new(return_type))) + TConst(FunctionT(Box::new(TConst(Unit)), Box::new(return_type))) } else { let mut output_type = return_type; for p in sig.params.iter() { let p_type = p.1.as_ref().map(|anno| self.from_anno(anno)).unwrap_or_else(|| { self.get_existential_type() }); - output_type = Univ(Function(Box::new(p_type), Box::new(output_type))); + output_type = TConst(FunctionT(Box::new(p_type), Box::new(output_type))); } output_type } } pub fn type_check(&mut self, ast: &AST) -> TypeCheckResult { - let mut last = TypeVariable::Univ(UVar::Unit); + use self::Type::*; + use self::TypeConst::*; + + let mut last = TConst(Unit); + for statement in ast.0.iter() { match statement { &Statement::Declaration(ref _decl) => { @@ -154,17 +164,18 @@ impl TypeContext { fn infer(&mut self, expr: &Expression) -> TypeCheckResult { use self::ExpressionType::*; - use self::TypeVariable::*; + use self::Type::*; + use self::TypeConst::*; Ok(match (&expr.0, &expr.1) { (ref _t, &Some(ref anno)) => { //TODO make this better, self.from_anno(anno) }, - (&IntLiteral(_), _) => Univ(UVar::Integer), - (&FloatLiteral(_), _) => Univ(UVar::Float), - (&StringLiteral(_), _) => Univ(UVar::String), - (&BoolLiteral(_), _) => Univ(UVar::Boolean), + (&IntLiteral(_), _) => TConst(Integer), + (&FloatLiteral(_), _) => TConst(Float), + (&StringLiteral(_), _) => TConst(StringT), + (&BoolLiteral(_), _) => TConst(Boolean), (&Variable(ref name), _) => self.lookup(name).map(|entry| entry.type_var) .ok_or(format!("Couldn't find {}", name))?, (&BinExp(ref op, box ref lhs, box ref rhs), _) => { @@ -177,31 +188,33 @@ impl TypeContext { let f_type = self.infer(&*f)?; let arg_type = self.infer(arguments.get(0).unwrap())?; // TODO fix later match f_type { - Univ(UVar::Function(box t1, box ret_type)) => { + TConst(FunctionT(box t1, box ret_type)) => { let _ = self.unify(&t1, &arg_type)?; ret_type }, _ => return Err(format!("Type error")) } }, - _ => Univ(UVar::Unit), + _ => TConst(Unit) }) } fn infer_op(&mut self, _op: &Operation) -> TypeCheckResult { - use self::TypeVariable::*; + use self::Type::*; + use self::TypeConst::*; + Ok( - Univ(UVar::Function( - Box::new(Univ(UVar::Integer)), - Box::new(Univ(UVar::Function( - Box::new(Univ(UVar::Integer)), - Box::new(Univ(UVar::Integer)) + TConst(FunctionT( + Box::new(TConst(Integer)), + Box::new(TConst(FunctionT( + Box::new(TConst(Integer)), + Box::new(TConst(Integer)) ))) )) ) } - fn unify(&mut self, t1: &TypeVariable, t2: &TypeVariable) -> TypeCheckResult { + fn unify(&mut self, t1: &Type, t2: &Type) -> TypeCheckResult { if t1 == t2 { Ok(t1.clone()) } else { @@ -212,8 +225,9 @@ impl TypeContext { #[cfg(test)] mod tests { - use super::{TypeContext, TypeVariable, UVar}; - use super::TypeVariable::*; + use super::{Type, TypeVar, TypeConst, TypeContext}; + use super::Type::*; + use super::TypeConst::*; use schala_lang::parsing::{parse, tokenize}; macro_rules! type_test { @@ -228,7 +242,7 @@ mod tests { #[test] fn basic_inference() { - type_test!("30", Univ(UVar::Integer)) + type_test!("30", TConst(Integer)); } }