schala/src/schala_lang/type_check.rs

460 lines
14 KiB
Rust
Raw Normal View History

2017-10-08 12:22:04 -07:00
use std::collections::HashMap;
2017-10-08 13:51:56 -07:00
use std::rc::Rc;
2017-10-08 12:22:04 -07:00
2017-10-11 02:33:46 -07:00
//SKOLEMIZATION - how you prevent an unassigned existential type variable from leaking!
2017-10-10 17:29:28 -07:00
use schala_lang::parsing::{AST, Statement, Declaration, Signature, Expression, ExpressionType, Operation, Variant, TypeName};
2017-10-04 02:07:30 -07:00
2017-10-10 01:11:24 -07:00
// from Niko's talk
/* fn type_check(expression, expected_ty) -> Ty {
let ty = bare_type_check(expression, expected_type);
if ty icompatible with expected_ty {
try_coerce(expression, ty, expected_ty)
} else {
ty
}
}
fn bare_type_check(exprssion, expected_type) -> Ty { ... }
*/
// from https://www.youtube.com/watch?v=il3gD7XMdmA
// typeInfer :: Expr a -> Matching (Type a)
// unify :: Type a -> Type b -> Matching (Type c)
2017-10-10 02:17:07 -07:00
#[derive(Debug, PartialEq, Clone)]
pub enum Type {
TVar(TypeVar),
TConst(TypeConst),
2017-10-11 01:55:45 -07:00
TFunc(Box<Type>, Box<Type>),
2017-10-10 02:17:07 -07:00
}
2017-10-10 01:11:24 -07:00
#[derive(Debug, PartialEq, Clone)]
2017-10-10 02:17:07 -07:00
pub enum TypeVar {
Univ(Rc<String>),
2017-10-10 01:11:24 -07:00
Exist(u64),
}
impl TypeVar {
fn univ(label: &str) -> TypeVar {
TypeVar::Univ(Rc::new(label.to_string()))
}
}
2017-10-10 01:11:24 -07:00
#[derive(Debug, PartialEq, Clone)]
2017-10-10 02:17:07 -07:00
pub enum TypeConst {
2017-10-10 17:29:28 -07:00
UserT(Rc<String>),
2017-10-10 01:11:24 -07:00
Integer,
Float,
2017-10-10 02:17:07 -07:00
StringT,
2017-10-10 01:11:24 -07:00
Boolean,
Unit,
Bottom,
}
2017-10-10 02:17:07 -07:00
type TypeCheckResult = Result<Type, String>;
2017-10-10 01:11:24 -07:00
2017-10-08 12:22:04 -07:00
#[derive(Debug, PartialEq, Eq, Hash)]
2017-10-09 04:02:50 -07:00
struct PathSpecifier(Rc<String>);
#[derive(Debug, PartialEq, Clone)]
struct TypeContextEntry {
2017-10-10 22:14:55 -07:00
ty: Type,
constant: bool
2017-10-08 13:51:56 -07:00
}
2017-10-08 12:22:04 -07:00
2017-10-09 00:59:52 -07:00
pub struct TypeContext {
symbol_table: HashMap<PathSpecifier, TypeContextEntry>,
evar_table: HashMap<u64, Type>,
2017-10-09 02:26:59 -07:00
existential_type_label_count: u64
}
2017-10-09 00:59:52 -07:00
impl TypeContext {
pub fn new() -> TypeContext {
2017-10-09 02:26:59 -07:00
TypeContext {
symbol_table: HashMap::new(),
evar_table: HashMap::new(),
2017-10-09 02:26:59 -07:00
existential_type_label_count: 0,
}
}
2017-10-09 00:59:52 -07:00
pub fn add_symbols(&mut self, ast: &AST) {
2017-10-08 13:51:56 -07:00
use self::Declaration::*;
2017-10-10 17:29:28 -07:00
use self::Type::*;
use self::TypeConst::*;
2017-10-08 13:51:56 -07:00
for statement in ast.0.iter() {
2017-10-09 00:59:52 -07:00
match *statement {
Statement::ExpressionStatement(_) => (),
Statement::Declaration(ref decl) => {
match *decl {
FuncSig(_) => (),
Impl { .. } => (),
2017-10-10 17:29:28 -07:00
TypeDecl(ref type_constructor, ref body) => {
for variant in body.0.iter() {
2017-10-10 22:14:55 -07:00
let (spec, ty) = match variant {
2017-10-10 17:29:28 -07:00
&Variant::UnitStruct(ref data_constructor) => {
let spec = PathSpecifier(data_constructor.clone());
2017-10-10 22:14:55 -07:00
let ty = TConst(UserT(type_constructor.clone()));
(spec, ty)
2017-10-10 17:29:28 -07:00
},
&Variant::TupleStruct(ref data_construcor, ref args) => {
//TODO fix
let arg = args.get(0).unwrap();
let type_arg = self.from_anno(arg);
let spec = PathSpecifier(data_construcor.clone());
2017-10-11 01:55:45 -07:00
let ty = TFunc(Box::new(type_arg), Box::new(TConst(UserT(type_constructor.clone()))));
2017-10-10 22:14:55 -07:00
(spec, ty)
2017-10-10 17:29:28 -07:00
},
&Variant::Record(_, _) => unimplemented!(),
};
2017-10-10 22:14:55 -07:00
let entry = TypeContextEntry { ty, constant: true };
2017-10-10 17:29:28 -07:00
self.symbol_table.insert(spec, entry);
}
},
2017-10-09 00:59:52 -07:00
TypeAlias { .. } => (),
Binding {ref name, ref constant, ref expr} => {
2017-10-09 04:02:50 -07:00
let spec = PathSpecifier(name.clone());
2017-10-10 22:14:55 -07:00
let ty = expr.1.as_ref()
2017-10-09 02:26:59 -07:00
.map(|ty| self.from_anno(ty))
.unwrap_or_else(|| { self.alloc_existential_type() }); // this call to alloc_existential is OK b/c a binding only ever has one type, so if the annotation is absent, it's fine to just make one de novo
2017-10-10 22:14:55 -07:00
let entry = TypeContextEntry { ty, constant: *constant };
self.symbol_table.insert(spec, entry);
2017-10-09 00:59:52 -07:00
},
2017-10-09 02:26:59 -07:00
FuncDecl(ref signature, _) => {
2017-10-09 04:02:50 -07:00
let spec = PathSpecifier(signature.name.clone());
2017-10-10 22:14:55 -07:00
let ty = self.from_signature(signature);
let entry = TypeContextEntry { ty, constant: true };
self.symbol_table.insert(spec, entry);
2017-10-08 13:51:56 -07:00
},
}
}
}
}
}
fn lookup(&mut self, binding: &Rc<String>) -> Option<TypeContextEntry> {
2017-10-09 04:02:50 -07:00
let key = PathSpecifier(binding.clone());
self.symbol_table.get(&key).map(|entry| entry.clone())
}
2017-10-08 13:57:43 -07:00
pub fn debug_symbol_table(&self) -> String {
2017-10-11 02:33:46 -07:00
format!("Symbol table:\n {:?}\nEvar table:\n{:?}", self.symbol_table, self.evar_table)
2017-10-08 13:57:43 -07:00
}
fn alloc_existential_type(&mut self) -> Type {
2017-10-10 02:17:07 -07:00
let ret = Type::TVar(TypeVar::Exist(self.existential_type_label_count));
2017-10-09 02:26:59 -07:00
self.existential_type_label_count += 1;
ret
}
2017-10-10 02:17:07 -07:00
fn from_anno(&mut self, anno: &TypeName) -> Type {
use self::Type::*;
use self::TypeConst::*;
2017-10-09 02:26:59 -07:00
match anno {
&TypeName::Singleton { ref name, .. } => {
match name.as_ref().as_ref() {
2017-10-10 02:17:07 -07:00
"Int" => TConst(Integer),
"Bool" => TConst(Boolean),
2017-10-10 21:51:45 -07:00
"String" => TConst(StringT),
s => TVar(TypeVar::Univ(Rc::new(format!("{}",s)))),
2017-10-09 02:26:59 -07:00
}
},
2017-10-10 21:51:45 -07:00
&TypeName::Tuple(ref items) => {
if items.len() == 1 {
TConst(Unit)
} else {
TConst(Bottom)
}
}
2017-10-09 02:26:59 -07:00
}
}
2017-10-10 02:17:07 -07:00
fn from_signature(&mut self, sig: &Signature) -> Type {
use self::Type::*;
use self::TypeConst::*;
//TODO this won't work properly until you make sure that all (universal) type vars in the function have the same existential type var
// actually this should never even put existential types into the symbol table at all
//this will crash if more than 5 arg function is used
let names = vec!["a", "b", "c", "d", "e", "f"];
let mut idx = 0;
let mut get_type = || { let q = TVar(TypeVar::Univ(Rc::new(format!("{}", names.get(idx).unwrap())))); idx += 1; q };
let return_type = sig.type_anno.as_ref().map(|anno| self.from_anno(&anno)).unwrap_or_else(|| { get_type() });
2017-10-09 02:26:59 -07:00
if sig.params.len() == 0 {
2017-10-11 01:55:45 -07:00
TFunc(Box::new(TConst(Unit)), Box::new(return_type))
2017-10-09 02:26:59 -07:00
} else {
let mut output_type = return_type;
for p in sig.params.iter() {
let p_type = p.1.as_ref().map(|anno| self.from_anno(anno)).unwrap_or_else(|| { get_type() });
2017-10-11 01:55:45 -07:00
output_type = TFunc(Box::new(p_type), Box::new(output_type));
2017-10-09 02:26:59 -07:00
}
output_type
}
}
2017-10-04 02:07:30 -07:00
pub fn type_check(&mut self, ast: &AST) -> TypeCheckResult {
2017-10-10 02:17:07 -07:00
use self::Type::*;
use self::TypeConst::*;
let mut last = TConst(Unit);
2017-10-04 02:07:30 -07:00
for statement in ast.0.iter() {
match statement {
&Statement::Declaration(ref _decl) => {
//return Err(format!("Declarations not supported"));
2017-10-04 02:07:30 -07:00
},
&Statement::ExpressionStatement(ref expr) => {
last = self.infer(expr)?;
2017-10-04 02:07:30 -07:00
}
}
}
Ok(last)
2017-10-04 02:07:30 -07:00
}
2017-10-08 12:22:04 -07:00
/*
fn infer(&mut self, expr: &Expression) -> TypeCheckResult {
2017-10-08 12:22:04 -07:00
use self::ExpressionType::*;
2017-10-10 02:17:07 -07:00
use self::Type::*;
use self::TypeConst::*;
2017-10-08 12:22:04 -07:00
Ok(match (&expr.0, &expr.1) {
2017-10-10 21:23:24 -07:00
(&IntLiteral(_), anno) => {
match *anno {
None => TConst(Integer),
Some(ref t) => self.from_anno(t)
}
}
(&FloatLiteral(_), anno) => {
match *anno {
None => TConst(Float),
Some(ref t) => self.from_anno(t),
}
},
(&StringLiteral(_), anno) => {
match *anno {
None => TConst(StringT),
Some(ref t) => self.from_anno(t),
}
},
(&BoolLiteral(_), anno) => {
match *anno {
None => TConst(Boolean),
Some(ref t) => self.from_anno(t),
}
},
2017-10-10 21:23:24 -07:00
(&Value(ref name), ref _anno) => {
2017-10-10 02:32:02 -07:00
self.lookup(name)
2017-10-10 22:14:55 -07:00
.map(|entry| entry.ty)
2017-10-10 02:32:02 -07:00
.ok_or(format!("Couldn't find {}", name))?
},
2017-10-10 21:23:24 -07:00
(&BinExp(ref op, box ref lhs, box ref rhs), ref _anno) => {
2017-10-10 02:41:17 -07:00
let op_type = self.infer_op(op)?;
let lhs_type = self.infer(&lhs)?;
match op_type {
TConst(FunctionT(box t1, box t2)) => {
let _ = self.unify(t1, lhs_type)?;
2017-10-10 02:41:17 -07:00
let rhs_type = self.infer(&rhs)?;
match t2 {
TConst(FunctionT(box t3, box t_ret)) => {
let _ = self.unify(t3, rhs_type)?;
2017-10-10 02:41:17 -07:00
t_ret
},
_ => return Err(format!("Another bad type for operator"))
}
},
_ => return Err(format!("Bad type for operator")),
}
2017-10-10 01:04:19 -07:00
},
2017-10-10 21:23:24 -07:00
(&Call { ref f, ref arguments }, ref _anno) => {
let f_type = self.infer(&*f)?;
let arg_type = self.infer(arguments.get(0).unwrap())?; // TODO fix later
match f_type {
2017-10-10 02:17:07 -07:00
TConst(FunctionT(box t1, box ret_type)) => {
let _ = self.unify(t1, arg_type)?;
ret_type
},
_ => return Err(format!("Type error"))
}
},
2017-10-10 02:17:07 -07:00
_ => TConst(Unit)
})
2017-10-08 12:22:04 -07:00
}
fn infer_op(&mut self, op: &Operation) -> TypeCheckResult {
2017-10-10 02:17:07 -07:00
use self::Type::*;
use self::TypeConst::*;
let opstr: &str = &op.0;
if opstr == "+" {
return Ok(
TConst(FunctionT(
Box::new(TVar(TypeVar::univ("a"))),
Box::new(TConst(FunctionT(
Box::new(TVar(TypeVar::univ("a"))),
Box::new(TVar(TypeVar::univ("a")))
)))
))
)
}
2017-10-10 01:04:19 -07:00
Ok(
2017-10-10 02:17:07 -07:00
TConst(FunctionT(
Box::new(TConst(Integer)),
Box::new(TConst(FunctionT(
Box::new(TConst(Integer)),
Box::new(TConst(Integer))
2017-10-10 01:04:19 -07:00
)))
))
)
}
*/
fn infer(&mut self, expr: &Expression) -> TypeCheckResult {
match (&expr.0, &expr.1) {
(exprtype, &Some(ref anno)) => {
let tx = self.infer_no_anno(exprtype)?;
let ty = self.from_anno(anno);
self.unify(tx, ty)
},
(exprtype, &None) => self.infer_no_anno(exprtype),
}
}
fn infer_no_anno(&mut self, ex: &ExpressionType) -> TypeCheckResult {
2017-10-12 21:46:12 -07:00
use self::ExpressionType::*;
use self::Type::*;
use self::TypeConst::*;
Ok(match ex {
&IntLiteral(_) => TConst(Integer),
&FloatLiteral(_) => TConst(Float),
&StringLiteral(_) => TConst(StringT),
&BoolLiteral(_) => TConst(Boolean),
&Value(ref name) => {
self.lookup(name)
.map(|entry| entry.ty)
.ok_or(format!("Couldn't find {}", name))?
},
&BinExp(ref op, ref lhs, ref rhs) => {
let t_lhs = self.infer(lhs)?;
match self.infer_op(op)? {
TFunc(t1, t2) => {
let _ = self.unify(t_lhs, *t1)?;
let t_rhs = self.infer(rhs)?;
let x = *t2;
match x {
TFunc(t3, t4) => {
let _ = self.unify(t_rhs, *t3)?;
*t4
},
_ => return Err(format!("Not a function type either")),
}
},
_ => return Err(format!("Op {:?} is not a function type", op)),
}
},
&Call { ref f, ref arguments } => {
let tf = self.infer(f)?;
let targ = self.infer(arguments.get(0).unwrap())?;
match tf {
2017-10-11 01:55:45 -07:00
TFunc(box t1, box t2) => {
let _ = self.unify(t1, targ)?;
t2
},
_ => return Err(format!("Not a function!")),
}
},
_ => TConst(Bottom),
})
}
2017-10-10 01:04:19 -07:00
fn infer_op(&mut self, op: &Operation) -> TypeCheckResult {
2017-10-10 04:38:59 -07:00
use self::Type::*;
use self::TypeConst::*;
macro_rules! binoptype {
($lhs:expr, $rhs:expr, $out:expr) => { TFunc(Box::new($lhs), Box::new(TFunc(Box::new($rhs), Box::new($out)))) };
}
Ok(match (*op.0).as_ref() {
"+" => binoptype!(TConst(Integer), TConst(Integer), TConst(Integer)),
"++" => binoptype!(TConst(StringT), TConst(StringT), TConst(StringT)),
"-" => binoptype!(TConst(Integer), TConst(Integer), TConst(Integer)),
"*" => binoptype!(TConst(Integer), TConst(Integer), TConst(Integer)),
"/" => binoptype!(TConst(Integer), TConst(Integer), TConst(Integer)),
"%" => binoptype!(TConst(Integer), TConst(Integer), TConst(Integer)),
_ => TConst(Bottom)
})
}
fn unify(&mut self, t1: Type, t2: Type) -> TypeCheckResult {
use self::Type::*;
use self::TypeVar::*;
2017-10-10 04:38:59 -07:00
println!("Calling unify with `{:?}` and `{:?}`", t1, t2);
2017-10-10 04:38:59 -07:00
match (&t1, &t2) {
(&TConst(ref c1), &TConst(ref c2)) if c1 == c2 => Ok(TConst(c1.clone())),
2017-10-11 02:03:50 -07:00
(&TFunc(ref t1, ref t2), &TFunc(ref t3, ref t4)) => {
let t5 = self.unify(*t1.clone().clone(), *t3.clone().clone())?;
let t6 = self.unify(*t2.clone().clone(), *t4.clone().clone())?;
Ok(TFunc(Box::new(t5), Box::new(t6)))
},
(&TVar(Univ(ref a)), &TVar(Univ(ref b))) => {
if a == b {
Ok(TVar(Univ(a.clone())))
} else {
Err(format!("Couldn't unify universal types {} and {}", a, b))
}
},
2017-10-11 02:33:46 -07:00
//the interesting case!!
(&TVar(Exist(ref a)), ref t2) => {
let x = self.evar_table.get(a).map(|x| x.clone());
match x {
Some(ref t1) => self.unify(t1.clone().clone(), t2.clone().clone()),
None => {
self.evar_table.insert(*a, t2.clone().clone());
Ok(t2.clone().clone())
}
}
},
(ref t1, &TVar(Exist(ref a))) => {
let x = self.evar_table.get(a).map(|x| x.clone());
match x {
Some(ref t2) => self.unify(t2.clone().clone(), t1.clone().clone()),
None => {
self.evar_table.insert(*a, t1.clone().clone());
Ok(t1.clone().clone())
}
}
},
2017-10-10 04:38:59 -07:00
_ => Err(format!("Types {:?} and {:?} don't unify", t1, t2))
}
}
}
2017-10-09 12:26:25 -07:00
#[cfg(test)]
mod tests {
2017-10-10 02:17:07 -07:00
use super::{Type, TypeVar, TypeConst, TypeContext};
use super::Type::*;
use super::TypeConst::*;
2017-10-09 12:26:25 -07:00
use schala_lang::parsing::{parse, tokenize};
macro_rules! type_test {
($input:expr, $correct:expr) => {
{
let mut tc = TypeContext::new();
let ast = parse(tokenize($input)).0.unwrap() ;
tc.add_symbols(&ast);
2017-10-09 12:26:25 -07:00
assert_eq!($correct, tc.type_check(&ast).unwrap())
}
}
}
#[test]
fn basic_inference() {
2017-10-10 02:17:07 -07:00
type_test!("30", TConst(Integer));
type_test!("fn x(a: Int): Bool {}; x(1)", TConst(Boolean));
2017-10-09 12:26:25 -07:00
}
}