use std::rc::Rc; //THINGS TODO // look at the haskell compiler, see where in its flow the typechecking happens // -nope, ghc deliberately does typechecking before desugaring to core // cf. a history of haskell, peyton-jones use ena::unify::{UnifyKey, InPlaceUnificationTable, UnificationTable, EqUnifyValue}; use crate::ast::*; use crate::util::ScopeStack; use crate::builtin::{PrefixOp, BinOp}; use std::collections::HashMap; use std::hash::Hash; #[derive(Debug, Clone, PartialEq)] pub struct TypeData { ty: Option } impl TypeData { pub fn new() -> TypeData { TypeData { ty: None } } } pub type TypeName = Rc; pub struct TypeContext<'a> { variable_map: ScopeStack<'a, Rc, Type>, unification_table: InPlaceUnificationTable, //evar_count: u32 } /// `InferResult` is the monad in which type inference takes place. type InferResult = Result; #[derive(Debug, Clone)] pub struct TypeError { pub msg: String } impl TypeError { fn new(msg: T) -> InferResult where T: Into { //TODO make these kinds of error-producing functions CoW-ready Err(TypeError { msg: msg.into() }) } } #[derive(Debug, Clone, PartialEq)] pub enum Type { Const(TypeConst), Var(TypeVar), Arrow(Box, Box), Compound { ty_name: String, args:Vec } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TypeVar(usize); impl UnifyKey for TypeVar { type Value = Option; fn index(&self) -> u32 { self.0 as u32 } fn from_index(u: u32) -> TypeVar { TypeVar(u as usize) } fn tag() -> &'static str { "TypeVar" } } #[derive(Debug, Clone, PartialEq, Eq)] pub enum TypeConst { Unit, Nat, Int, Float, StringT, Bool, Ordering, UserDefined } impl EqUnifyValue for TypeConst { } macro_rules! ty { ($type_name:ident) => { Type::Const(TypeConst::$type_name) }; ($t1:ident -> $t2:ident) => { Type::Arrow(Box::new(ty!($t1)), Box::new(ty!($t2))) }; ($t1:ident -> $t2:ident -> $t3:ident) => { Type::Arrow(Box::new(ty!($t1)), Box::new(ty!($t2 -> $t3))) }; } //TODO find a better way to capture the to/from string logic impl Type { pub fn to_string(&self) -> String { use self::Type::*; use self::TypeConst::*; match self { Const(Unit) => format!("()"), Const(Nat) => format!("Nat"), Const(Int) => format!("Int"), Const(Float) => format!("Float"), Const(StringT) => format!("String"), Const(Bool) => format!("Bool"), Const(Ordering) => format!("Ordering"), _ => format!("UNKNOWN TYPE"), } } fn from_string(string: &str) -> Option { Some(match string { "()" | "Unit" => ty!(Unit), "Nat" => ty!(Nat), "Int" => ty!(Int), "Float" => ty!(Float), "String" => ty!(StringT), "Bool" => ty!(Bool), "Ordering" => ty!(Ordering), _ => return None }) } } /* /// `Type` is parameterized by whether the type variables can be just universal, or universal or /// existential. #[derive(Debug, Clone)] enum Type { Var(A), Const(TConst), Arrow(Box>, Box>), } #[derive(Debug, Clone)] enum TVar { Univ(UVar), Exist(ExistentialVar) } #[derive(Debug, Clone)] struct UVar(Rc); #[derive(Debug, Clone)] struct ExistentialVar(u32); impl Type { fn to_tvar(&self) -> Type { match self { Type::Var(UVar(name)) => Type::Var(TVar::Univ(UVar(name.clone()))), Type::Const(ref c) => Type::Const(c.clone()), Type::Arrow(a, b) => Type::Arrow( Box::new(a.to_tvar()), Box::new(b.to_tvar()) ) } } } impl Type { fn skolemize(&self) -> Type { match self { Type::Var(TVar::Univ(uvar)) => Type::Var(uvar.clone()), Type::Var(TVar::Exist(_)) => Type::Var(UVar(Rc::new(format!("sk")))), Type::Const(ref c) => Type::Const(c.clone()), Type::Arrow(a, b) => Type::Arrow( Box::new(a.skolemize()), Box::new(b.skolemize()) ) } } } impl TypeIdentifier { fn to_monotype(&self) -> Type { match self { TypeIdentifier::Tuple(_) => Type::Const(TConst::Nat), TypeIdentifier::Singleton(TypeSingletonName { name, .. }) => { match &name[..] { "Nat" => Type::Const(TConst::Nat), "Int" => Type::Const(TConst::Int), "Float" => Type::Const(TConst::Float), "Bool" => Type::Const(TConst::Bool), "String" => Type::Const(TConst::StringT), _ => Type::Const(TConst::Nat), } } } } } #[derive(Debug, Clone)] enum TConst { User(Rc), Unit, Nat, Int, Float, StringT, Bool, } impl TConst { fn user(name: &str) -> TConst { TConst::User(Rc::new(name.to_string())) } } */ impl<'a> TypeContext<'a> { pub fn new() -> TypeContext<'a> { TypeContext { variable_map: ScopeStack::new(None), unification_table: UnificationTable::new(), //evar_count: 0 } } fn get_type_from_name(&self, name: &TypeIdentifier) -> InferResult { use self::TypeIdentifier::*; Ok(match name { Singleton(TypeSingletonName { name, params }) => { match Type::from_string(&name) { Some(ty) => ty, None => return TypeError::new("Unknown type name") } }, Tuple(_) => return TypeError::new("tuples aren't ready yet"), }) } pub fn typecheck(&mut self, ast: &AST) -> Result { let mut returned_type = Type::Const(TypeConst::Unit); for statement in ast.0.iter() { returned_type = self.statement(statement.node())?; } Ok(returned_type) } fn statement(&mut self, statement: &Statement) -> InferResult { match statement { Statement::ExpressionStatement(e) => self.expr(e.node()), Statement::Declaration(decl) => self.decl(decl), } } fn decl(&mut self, decl: &Declaration) -> InferResult { use self::Declaration::*; match decl { Binding { name, expr, .. } => { let ty = self.expr(expr.node())?; self.variable_map.insert(name.clone(), ty); }, _ => (), } Ok(ty!(Unit)) } fn expr(&mut self, expr: &Expression) -> InferResult { match expr { Expression(expr_type, Some(anno)) => { let t1 = self.expr_type(expr_type)?; let t2 = self.get_type_from_name(anno)?; self.unify(t2, t1) }, Expression(expr_type, None) => self.expr_type(expr_type) } } fn expr_type(&mut self, expr: &ExpressionKind) -> InferResult { use self::ExpressionKind::*; Ok(match expr { NatLiteral(_) => ty!(Nat), BoolLiteral(_) => ty!(Bool), FloatLiteral(_) => ty!(Float), StringLiteral(_) => ty!(StringT), PrefixExp(op, expr) => self.prefix(op, expr.node())?, BinExp(op, lhs, rhs) => self.binexp(op, lhs.node(), rhs.node())?, IfExpression { discriminator, body } => self.if_expr(discriminator, body)?, Value(val) => self.handle_value(val)?, Lambda { params, type_anno, body } => self.lambda(params, type_anno, body)?, _ => ty!(Unit), }) } fn prefix(&mut self, op: &PrefixOp, expr: &Expression) -> InferResult { let f = match op.get_type() { Ok(ty) => ty, Err(e) => return TypeError::new(e) }; let x = self.expr(expr)?; self.handle_apply(f, x) } fn binexp(&mut self, op: &BinOp, lhs: &Expression, rhs: &Expression) -> InferResult { let tf = match op.get_type() { Ok(ty) => ty, Err(e) => return TypeError::new(e), }; let t_lhs = self.expr(lhs)?; let t_curried = self.handle_apply(tf, t_lhs)?; let t_rhs = self.expr(rhs)?; self.handle_apply(t_curried, t_rhs) } fn handle_apply(&mut self, tf: Type, tx: Type) -> InferResult { Ok(match tf { Type::Arrow(box ref t1, box ref t2) => { let _ = self.unify(t1.clone(), tx)?; t2.clone() }, _ => return TypeError::new(format!("Not a function")) }) } fn if_expr(&mut self, discriminator: &Discriminator, body: &IfExpressionBody) -> InferResult { use self::Discriminator::*; use self::IfExpressionBody::*; match (discriminator, body) { (Simple(expr), SimpleConditional(then_clause, else_clause)) => self.handle_simple_if(expr, then_clause, else_clause), _ => TypeError::new(format!("Complex conditionals not supported")) } } fn handle_simple_if(&mut self, expr: &Expression, then_clause: &Block, else_clause: &Option) -> InferResult { let t1 = self.expr(expr)?; let t2 = self.block(then_clause)?; let t3 = match else_clause { Some(block) => self.block(block)?, None => ty!(Unit) }; let _ = self.unify(ty!(Bool), t1)?; self.unify(t2, t3) } fn lambda(&mut self, params: &Vec, type_anno: &Option, body: &Block) -> InferResult { Ok(ty!(Unit)) } fn block(&mut self, block: &Block) -> InferResult { let mut output = ty!(Unit); for s in block.iter() { let statement = s.node(); output = self.statement(statement)?; } Ok(output) } fn handle_value(&mut self, val: &Rc) -> InferResult { match self.variable_map.lookup(val) { Some(ty) => Ok(ty.clone()), None => TypeError::new(format!("Couldn't find variable: {}", val)) } } fn unify(&mut self, t1: Type, t2: Type) -> InferResult { use self::Type::*; use std::collections::hash_map::Entry; match (t1, t2) { (Const(ref c1), Const(ref c2)) if c1 == c2 => Ok(Const(c1.clone())), //choice of c1 is arbitrary I *think* (Const(ref c1), Var(ref v2)) => { self.unification_table.unify_var_value(v2.clone(), Some(c1.clone())) .or_else(|_| TypeError::new(format!("Couldn't unify {:?} and {:?}", Const(c1.clone()), Var(*v2))))?; Ok(Const(c1.clone())) }, (a @ Var(_), b @ Const(_)) => self.unify(b, a), (Var(v1), Var(v2)) => { panic!() }, (a, b) => TypeError::new(format!("{:?} and {:?} do not unify", a, b)), } } } #[cfg(test)] mod typechecking_tests { use super::*; use crate::ast::AST; fn parse(input: &str) -> AST { let tokens = crate::tokenizing::tokenize(input); let mut parser = crate::parsing::Parser::new(tokens); parser.parse().unwrap() } macro_rules! assert_type_in_fresh_context { ($string:expr, $type:expr) => { let mut tc = TypeContext::new(); let ref ast = parse($string); let ty = tc.typecheck(ast).unwrap(); assert_eq!(ty, $type) } } #[test] fn basic_test() { assert_type_in_fresh_context!("1", ty!(Nat)); assert_type_in_fresh_context!(r#""drugs""#, ty!(StringT)); assert_type_in_fresh_context!("true", ty!(Bool)); assert_type_in_fresh_context!("-1", ty!(Int)); } #[test] fn operators() { assert_type_in_fresh_context!("1 + 2", ty!(Nat)); assert_type_in_fresh_context!("-2", ty!(Int)); assert_type_in_fresh_context!("!true", ty!(Bool)); } }