use std::rc::Rc; use std::convert::TryFrom; use std::fmt; use ena::unify::{UnifyKey, InPlaceUnificationTable, UnificationTable, EqUnifyValue}; use crate::builtin::Builtin; use crate::ast::*; use crate::util::ScopeStack; use crate::util::deref_optional_box; #[derive(Debug, Clone, PartialEq)] pub struct TypeData { ty: Option } impl TypeData { #[allow(dead_code)] pub fn new() -> TypeData { TypeData { ty: None } } } //TODO need to hook this into the actual typechecking system somehow #[derive(Debug, Clone)] pub struct TypeId { local_name: Rc } impl TypeId { //TODO this is definitely incomplete pub fn lookup_name(name: &str) -> TypeId { TypeId { local_name: Rc::new(name.to_string()) } } pub fn local_name(&self) -> &str { self.local_name.as_ref() } } impl fmt::Display for TypeId { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "TypeId:{}", self.local_name) } } pub struct TypeContext<'a> { variable_map: ScopeStack<'a, Rc, Type>, unification_table: InPlaceUnificationTable, } /// `InferResult` is the monad in which type inference takes place. type InferResult = Result; #[derive(Debug, Clone)] pub struct TypeError { pub msg: String } impl TypeError { fn new(msg: T) -> InferResult where T: Into { Err(TypeError { msg: msg.into() }) } } #[allow(dead_code)] // avoids warning from Compound #[derive(Debug, Clone, PartialEq)] pub enum Type { Const(TypeConst), Var(TypeVar), Arrow { params: Vec, ret: Box }, Compound { ty_name: String, args:Vec } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TypeVar(usize); impl UnifyKey for TypeVar { type Value = Option; fn index(&self) -> u32 { self.0 as u32 } fn from_index(u: u32) -> TypeVar { TypeVar(u as usize) } fn tag() -> &'static str { "TypeVar" } } #[derive(Debug, Clone, PartialEq, Eq)] pub enum TypeConst { Unit, Nat, Int, Float, StringT, Bool, Ordering, //UserDefined } impl TypeConst { /* #[allow(dead_code)] pub fn to_string(&self) -> String { use self::TypeConst::*; match self { Unit => "()".to_string(), Nat => "Nat".to_string(), Int => "Int".to_string(), Float => "Float".to_string(), StringT => "String".to_string(), Bool => "Bool".to_string(), Ordering => "Ordering".to_string(), } } */ } impl EqUnifyValue for TypeConst { } macro_rules! ty { ($type_name:ident) => { Type::Const(TypeConst::$type_name) }; ($t1:ident -> $t2:ident) => { Type::Arrow { params: vec![ty!($t1)], ret: box ty!($t2) } }; ($t1:ident -> $t2:ident -> $t3:ident) => { Type::Arrow { params: vec![ty!($t1), ty!($t2)], ret: box ty!($t3) } }; ($type_list:ident, $ret_type:ident) => { Type::Arrow { params: $type_list, ret: box $ret_type, } } } //TODO find a better way to capture the to/from string logic impl Type { /* #[allow(dead_code)] pub fn to_string(&self) -> String { use self::Type::*; match self { Const(c) => c.to_string(), Var(v) => format!("t_{}", v.0), Arrow { params, box ref ret } => { if params.is_empty() { format!("-> {}", ret.to_string()) } else { let mut buf = String::new(); for p in params.iter() { write!(buf, "{} -> ", p.to_string()).unwrap(); } write!(buf, "{}", ret.to_string()).unwrap(); buf } }, Compound { .. } => "".to_string() } } */ fn from_string(string: &str) -> Option { Some(match string { "()" | "Unit" => ty!(Unit), "Nat" => ty!(Nat), "Int" => ty!(Int), "Float" => ty!(Float), "String" => ty!(StringT), "Bool" => ty!(Bool), "Ordering" => ty!(Ordering), _ => return None }) } } /* /// `Type` is parameterized by whether the type variables can be just universal, or universal or /// existential. #[derive(Debug, Clone)] enum Type { Var(A), Const(TConst), Arrow(Box>, Box>), } #[derive(Debug, Clone)] enum TVar { Univ(UVar), Exist(ExistentialVar) } #[derive(Debug, Clone)] struct UVar(Rc); #[derive(Debug, Clone)] struct ExistentialVar(u32); impl Type { fn to_tvar(&self) -> Type { match self { Type::Var(UVar(name)) => Type::Var(TVar::Univ(UVar(name.clone()))), Type::Const(ref c) => Type::Const(c.clone()), Type::Arrow(a, b) => Type::Arrow( Box::new(a.to_tvar()), Box::new(b.to_tvar()) ) } } } impl Type { fn skolemize(&self) -> Type { match self { Type::Var(TVar::Univ(uvar)) => Type::Var(uvar.clone()), Type::Var(TVar::Exist(_)) => Type::Var(UVar(Rc::new(format!("sk")))), Type::Const(ref c) => Type::Const(c.clone()), Type::Arrow(a, b) => Type::Arrow( Box::new(a.skolemize()), Box::new(b.skolemize()) ) } } } impl TypeIdentifier { fn to_monotype(&self) -> Type { match self { TypeIdentifier::Tuple(_) => Type::Const(TConst::Nat), TypeIdentifier::Singleton(TypeSingletonName { name, .. }) => { match &name[..] { "Nat" => Type::Const(TConst::Nat), "Int" => Type::Const(TConst::Int), "Float" => Type::Const(TConst::Float), "Bool" => Type::Const(TConst::Bool), "String" => Type::Const(TConst::StringT), _ => Type::Const(TConst::Nat), } } } } } #[derive(Debug, Clone)] enum TConst { User(Rc), Unit, Nat, Int, Float, StringT, Bool, } impl TConst { fn user(name: &str) -> TConst { TConst::User(Rc::new(name.to_string())) } } */ impl<'a> TypeContext<'a> { pub fn new() -> TypeContext<'a> { TypeContext { variable_map: ScopeStack::new(None), unification_table: UnificationTable::new(), } } /* fn new_env(&'a self, new_var: Rc, ty: Type) -> TypeContext<'a> { let mut new_context = TypeContext { variable_map: self.variable_map.new_scope(None), unification_table: UnificationTable::new(), //???? not sure if i want this }; new_context.variable_map.insert(new_var, ty); new_context } */ fn get_type_from_name(&self, name: &TypeIdentifier) -> InferResult { use self::TypeIdentifier::*; Ok(match name { Singleton(TypeSingletonName { name,.. }) => { match Type::from_string(name) { Some(ty) => ty, None => return TypeError::new(format!("Unknown type name: {}", name)) } }, Tuple(_) => return TypeError::new("tuples aren't ready yet"), }) } /// `typecheck` is the entry into the type-inference system, accepting an AST as an argument /// Following the example of GHC, the compiler deliberately does typechecking before de-sugaring /// the AST to ReducedAST pub fn typecheck(&mut self, ast: &AST) -> Result { let mut returned_type = Type::Const(TypeConst::Unit); for statement in ast.statements.statements.iter() { returned_type = self.statement(statement)?; } Ok(returned_type) } fn statement(&mut self, statement: &Statement) -> InferResult { match &statement.kind { StatementKind::Expression(e) => self.expr(e), StatementKind::Declaration(decl) => self.decl(decl), StatementKind::Import(_) => Ok(ty!(Unit)), StatementKind::Module(_) => Ok(ty!(Unit)), } } fn decl(&mut self, decl: &Declaration) -> InferResult { use self::Declaration::*; if let Binding { name, expr, .. } = decl { let ty = self.expr(expr)?; self.variable_map.insert(name.clone(), ty); } Ok(ty!(Unit)) } fn invoc(&mut self, invoc: &InvocationArgument) -> InferResult { use InvocationArgument::*; match invoc { Positional(expr) => self.expr(expr), _ => Ok(ty!(Nat)) //TODO this is wrong } } fn expr(&mut self, expr: &Expression) -> InferResult { match expr { Expression { kind, type_anno: Some(anno), .. } => { let t1 = self.expr_type(kind)?; let t2 = self.get_type_from_name(anno)?; self.unify(t2, t1) }, Expression { kind, type_anno: None, .. } => self.expr_type(kind) } } fn expr_type(&mut self, expr: &ExpressionKind) -> InferResult { use self::ExpressionKind::*; Ok(match expr { NatLiteral(_) => ty!(Nat), BoolLiteral(_) => ty!(Bool), FloatLiteral(_) => ty!(Float), StringLiteral(_) => ty!(StringT), PrefixExp(op, expr) => self.prefix(op, expr)?, BinExp(op, lhs, rhs) => self.binexp(op, lhs, rhs)?, IfExpression { discriminator, body } => self.if_expr(deref_optional_box(discriminator), &**body)?, Value(val) => self.handle_value(val)?, Call { box ref f, arguments } => self.call(f, arguments)?, Lambda { params, type_anno, body } => self.lambda(params, type_anno, body)?, _ => ty!(Unit), }) } fn prefix(&mut self, op: &PrefixOp, expr: &Expression) -> InferResult { let builtin: Option = TryFrom::try_from(op).ok(); let tf = match builtin.map(|b| b.get_type()) { Some(ty) => ty, None => return TypeError::new("no type found") }; let tx = self.expr(expr)?; self.handle_apply(tf, vec![tx]) } fn binexp(&mut self, op: &BinOp, lhs: &Expression, rhs: &Expression) -> InferResult { let builtin: Option = TryFrom::try_from(op).ok(); let tf = match builtin.map(|b| b.get_type()) { Some(ty) => ty, None => return TypeError::new("no type found"), }; let t_lhs = self.expr(lhs)?; let t_rhs = self.expr(rhs)?; //TODO is this order a problem? not sure self.handle_apply(tf, vec![t_lhs, t_rhs]) } fn if_expr(&mut self, discriminator: Option<&Expression>, body: &IfExpressionBody) -> InferResult { use self::IfExpressionBody::*; match (discriminator, body) { (Some(expr), SimpleConditional{ then_case, else_case }) => self.handle_simple_if(expr, then_case, else_case), _ => TypeError::new("Complex conditionals not supported".to_string()) } } #[allow(clippy::ptr_arg)] fn handle_simple_if(&mut self, expr: &Expression, then_clause: &Block, else_clause: &Option) -> InferResult { let t1 = self.expr(expr)?; let t2 = self.block(then_clause)?; let t3 = match else_clause { Some(block) => self.block(block)?, None => ty!(Unit) }; let _ = self.unify(ty!(Bool), t1)?; self.unify(t2, t3) } #[allow(clippy::ptr_arg)] fn lambda(&mut self, params: &Vec, type_anno: &Option, _body: &Block) -> InferResult { let argument_types: InferResult> = params.iter().map(|param: &FormalParam| { if let FormalParam { anno: Some(type_identifier), .. } = param { self.get_type_from_name(type_identifier) } else { Ok(Type::Var(self.fresh_type_variable())) } }).collect(); let argument_types = argument_types?; let ret_type = match type_anno.as_ref() { Some(anno) => self.get_type_from_name(anno)?, None => Type::Var(self.fresh_type_variable()) }; Ok(ty!(argument_types, ret_type)) } fn call(&mut self, f: &Expression, args: &[ InvocationArgument ]) -> InferResult { let tf = self.expr(f)?; let arg_types: InferResult> = args.iter().map(|ex| self.invoc(ex)).collect(); let arg_types = arg_types?; self.handle_apply(tf, arg_types) } fn handle_apply(&mut self, tf: Type, args: Vec) -> InferResult { Ok(match tf { Type::Arrow { ref params, ret: box ref t_ret } if params.len() == args.len() => { for (t_param, t_arg) in params.iter().zip(args.iter()) { let _ = self.unify(t_param.clone(), t_arg.clone())?; //TODO I think this needs to reference a sub-scope } t_ret.clone() }, Type::Arrow { .. } => return TypeError::new("Wrong length"), _ => return TypeError::new("Not a function".to_string()) }) } #[allow(clippy::ptr_arg)] fn block(&mut self, block: &Block) -> InferResult { let mut output = ty!(Unit); for statement in block.statements.iter() { output = self.statement(statement)?; } Ok(output) } fn handle_value(&mut self, val: &QualifiedName) -> InferResult { let QualifiedName { components: vec, .. } = val; let var = &vec[0]; match self.variable_map.lookup(var) { Some(ty) => Ok(ty.clone()), None => TypeError::new(format!("Couldn't find variable: {}", &var)), } } fn unify(&mut self, t1: Type, t2: Type) -> InferResult { use self::Type::*; match (t1, t2) { (Const(ref c1), Const(ref c2)) if c1 == c2 => Ok(Const(c1.clone())), //choice of c1 is arbitrary I *think* (a @ Var(_), b @ Const(_)) => self.unify(b, a), (Const(ref c1), Var(ref v2)) => { self.unification_table.unify_var_value(*v2, Some(c1.clone())) .or_else(|_| TypeError::new(format!("Couldn't unify {:?} and {:?}", Const(c1.clone()), Var(*v2))))?; Ok(Const(c1.clone())) }, (Var(v1), Var(v2)) => { //TODO add occurs check self.unification_table.unify_var_var(v1, v2) .or_else(|e| { println!("Unify error: {:?}", e); TypeError::new(format!("Two type variables {:?} and {:?} couldn't unify", v1, v2)) })?; Ok(Var(v1)) //arbitrary decision I think }, (a, b) => TypeError::new(format!("{:?} and {:?} do not unify", a, b)), } } fn fresh_type_variable(&mut self) -> TypeVar { self.unification_table.new_key(None) } } #[cfg(test)] mod typechecking_tests { use super::*; macro_rules! assert_type_in_fresh_context { ($string:expr, $type:expr) => { let mut tc = TypeContext::new(); let ast = &crate::util::quick_ast($string); let ty = tc.typecheck(ast).unwrap(); assert_eq!(ty, $type) } } #[test] fn basic_test() { assert_type_in_fresh_context!("1", ty!(Nat)); assert_type_in_fresh_context!(r#""drugs""#, ty!(StringT)); assert_type_in_fresh_context!("true", ty!(Bool)); } #[test] fn operators() { //TODO fix these with new operator regime /* assert_type_in_fresh_context!("-1", ty!(Int)); assert_type_in_fresh_context!("1 + 2", ty!(Nat)); assert_type_in_fresh_context!("-2", ty!(Int)); assert_type_in_fresh_context!("!true", ty!(Bool)); */ } }