schala/src/schala_lang/type_check.rs

235 lines
6.6 KiB
Rust
Raw Normal View History

2017-10-08 12:22:04 -07:00
use std::collections::HashMap;
2017-10-08 13:51:56 -07:00
use std::rc::Rc;
2017-10-08 12:22:04 -07:00
2017-10-09 02:26:59 -07:00
use schala_lang::parsing::{AST, Statement, Declaration, Signature, Expression, ExpressionType, Operation, TypeName};
2017-10-04 02:07:30 -07:00
2017-10-10 01:11:24 -07:00
// from Niko's talk
/* fn type_check(expression, expected_ty) -> Ty {
let ty = bare_type_check(expression, expected_type);
if ty icompatible with expected_ty {
try_coerce(expression, ty, expected_ty)
} else {
ty
}
}
fn bare_type_check(exprssion, expected_type) -> Ty { ... }
*/
// from https://www.youtube.com/watch?v=il3gD7XMdmA
// typeInfer :: Expr a -> Matching (Type a)
// unify :: Type a -> Type b -> Matching (Type c)
#[derive(Debug, PartialEq, Clone)]
pub enum TypeVariable {
Univ(UVar),
Exist(u64),
}
#[derive(Debug, PartialEq, Clone)]
pub enum UVar {
Integer,
Float,
String,
Boolean,
Unit,
Function(Box<TypeVariable>, Box<TypeVariable>),
Bottom,
}
type TypeCheckResult = Result<TypeVariable, String>;
2017-10-08 12:22:04 -07:00
#[derive(Debug, PartialEq, Eq, Hash)]
2017-10-09 04:02:50 -07:00
struct PathSpecifier(Rc<String>);
#[derive(Debug, PartialEq, Clone)]
struct TypeContextEntry {
type_var: TypeVariable,
constant: bool
2017-10-08 13:51:56 -07:00
}
2017-10-08 12:22:04 -07:00
2017-10-09 00:59:52 -07:00
pub struct TypeContext {
symbol_table: HashMap<PathSpecifier, TypeContextEntry>,
2017-10-09 02:26:59 -07:00
existential_type_label_count: u64
}
2017-10-09 00:59:52 -07:00
impl TypeContext {
pub fn new() -> TypeContext {
2017-10-09 02:26:59 -07:00
TypeContext {
symbol_table: HashMap::new(),
existential_type_label_count: 0,
}
}
2017-10-09 00:59:52 -07:00
pub fn add_symbols(&mut self, ast: &AST) {
2017-10-08 13:51:56 -07:00
use self::Declaration::*;
for statement in ast.0.iter() {
2017-10-09 00:59:52 -07:00
match *statement {
Statement::ExpressionStatement(_) => (),
Statement::Declaration(ref decl) => {
match *decl {
FuncSig(_) => (),
Impl { .. } => (),
TypeDecl { .. } => (),
TypeAlias { .. } => (),
Binding {ref name, ref constant, ref expr} => {
2017-10-09 04:02:50 -07:00
let spec = PathSpecifier(name.clone());
let type_var = expr.1.as_ref()
2017-10-09 02:26:59 -07:00
.map(|ty| self.from_anno(ty))
.unwrap_or_else(|| { self.get_existential_type() });
let entry = TypeContextEntry { type_var, constant: *constant };
self.symbol_table.insert(spec, entry);
2017-10-09 00:59:52 -07:00
},
2017-10-09 02:26:59 -07:00
FuncDecl(ref signature, _) => {
2017-10-09 04:02:50 -07:00
let spec = PathSpecifier(signature.name.clone());
let type_var = self.from_signature(signature);
let entry = TypeContextEntry { type_var, constant: true };
self.symbol_table.insert(spec, entry);
2017-10-08 13:51:56 -07:00
},
}
}
}
}
}
fn lookup(&mut self, binding: &Rc<String>) -> Option<TypeContextEntry> {
2017-10-09 04:02:50 -07:00
let key = PathSpecifier(binding.clone());
self.symbol_table.get(&key).map(|entry| entry.clone())
}
2017-10-08 13:57:43 -07:00
pub fn debug_symbol_table(&self) -> String {
2017-10-09 00:59:52 -07:00
format!("Symbol table:\n {:?}", self.symbol_table)
2017-10-08 13:57:43 -07:00
}
2017-10-09 02:26:59 -07:00
fn get_existential_type(&mut self) -> TypeVariable {
let ret = TypeVariable::Exist(self.existential_type_label_count);
self.existential_type_label_count += 1;
ret
}
fn from_anno(&mut self, anno: &TypeName) -> TypeVariable {
use self::TypeVariable::*;
use self::UVar::*;
match anno {
&TypeName::Singleton { ref name, .. } => {
match name.as_ref().as_ref() {
"Int" => Univ(Integer),
"Bool" => Univ(Boolean),
_ => self.get_existential_type()
}
},
_ => Univ(Bottom),
}
}
fn from_signature(&mut self, sig: &Signature) -> TypeVariable {
use self::TypeVariable::Univ;
use self::UVar::{Unit, Function};
let return_type = sig.type_anno.as_ref().map(|anno| self.from_anno(&anno)).unwrap_or_else(|| { self.get_existential_type() });
if sig.params.len() == 0 {
Univ(Function(Box::new(Univ(Unit)), Box::new(return_type)))
} else {
let mut output_type = return_type;
for p in sig.params.iter() {
let p_type = p.1.as_ref().map(|anno| self.from_anno(anno)).unwrap_or_else(|| { self.get_existential_type() });
output_type = Univ(Function(Box::new(p_type), Box::new(output_type)));
}
output_type
}
}
2017-10-04 02:07:30 -07:00
pub fn type_check(&mut self, ast: &AST) -> TypeCheckResult {
let mut last = TypeVariable::Univ(UVar::Unit);
2017-10-04 02:07:30 -07:00
for statement in ast.0.iter() {
match statement {
&Statement::Declaration(ref _decl) => {
//return Err(format!("Declarations not supported"));
2017-10-04 02:07:30 -07:00
},
&Statement::ExpressionStatement(ref expr) => {
last = self.infer(expr)?;
2017-10-04 02:07:30 -07:00
}
}
}
Ok(last)
2017-10-04 02:07:30 -07:00
}
2017-10-08 12:22:04 -07:00
fn infer(&mut self, expr: &Expression) -> TypeCheckResult {
2017-10-08 12:22:04 -07:00
use self::ExpressionType::*;
use self::TypeVariable::*;
2017-10-08 12:22:04 -07:00
Ok(match (&expr.0, &expr.1) {
(ref _t, &Some(ref anno)) => {
//TODO make this better,
2017-10-09 02:26:59 -07:00
self.from_anno(anno)
},
(&IntLiteral(_), _) => Univ(UVar::Integer),
2017-10-09 04:02:50 -07:00
(&FloatLiteral(_), _) => Univ(UVar::Float),
2017-10-09 11:42:53 -07:00
(&StringLiteral(_), _) => Univ(UVar::String),
(&BoolLiteral(_), _) => Univ(UVar::Boolean),
(&Variable(ref name), _) => self.lookup(name).map(|entry| entry.type_var)
.ok_or(format!("Couldn't find {}", name))?,
2017-10-10 01:04:19 -07:00
(&BinExp(ref op, box ref lhs, box ref rhs), _) => {
let _f_type = self.infer_op(op);
let _lhs_type = self.infer(&lhs);
let _rhs_type = self.infer(&rhs);
unimplemented!()
},
(&Call { ref f, ref arguments }, _) => {
let f_type = self.infer(&*f)?;
let arg_type = self.infer(arguments.get(0).unwrap())?; // TODO fix later
match f_type {
Univ(UVar::Function(box t1, box ret_type)) => {
let _ = self.unify(&t1, &arg_type)?;
ret_type
},
_ => return Err(format!("Type error"))
}
},
_ => Univ(UVar::Unit),
})
2017-10-08 12:22:04 -07:00
}
2017-10-10 01:04:19 -07:00
fn infer_op(&mut self, _op: &Operation) -> TypeCheckResult {
use self::TypeVariable::*;
Ok(
Univ(UVar::Function(
Box::new(Univ(UVar::Integer)),
Box::new(Univ(UVar::Function(
Box::new(Univ(UVar::Integer)),
Box::new(Univ(UVar::Integer))
)))
))
)
}
2017-10-09 00:22:42 -07:00
fn unify(&mut self, t1: &TypeVariable, t2: &TypeVariable) -> TypeCheckResult {
if t1 == t2 {
Ok(t1.clone())
} else {
Err(format!("Types {:?} and {:?} don't unify", t1, t2))
}
}
}
2017-10-09 12:26:25 -07:00
#[cfg(test)]
mod tests {
use super::{TypeContext, TypeVariable, UVar};
use super::TypeVariable::*;
use schala_lang::parsing::{parse, tokenize};
macro_rules! type_test {
($input:expr, $correct:expr) => {
{
let mut tc = TypeContext::new();
let ast = parse(tokenize($input)).0.unwrap() ;
assert_eq!($correct, tc.type_check(&ast).unwrap())
}
}
}
#[test]
fn basic_inference() {
type_test!("30", Univ(UVar::Integer))
}
}