486 lines
14 KiB
Rust
486 lines
14 KiB
Rust
use std::rc::Rc;
|
|
use std::fmt::Write;
|
|
|
|
use ena::unify::{UnifyKey, InPlaceUnificationTable, UnificationTable, EqUnifyValue};
|
|
|
|
use crate::ast::*;
|
|
use crate::util::ScopeStack;
|
|
|
|
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub struct TypeData {
|
|
ty: Option<Type>
|
|
}
|
|
|
|
impl TypeData {
|
|
pub fn new() -> TypeData {
|
|
TypeData { ty: None }
|
|
}
|
|
}
|
|
|
|
pub type TypeName = Rc<String>;
|
|
|
|
pub struct TypeContext<'a> {
|
|
variable_map: ScopeStack<'a, Rc<String>, Type>,
|
|
unification_table: InPlaceUnificationTable<TypeVar>,
|
|
}
|
|
|
|
/// `InferResult` is the monad in which type inference takes place.
|
|
type InferResult<T> = Result<T, TypeError>;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct TypeError { pub msg: String }
|
|
|
|
impl TypeError {
|
|
fn new<A, T>(msg: T) -> InferResult<A> where T: Into<String> {
|
|
Err(TypeError { msg: msg.into() })
|
|
}
|
|
}
|
|
|
|
#[allow(dead_code)] // avoids warning from Compound
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
pub enum Type {
|
|
Const(TypeConst),
|
|
Var(TypeVar),
|
|
Arrow {
|
|
params: Vec<Type>,
|
|
ret: Box<Type>
|
|
},
|
|
Compound {
|
|
ty_name: String,
|
|
args:Vec<Type>
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
pub struct TypeVar(usize);
|
|
|
|
impl UnifyKey for TypeVar {
|
|
type Value = Option<TypeConst>;
|
|
fn index(&self) -> u32 { self.0 as u32 }
|
|
fn from_index(u: u32) -> TypeVar { TypeVar(u as usize) }
|
|
fn tag() -> &'static str { "TypeVar" }
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum TypeConst {
|
|
Unit,
|
|
Nat,
|
|
Int,
|
|
Float,
|
|
StringT,
|
|
Bool,
|
|
Ordering,
|
|
//UserDefined
|
|
}
|
|
|
|
impl TypeConst {
|
|
pub fn to_string(&self) -> String {
|
|
use self::TypeConst::*;
|
|
match self {
|
|
Unit => format!("()"),
|
|
Nat => format!("Nat"),
|
|
Int => format!("Int"),
|
|
Float => format!("Float"),
|
|
StringT => format!("String"),
|
|
Bool => format!("Bool"),
|
|
Ordering => format!("Ordering"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl EqUnifyValue for TypeConst { }
|
|
|
|
macro_rules! ty {
|
|
($type_name:ident) => { Type::Const(TypeConst::$type_name) };
|
|
($t1:ident -> $t2:ident) => { Type::Arrow { params: vec![ty!($t1)], ret: box ty!($t2) } };
|
|
($t1:ident -> $t2:ident -> $t3:ident) => { Type::Arrow { params: vec![ty!($t1), ty!($t2)], ret: box ty!($t3) } };
|
|
($type_list:ident, $ret_type:ident) => {
|
|
Type::Arrow {
|
|
params: $type_list,
|
|
ret: box $ret_type,
|
|
}
|
|
}
|
|
}
|
|
|
|
//TODO find a better way to capture the to/from string logic
|
|
impl Type {
|
|
pub fn to_string(&self) -> String {
|
|
use self::Type::*;
|
|
match self {
|
|
Const(c) => c.to_string(),
|
|
Var(v) => format!("t_{}", v.0),
|
|
Arrow { params, box ref ret } => {
|
|
if params.len() == 0 {
|
|
format!("-> {}", ret.to_string())
|
|
} else {
|
|
let mut buf = String::new();
|
|
for p in params.iter() {
|
|
write!(buf, "{} -> ", p.to_string()).unwrap();
|
|
}
|
|
write!(buf, "{}", ret.to_string()).unwrap();
|
|
buf
|
|
}
|
|
},
|
|
Compound { .. } => format!("<some compound type>")
|
|
}
|
|
}
|
|
|
|
fn from_string(string: &str) -> Option<Type> {
|
|
Some(match string {
|
|
"()" | "Unit" => ty!(Unit),
|
|
"Nat" => ty!(Nat),
|
|
"Int" => ty!(Int),
|
|
"Float" => ty!(Float),
|
|
"String" => ty!(StringT),
|
|
"Bool" => ty!(Bool),
|
|
"Ordering" => ty!(Ordering),
|
|
_ => return None
|
|
})
|
|
}
|
|
}
|
|
|
|
/*
|
|
/// `Type` is parameterized by whether the type variables can be just universal, or universal or
|
|
/// existential.
|
|
#[derive(Debug, Clone)]
|
|
enum Type<A> {
|
|
Var(A),
|
|
Const(TConst),
|
|
Arrow(Box<Type<A>>, Box<Type<A>>),
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
enum TVar {
|
|
Univ(UVar),
|
|
Exist(ExistentialVar)
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct UVar(Rc<String>);
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct ExistentialVar(u32);
|
|
|
|
impl Type<UVar> {
|
|
fn to_tvar(&self) -> Type<TVar> {
|
|
match self {
|
|
Type::Var(UVar(name)) => Type::Var(TVar::Univ(UVar(name.clone()))),
|
|
Type::Const(ref c) => Type::Const(c.clone()),
|
|
Type::Arrow(a, b) => Type::Arrow(
|
|
Box::new(a.to_tvar()),
|
|
Box::new(b.to_tvar())
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Type<TVar> {
|
|
fn skolemize(&self) -> Type<UVar> {
|
|
match self {
|
|
Type::Var(TVar::Univ(uvar)) => Type::Var(uvar.clone()),
|
|
Type::Var(TVar::Exist(_)) => Type::Var(UVar(Rc::new(format!("sk")))),
|
|
Type::Const(ref c) => Type::Const(c.clone()),
|
|
Type::Arrow(a, b) => Type::Arrow(
|
|
Box::new(a.skolemize()),
|
|
Box::new(b.skolemize())
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl TypeIdentifier {
|
|
fn to_monotype(&self) -> Type<UVar> {
|
|
match self {
|
|
TypeIdentifier::Tuple(_) => Type::Const(TConst::Nat),
|
|
TypeIdentifier::Singleton(TypeSingletonName { name, .. }) => {
|
|
match &name[..] {
|
|
"Nat" => Type::Const(TConst::Nat),
|
|
"Int" => Type::Const(TConst::Int),
|
|
"Float" => Type::Const(TConst::Float),
|
|
"Bool" => Type::Const(TConst::Bool),
|
|
"String" => Type::Const(TConst::StringT),
|
|
_ => Type::Const(TConst::Nat),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
enum TConst {
|
|
User(Rc<String>),
|
|
Unit,
|
|
Nat,
|
|
Int,
|
|
Float,
|
|
StringT,
|
|
Bool,
|
|
}
|
|
|
|
impl TConst {
|
|
fn user(name: &str) -> TConst {
|
|
TConst::User(Rc::new(name.to_string()))
|
|
}
|
|
}
|
|
*/
|
|
|
|
impl<'a> TypeContext<'a> {
|
|
pub fn new() -> TypeContext<'a> {
|
|
TypeContext {
|
|
variable_map: ScopeStack::new(None),
|
|
unification_table: UnificationTable::new(),
|
|
}
|
|
}
|
|
|
|
/*
|
|
fn new_env(&'a self, new_var: Rc<String>, ty: Type) -> TypeContext<'a> {
|
|
let mut new_context = TypeContext {
|
|
variable_map: self.variable_map.new_scope(None),
|
|
unification_table: UnificationTable::new(), //???? not sure if i want this
|
|
};
|
|
|
|
new_context.variable_map.insert(new_var, ty);
|
|
new_context
|
|
}
|
|
*/
|
|
|
|
|
|
fn get_type_from_name(&self, name: &TypeIdentifier) -> InferResult<Type> {
|
|
use self::TypeIdentifier::*;
|
|
Ok(match name {
|
|
Singleton(TypeSingletonName { name,.. }) => {
|
|
match Type::from_string(&name) {
|
|
Some(ty) => ty,
|
|
None => return TypeError::new(format!("Unknown type name: {}", name))
|
|
}
|
|
},
|
|
Tuple(_) => return TypeError::new("tuples aren't ready yet"),
|
|
})
|
|
}
|
|
|
|
/// `typecheck` is the entry into the type-inference system, accepting an AST as an argument
|
|
/// Following the example of GHC, the compiler deliberately does typechecking before de-sugaring
|
|
/// the AST to ReducedAST
|
|
pub fn typecheck(&mut self, ast: &AST) -> Result<Type, TypeError> {
|
|
let mut returned_type = Type::Const(TypeConst::Unit);
|
|
for statement in ast.statements.iter() {
|
|
returned_type = self.statement(statement.node())?;
|
|
}
|
|
Ok(returned_type)
|
|
}
|
|
|
|
fn statement(&mut self, statement: &Statement) -> InferResult<Type> {
|
|
match &statement.kind {
|
|
StatementKind::Expression(e) => self.expr(e.node()),
|
|
StatementKind::Declaration(decl) => self.decl(&decl),
|
|
}
|
|
}
|
|
|
|
fn decl(&mut self, decl: &Declaration) -> InferResult<Type> {
|
|
use self::Declaration::*;
|
|
match decl {
|
|
Binding { name, expr, .. } => {
|
|
let ty = self.expr(expr.node())?;
|
|
self.variable_map.insert(name.clone(), ty);
|
|
},
|
|
_ => (),
|
|
}
|
|
Ok(ty!(Unit))
|
|
}
|
|
|
|
fn invoc(&mut self, invoc: &InvocationArgument) -> InferResult<Type> {
|
|
use InvocationArgument::*;
|
|
match invoc {
|
|
Positional(expr) => self.expr(expr.node()),
|
|
_ => Ok(ty!(Nat)) //TODO this is wrong
|
|
}
|
|
}
|
|
|
|
fn expr(&mut self, expr: &Expression) -> InferResult<Type> {
|
|
match expr {
|
|
Expression { kind, type_anno: Some(anno) } => {
|
|
let t1 = self.expr_type(kind)?;
|
|
let t2 = self.get_type_from_name(anno)?;
|
|
self.unify(t2, t1)
|
|
},
|
|
Expression { kind, type_anno: None } => self.expr_type(kind)
|
|
}
|
|
}
|
|
|
|
fn expr_type(&mut self, expr: &ExpressionKind) -> InferResult<Type> {
|
|
use self::ExpressionKind::*;
|
|
Ok(match expr {
|
|
NatLiteral(_) => ty!(Nat),
|
|
BoolLiteral(_) => ty!(Bool),
|
|
FloatLiteral(_) => ty!(Float),
|
|
StringLiteral(_) => ty!(StringT),
|
|
PrefixExp(op, expr) => self.prefix(op, expr.node())?,
|
|
BinExp(op, lhs, rhs) => self.binexp(op, lhs.node(), rhs.node())?,
|
|
IfExpression { discriminator, body } => self.if_expr(discriminator, body)?,
|
|
Value(val) => self.handle_value(val.node())?,
|
|
Call { box ref f, arguments } => self.call(f, arguments)?,
|
|
Lambda { params, type_anno, body } => self.lambda(params, type_anno, body)?,
|
|
_ => ty!(Unit),
|
|
})
|
|
}
|
|
|
|
fn prefix(&mut self, op: &PrefixOp, expr: &Expression) -> InferResult<Type> {
|
|
let tf = match op.builtin.map(|b| b.get_type()) {
|
|
Some(ty) => ty,
|
|
None => return TypeError::new("no type found")
|
|
};
|
|
|
|
let tx = self.expr(expr)?;
|
|
self.handle_apply(tf, vec![tx])
|
|
}
|
|
|
|
fn binexp(&mut self, op: &BinOp, lhs: &Expression, rhs: &Expression) -> InferResult<Type> {
|
|
let tf = match op.builtin.map(|b| b.get_type()) {
|
|
Some(ty) => ty,
|
|
None => return TypeError::new("no type found"),
|
|
};
|
|
|
|
let t_lhs = self.expr(lhs)?;
|
|
let t_rhs = self.expr(rhs)?; //TODO is this order a problem? not sure
|
|
|
|
self.handle_apply(tf, vec![t_lhs, t_rhs])
|
|
}
|
|
|
|
fn if_expr(&mut self, discriminator: &Discriminator, body: &IfExpressionBody) -> InferResult<Type> {
|
|
use self::Discriminator::*; use self::IfExpressionBody::*;
|
|
match (discriminator, body) {
|
|
(Simple(expr), SimpleConditional(then_clause, else_clause)) => self.handle_simple_if(expr.node(), then_clause, else_clause),
|
|
_ => TypeError::new(format!("Complex conditionals not supported"))
|
|
}
|
|
}
|
|
|
|
fn handle_simple_if(&mut self, expr: &Expression, then_clause: &Block, else_clause: &Option<Block>) -> InferResult<Type> {
|
|
let t1 = self.expr(expr)?;
|
|
let t2 = self.block(then_clause)?;
|
|
let t3 = match else_clause {
|
|
Some(block) => self.block(block)?,
|
|
None => ty!(Unit)
|
|
};
|
|
|
|
let _ = self.unify(ty!(Bool), t1)?;
|
|
self.unify(t2, t3)
|
|
}
|
|
|
|
fn lambda(&mut self, params: &Vec<FormalParam>, type_anno: &Option<TypeIdentifier>, _body: &Block) -> InferResult<Type> {
|
|
let argument_types: InferResult<Vec<Type>> = params.iter().map(|param: &FormalParam| {
|
|
if let FormalParam { anno: Some(type_identifier), .. } = param {
|
|
self.get_type_from_name(type_identifier)
|
|
} else {
|
|
Ok(Type::Var(self.fresh_type_variable()))
|
|
}
|
|
}).collect();
|
|
let argument_types = argument_types?;
|
|
let ret_type = match type_anno.as_ref() {
|
|
Some(anno) => self.get_type_from_name(anno)?,
|
|
None => Type::Var(self.fresh_type_variable())
|
|
};
|
|
|
|
Ok(ty!(argument_types, ret_type))
|
|
}
|
|
|
|
fn call(&mut self, f: &Meta<Expression>, args: &Vec<InvocationArgument>) -> InferResult<Type> {
|
|
let tf = self.expr(f.node())?;
|
|
let arg_types: InferResult<Vec<Type>> = args.iter().map(|ex| self.invoc(ex)).collect();
|
|
let arg_types = arg_types?;
|
|
self.handle_apply(tf, arg_types)
|
|
}
|
|
|
|
fn handle_apply(&mut self, tf: Type, args: Vec<Type>) -> InferResult<Type> {
|
|
Ok(match tf {
|
|
Type::Arrow { ref params, ret: box ref t_ret } if params.len() == args.len() => {
|
|
for (t_param, t_arg) in params.iter().zip(args.iter()) {
|
|
let _ = self.unify(t_param.clone(), t_arg.clone())?; //TODO I think this needs to reference a sub-scope
|
|
}
|
|
t_ret.clone()
|
|
},
|
|
Type::Arrow { .. } => return TypeError::new("Wrong length"),
|
|
_ => return TypeError::new(format!("Not a function"))
|
|
})
|
|
}
|
|
|
|
fn block(&mut self, block: &Block) -> InferResult<Type> {
|
|
let mut output = ty!(Unit);
|
|
for s in block.iter() {
|
|
let statement = s.node();
|
|
output = self.statement(statement)?;
|
|
}
|
|
Ok(output)
|
|
}
|
|
|
|
fn handle_value(&mut self, val: &QualifiedName) -> InferResult<Type> {
|
|
let QualifiedName(vec) = val;
|
|
let var = &vec[0];
|
|
match self.variable_map.lookup(var) {
|
|
Some(ty) => Ok(ty.clone()),
|
|
None => TypeError::new(format!("Couldn't find variable: {}", &var)),
|
|
}
|
|
}
|
|
|
|
fn unify(&mut self, t1: Type, t2: Type) -> InferResult<Type> {
|
|
use self::Type::*;
|
|
|
|
match (t1, t2) {
|
|
(Const(ref c1), Const(ref c2)) if c1 == c2 => Ok(Const(c1.clone())), //choice of c1 is arbitrary I *think*
|
|
(a @ Var(_), b @ Const(_)) => self.unify(b, a),
|
|
(Const(ref c1), Var(ref v2)) => {
|
|
self.unification_table.unify_var_value(v2.clone(), Some(c1.clone()))
|
|
.or_else(|_| TypeError::new(format!("Couldn't unify {:?} and {:?}", Const(c1.clone()), Var(*v2))))?;
|
|
Ok(Const(c1.clone()))
|
|
},
|
|
(Var(v1), Var(v2)) => {
|
|
//TODO add occurs check
|
|
self.unification_table.unify_var_var(v1.clone(), v2.clone())
|
|
.or_else(|e| {
|
|
println!("Unify error: {:?}", e);
|
|
TypeError::new(format!("Two type variables {:?} and {:?} couldn't unify", v1, v2))
|
|
})?;
|
|
Ok(Var(v1.clone())) //arbitrary decision I think
|
|
},
|
|
(a, b) => TypeError::new(format!("{:?} and {:?} do not unify", a, b)),
|
|
}
|
|
}
|
|
|
|
fn fresh_type_variable(&mut self) -> TypeVar {
|
|
let new_type_var = self.unification_table.new_key(None);
|
|
new_type_var
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod typechecking_tests {
|
|
use super::*;
|
|
|
|
macro_rules! assert_type_in_fresh_context {
|
|
($string:expr, $type:expr) => {
|
|
let mut tc = TypeContext::new();
|
|
let ref ast = crate::util::quick_ast($string);
|
|
let ty = tc.typecheck(ast).unwrap();
|
|
assert_eq!(ty, $type)
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn basic_test() {
|
|
assert_type_in_fresh_context!("1", ty!(Nat));
|
|
assert_type_in_fresh_context!(r#""drugs""#, ty!(StringT));
|
|
assert_type_in_fresh_context!("true", ty!(Bool));
|
|
}
|
|
|
|
#[test]
|
|
fn operators() {
|
|
//TODO fix these with new operator regime
|
|
/*
|
|
assert_type_in_fresh_context!("-1", ty!(Int));
|
|
assert_type_in_fresh_context!("1 + 2", ty!(Nat));
|
|
assert_type_in_fresh_context!("-2", ty!(Int));
|
|
assert_type_in_fresh_context!("!true", ty!(Bool));
|
|
*/
|
|
}
|
|
}
|