2017-10-08 12:22:04 -07:00
use std ::collections ::HashMap ;
2017-10-08 13:51:56 -07:00
use std ::rc ::Rc ;
2017-10-08 12:22:04 -07:00
2017-10-11 02:33:46 -07:00
2018-03-23 18:43:43 -07:00
use parsing ::{ AST , Statement , Declaration , Signature , Expression , ExpressionType , Operation , Variant , TypeName , TypeSingletonName } ;
2017-10-04 02:07:30 -07:00
2017-10-10 01:11:24 -07:00
// from Niko's talk
/* fn type_check(expression, expected_ty) -> Ty {
let ty = bare_type_check ( expression , expected_type ) ;
if ty icompatible with expected_ty {
try_coerce ( expression , ty , expected_ty )
} else {
ty
}
}
fn bare_type_check ( exprssion , expected_type ) -> Ty { .. . }
* /
2018-02-11 02:37:52 -08:00
/* H-M ALGO NOTES
from https ://www.youtube.com/watch?v=il3gD7XMdmA
( also check out http ://dev.stephendiehl.com/fun/006_hindley_milner.html)
typeInfer :: Expr a -> Matching ( Type a )
unify :: Type a -> Type b -> Matching ( Type c )
( Matching a ) is a monad in which unification is done
ex :
typeInfer ( If e1 e2 e3 ) = do
t1 < - typeInfer e1
t2 < - typeInfer e2
t3 < - typeInfer e3
_ < - unify t1 BoolType
unify t2 t3 - - b / c t2 and t3 have to be the same type
typeInfer ( Const ( ConstInt _ ) ) = IntType - - same for other literals
- - function application
typeInfer ( Apply f x ) = do
tf < - typeInfer f
tx < - typeInfer x
case tf of
FunctionType t1 t2 -> do
_ < - unify t1 tx
return t2
_ -> fail " Not a function "
- - type annotation
typeInfer ( Typed x t ) = do
tx < - typeInfer x
unify tx t
- - variable and let expressions - need to pass around a map of variable names to types here
typeInfer :: [ ( Var , Type Var ) ] -> Expr Var -> Matching ( Type Var )
typeInfer ctx ( Var x ) = case ( lookup x ctx ) of
Just t -> return t
Nothing -> fail " Unknown variable "
- - let x = e1 in e2
typeInfer ctx ( Let x e1 e2 ) = do
t1 < - typeInfer ctx e1
typeInfer ( ( x , t1 ) :: ctx ) e2
- - lambdas are complicated ( this represents ʎ x . e )
typeInfer ctx ( Lambda x e ) = do
t1 < - allocExistentialVariable
t2 < - typeInfer ( ( x , t1 ) :: ctx ) e
return $ FunctionType t1 t2 - - ie . t1 -> t2
- - to solve the problem of map :: ( a -> b ) -> [ a ] -> [ b ]
when we use a variable whose type has universal tvars , convert those universal
tvars to existential ones
- and each distinct universal tvar needs to map to the same existential type
- so we change typeinfer :
typeInfer ctx ( Var x ) = do
case ( lookup x ctx ) of
Nothing -> .. .
Just t -> do
let uvars = nub ( toList t ) - - nub removes duplicates , so this gets unique universally quantified variables
evars < - mapM ( const allocExistentialVariable ) uvars
let varMap = zip uvars evars
let vixVar varMap v = fromJust $ lookup v varMap
return ( fmap ( fixVar varMap ) t )
- - how do we define unify ? ?
- recall , type signature is :
unify :: Type a -> Type b -> Matching ( Type c )
unify BoolType BoolType = BoolType - - easy , same for all constants
unify ( FunctionType t1 t2 ) ( FunctionType t3 t4 ) = do
t5 < - unify t1 t3
t6 < - unify t2 t4
return $ FunctionType t5 t6
unify ( TVar a ) ( TVar b ) = if a = = b then TVar a else fail
- - existential types can be assigned another type at most once
- - some complicated stuff about hanlding existential types
- - everything else is a type error
unify a b = fail
SKOLEMIZATION - how you prevent an unassigned existential type variable from leaking !
- before a type gets to global scope , replace all unassigned existential vars w / new unique universal
type variables
* /
2017-10-10 01:11:24 -07:00
2017-10-10 02:17:07 -07:00
#[ derive(Debug, PartialEq, Clone) ]
pub enum Type {
TVar ( TypeVar ) ,
TConst ( TypeConst ) ,
2017-10-11 01:55:45 -07:00
TFunc ( Box < Type > , Box < Type > ) ,
2017-10-10 02:17:07 -07:00
}
2017-10-10 01:11:24 -07:00
#[ derive(Debug, PartialEq, Clone) ]
2017-10-10 02:17:07 -07:00
pub enum TypeVar {
2017-10-10 04:26:40 -07:00
Univ ( Rc < String > ) ,
2017-10-10 01:11:24 -07:00
Exist ( u64 ) ,
}
2017-10-10 04:26:40 -07:00
impl TypeVar {
fn univ ( label : & str ) -> TypeVar {
TypeVar ::Univ ( Rc ::new ( label . to_string ( ) ) )
}
}
2017-10-10 01:11:24 -07:00
#[ derive(Debug, PartialEq, Clone) ]
2017-10-10 02:17:07 -07:00
pub enum TypeConst {
2017-10-10 17:29:28 -07:00
UserT ( Rc < String > ) ,
2017-10-10 01:11:24 -07:00
Integer ,
Float ,
2017-10-10 02:17:07 -07:00
StringT ,
2017-10-10 01:11:24 -07:00
Boolean ,
Unit ,
Bottom ,
}
2017-10-10 02:17:07 -07:00
type TypeCheckResult = Result < Type , String > ;
2017-10-10 01:11:24 -07:00
2017-10-08 12:22:04 -07:00
#[ derive(Debug, PartialEq, Eq, Hash) ]
2017-10-09 04:02:50 -07:00
struct PathSpecifier ( Rc < String > ) ;
2017-10-09 02:38:33 -07:00
#[ derive(Debug, PartialEq, Clone) ]
struct TypeContextEntry {
2017-10-10 22:14:55 -07:00
ty : Type ,
2017-10-09 02:38:33 -07:00
constant : bool
2017-10-08 13:51:56 -07:00
}
2017-10-08 12:22:04 -07:00
2017-10-09 00:59:52 -07:00
pub struct TypeContext {
2017-10-09 02:38:33 -07:00
symbol_table : HashMap < PathSpecifier , TypeContextEntry > ,
2017-10-11 02:11:12 -07:00
evar_table : HashMap < u64 , Type > ,
2017-10-09 02:26:59 -07:00
existential_type_label_count : u64
2017-10-07 21:57:51 -07:00
}
2017-10-09 00:59:52 -07:00
impl TypeContext {
pub fn new ( ) -> TypeContext {
2017-10-09 02:26:59 -07:00
TypeContext {
symbol_table : HashMap ::new ( ) ,
2017-10-11 02:11:12 -07:00
evar_table : HashMap ::new ( ) ,
2017-10-09 02:26:59 -07:00
existential_type_label_count : 0 ,
}
2017-10-07 21:57:51 -07:00
}
2017-10-09 00:59:52 -07:00
pub fn add_symbols ( & mut self , ast : & AST ) {
2017-10-08 13:51:56 -07:00
use self ::Declaration ::* ;
2017-10-10 17:29:28 -07:00
use self ::Type ::* ;
use self ::TypeConst ::* ;
2017-10-08 13:51:56 -07:00
for statement in ast . 0. iter ( ) {
2017-10-09 00:59:52 -07:00
match * statement {
Statement ::ExpressionStatement ( _ ) = > ( ) ,
2018-02-08 01:15:27 -08:00
Statement ::Declaration ( ref decl ) = > match * decl {
FuncSig ( _ ) = > ( ) ,
Impl { .. } = > ( ) ,
TypeDecl ( ref type_constructor , ref body ) = > {
for variant in body . 0. iter ( ) {
let ( spec , ty ) = match variant {
& Variant ::UnitStruct ( ref data_constructor ) = > {
let spec = PathSpecifier ( data_constructor . clone ( ) ) ;
2018-02-12 00:51:53 -08:00
let ty = TConst ( UserT ( type_constructor . name . clone ( ) ) ) ;
2018-02-08 01:15:27 -08:00
( spec , ty )
} ,
& Variant ::TupleStruct ( ref data_construcor , ref args ) = > {
//TODO fix
let arg = args . get ( 0 ) . unwrap ( ) ;
let type_arg = self . from_anno ( arg ) ;
let spec = PathSpecifier ( data_construcor . clone ( ) ) ;
2018-02-12 00:51:53 -08:00
let ty = TFunc ( Box ::new ( type_arg ) , Box ::new ( TConst ( UserT ( type_constructor . name . clone ( ) ) ) ) ) ;
2018-02-08 01:15:27 -08:00
( spec , ty )
} ,
& Variant ::Record ( _ , _ ) = > unimplemented! ( ) ,
} ;
2017-10-10 22:14:55 -07:00
let entry = TypeContextEntry { ty , constant : true } ;
2017-10-09 02:38:33 -07:00
self . symbol_table . insert ( spec , entry ) ;
2018-02-08 01:15:27 -08:00
}
} ,
TypeAlias { .. } = > ( ) ,
Binding { ref name , ref constant , ref expr } = > {
let spec = PathSpecifier ( name . clone ( ) ) ;
let ty = expr . 1. as_ref ( )
. map ( | ty | self . from_anno ( ty ) )
. unwrap_or_else ( | | { self . alloc_existential_type ( ) } ) ; // this call to alloc_existential is OK b/c a binding only ever has one type, so if the annotation is absent, it's fine to just make one de novo
let entry = TypeContextEntry { ty , constant : * constant } ;
self . symbol_table . insert ( spec , entry ) ;
} ,
FuncDecl ( ref signature , _ ) = > {
let spec = PathSpecifier ( signature . name . clone ( ) ) ;
let ty = self . from_signature ( signature ) ;
let entry = TypeContextEntry { ty , constant : true } ;
self . symbol_table . insert ( spec , entry ) ;
} ,
2017-10-08 13:51:56 -07:00
}
}
}
2017-10-07 21:57:51 -07:00
}
2017-10-09 02:38:33 -07:00
fn lookup ( & mut self , binding : & Rc < String > ) -> Option < TypeContextEntry > {
2017-10-09 04:02:50 -07:00
let key = PathSpecifier ( binding . clone ( ) ) ;
2017-10-09 02:38:33 -07:00
self . symbol_table . get ( & key ) . map ( | entry | entry . clone ( ) )
2017-10-08 23:33:53 -07:00
}
2017-10-08 13:57:43 -07:00
pub fn debug_symbol_table ( & self ) -> String {
2017-10-11 02:33:46 -07:00
format! ( " Symbol table: \n {:?} \n Evar table: \n {:?} " , self . symbol_table , self . evar_table )
2017-10-08 13:57:43 -07:00
}
2017-10-11 01:50:04 -07:00
fn alloc_existential_type ( & mut self ) -> Type {
2017-10-10 02:17:07 -07:00
let ret = Type ::TVar ( TypeVar ::Exist ( self . existential_type_label_count ) ) ;
2017-10-09 02:26:59 -07:00
self . existential_type_label_count + = 1 ;
ret
}
2017-10-10 02:17:07 -07:00
fn from_anno ( & mut self , anno : & TypeName ) -> Type {
use self ::Type ::* ;
use self ::TypeConst ::* ;
2017-10-09 02:26:59 -07:00
match anno {
2018-02-12 00:25:48 -08:00
& TypeName ::Singleton ( TypeSingletonName { ref name , .. } ) = > {
2017-10-09 02:26:59 -07:00
match name . as_ref ( ) . as_ref ( ) {
2017-10-10 02:17:07 -07:00
" Int " = > TConst ( Integer ) ,
2017-10-13 00:01:43 -07:00
" Float " = > TConst ( Float ) ,
2017-10-10 02:17:07 -07:00
" Bool " = > TConst ( Boolean ) ,
2017-10-10 21:51:45 -07:00
" String " = > TConst ( StringT ) ,
2017-10-11 01:50:04 -07:00
s = > TVar ( TypeVar ::Univ ( Rc ::new ( format! ( " {} " , s ) ) ) ) ,
2017-10-09 02:26:59 -07:00
}
} ,
2017-10-10 21:51:45 -07:00
& TypeName ::Tuple ( ref items ) = > {
if items . len ( ) = = 1 {
TConst ( Unit )
} else {
TConst ( Bottom )
}
}
2017-10-09 02:26:59 -07:00
}
}
2017-10-10 02:17:07 -07:00
fn from_signature ( & mut self , sig : & Signature ) -> Type {
use self ::Type ::* ;
use self ::TypeConst ::* ;
2017-10-11 01:50:04 -07:00
//TODO this won't work properly until you make sure that all (universal) type vars in the function have the same existential type var
// actually this should never even put existential types into the symbol table at all
//this will crash if more than 5 arg function is used
let names = vec! [ " a " , " b " , " c " , " d " , " e " , " f " ] ;
let mut idx = 0 ;
let mut get_type = | | { let q = TVar ( TypeVar ::Univ ( Rc ::new ( format! ( " {} " , names . get ( idx ) . unwrap ( ) ) ) ) ) ; idx + = 1 ; q } ;
let return_type = sig . type_anno . as_ref ( ) . map ( | anno | self . from_anno ( & anno ) ) . unwrap_or_else ( | | { get_type ( ) } ) ;
2017-10-09 02:26:59 -07:00
if sig . params . len ( ) = = 0 {
2017-10-11 01:55:45 -07:00
TFunc ( Box ::new ( TConst ( Unit ) ) , Box ::new ( return_type ) )
2017-10-09 02:26:59 -07:00
} else {
let mut output_type = return_type ;
for p in sig . params . iter ( ) {
2017-10-11 01:50:04 -07:00
let p_type = p . 1. as_ref ( ) . map ( | anno | self . from_anno ( anno ) ) . unwrap_or_else ( | | { get_type ( ) } ) ;
2017-10-11 01:55:45 -07:00
output_type = TFunc ( Box ::new ( p_type ) , Box ::new ( output_type ) ) ;
2017-10-09 02:26:59 -07:00
}
output_type
}
}
2017-10-04 02:07:30 -07:00
pub fn type_check ( & mut self , ast : & AST ) -> TypeCheckResult {
2017-10-10 02:17:07 -07:00
use self ::Type ::* ;
use self ::TypeConst ::* ;
let mut last = TConst ( Unit ) ;
2017-10-04 02:07:30 -07:00
for statement in ast . 0. iter ( ) {
match statement {
& Statement ::Declaration ( ref _decl ) = > {
2017-10-08 22:48:10 -07:00
//return Err(format!("Declarations not supported"));
2017-10-04 02:07:30 -07:00
} ,
& Statement ::ExpressionStatement ( ref expr ) = > {
2017-10-08 22:48:10 -07:00
last = self . infer ( expr ) ? ;
2017-10-04 02:07:30 -07:00
}
}
}
2017-10-08 22:48:10 -07:00
Ok ( last )
2017-10-04 02:07:30 -07:00
}
2017-10-11 01:50:04 -07:00
fn infer ( & mut self , expr : & Expression ) -> TypeCheckResult {
2017-10-12 20:14:33 -07:00
match ( & expr . 0 , & expr . 1 ) {
( exprtype , & Some ( ref anno ) ) = > {
let tx = self . infer_no_anno ( exprtype ) ? ;
2017-10-11 01:50:04 -07:00
let ty = self . from_anno ( anno ) ;
2017-10-12 20:14:33 -07:00
self . unify ( tx , ty )
2017-10-11 01:50:04 -07:00
} ,
2017-10-12 20:14:33 -07:00
( exprtype , & None ) = > self . infer_no_anno ( exprtype ) ,
}
}
fn infer_no_anno ( & mut self , ex : & ExpressionType ) -> TypeCheckResult {
2017-10-12 21:46:12 -07:00
use self ::ExpressionType ::* ;
use self ::Type ::* ;
use self ::TypeConst ::* ;
2017-10-12 20:14:33 -07:00
Ok ( match ex {
& IntLiteral ( _ ) = > TConst ( Integer ) ,
2017-10-12 23:59:52 -07:00
& FloatLiteral ( _ ) = > TConst ( Float ) ,
& StringLiteral ( _ ) = > TConst ( StringT ) ,
2017-10-12 20:14:33 -07:00
& BoolLiteral ( _ ) = > TConst ( Boolean ) ,
2017-10-14 13:54:17 -07:00
& Value ( ref name , _ ) = > {
2017-10-11 01:50:04 -07:00
self . lookup ( name )
. map ( | entry | entry . ty )
. ok_or ( format! ( " Couldn't find {} " , name ) ) ?
} ,
2017-10-12 23:59:52 -07:00
& BinExp ( ref op , ref lhs , ref rhs ) = > {
let t_lhs = self . infer ( lhs ) ? ;
match self . infer_op ( op ) ? {
TFunc ( t1 , t2 ) = > {
let _ = self . unify ( t_lhs , * t1 ) ? ;
let t_rhs = self . infer ( rhs ) ? ;
let x = * t2 ;
match x {
TFunc ( t3 , t4 ) = > {
let _ = self . unify ( t_rhs , * t3 ) ? ;
* t4
} ,
_ = > return Err ( format! ( " Not a function type either " ) ) ,
}
} ,
_ = > return Err ( format! ( " Op {:?} is not a function type " , op ) ) ,
}
} ,
2017-10-12 20:14:33 -07:00
& Call { ref f , ref arguments } = > {
2017-10-11 01:50:04 -07:00
let tf = self . infer ( f ) ? ;
let targ = self . infer ( arguments . get ( 0 ) . unwrap ( ) ) ? ;
match tf {
2017-10-11 01:55:45 -07:00
TFunc ( box t1 , box t2 ) = > {
2017-10-12 23:59:52 -07:00
let _ = self . unify ( t1 , targ ) ? ;
2017-10-11 01:50:04 -07:00
t2
} ,
_ = > return Err ( format! ( " Not a function! " ) ) ,
}
} ,
_ = > TConst ( Bottom ) ,
} )
}
2017-10-10 01:04:19 -07:00
2017-10-12 23:59:52 -07:00
fn infer_op ( & mut self , op : & Operation ) -> TypeCheckResult {
2017-10-10 04:38:59 -07:00
use self ::Type ::* ;
use self ::TypeConst ::* ;
2017-10-12 23:59:52 -07:00
macro_rules ! binoptype {
( $lhs :expr , $rhs :expr , $out :expr ) = > { TFunc ( Box ::new ( $lhs ) , Box ::new ( TFunc ( Box ::new ( $rhs ) , Box ::new ( $out ) ) ) ) } ;
}
Ok ( match ( * op . 0 ) . as_ref ( ) {
" + " = > binoptype! ( TConst ( Integer ) , TConst ( Integer ) , TConst ( Integer ) ) ,
" ++ " = > binoptype! ( TConst ( StringT ) , TConst ( StringT ) , TConst ( StringT ) ) ,
" - " = > binoptype! ( TConst ( Integer ) , TConst ( Integer ) , TConst ( Integer ) ) ,
" * " = > binoptype! ( TConst ( Integer ) , TConst ( Integer ) , TConst ( Integer ) ) ,
" / " = > binoptype! ( TConst ( Integer ) , TConst ( Integer ) , TConst ( Integer ) ) ,
" % " = > binoptype! ( TConst ( Integer ) , TConst ( Integer ) , TConst ( Integer ) ) ,
_ = > TConst ( Bottom )
} )
}
fn unify ( & mut self , t1 : Type , t2 : Type ) -> TypeCheckResult {
use self ::Type ::* ;
2017-10-11 02:11:12 -07:00
use self ::TypeVar ::* ;
2017-10-10 04:38:59 -07:00
2017-10-12 23:59:52 -07:00
println! ( " Calling unify with ` {:?} ` and ` {:?} ` " , t1 , t2 ) ;
2017-10-10 04:38:59 -07:00
match ( & t1 , & t2 ) {
( & TConst ( ref c1 ) , & TConst ( ref c2 ) ) if c1 = = c2 = > Ok ( TConst ( c1 . clone ( ) ) ) ,
2017-10-11 02:03:50 -07:00
( & TFunc ( ref t1 , ref t2 ) , & TFunc ( ref t3 , ref t4 ) ) = > {
let t5 = self . unify ( * t1 . clone ( ) . clone ( ) , * t3 . clone ( ) . clone ( ) ) ? ;
let t6 = self . unify ( * t2 . clone ( ) . clone ( ) , * t4 . clone ( ) . clone ( ) ) ? ;
Ok ( TFunc ( Box ::new ( t5 ) , Box ::new ( t6 ) ) )
} ,
2017-10-11 02:11:12 -07:00
( & TVar ( Univ ( ref a ) ) , & TVar ( Univ ( ref b ) ) ) = > {
if a = = b {
Ok ( TVar ( Univ ( a . clone ( ) ) ) )
} else {
Err ( format! ( " Couldn't unify universal types {} and {} " , a , b ) )
}
} ,
2017-10-11 02:33:46 -07:00
//the interesting case!!
( & TVar ( Exist ( ref a ) ) , ref t2 ) = > {
let x = self . evar_table . get ( a ) . map ( | x | x . clone ( ) ) ;
match x {
Some ( ref t1 ) = > self . unify ( t1 . clone ( ) . clone ( ) , t2 . clone ( ) . clone ( ) ) ,
None = > {
self . evar_table . insert ( * a , t2 . clone ( ) . clone ( ) ) ;
Ok ( t2 . clone ( ) . clone ( ) )
}
}
} ,
( ref t1 , & TVar ( Exist ( ref a ) ) ) = > {
let x = self . evar_table . get ( a ) . map ( | x | x . clone ( ) ) ;
match x {
Some ( ref t2 ) = > self . unify ( t2 . clone ( ) . clone ( ) , t1 . clone ( ) . clone ( ) ) ,
None = > {
self . evar_table . insert ( * a , t1 . clone ( ) . clone ( ) ) ;
Ok ( t1 . clone ( ) . clone ( ) )
}
2017-10-11 02:11:12 -07:00
}
} ,
2017-10-10 04:38:59 -07:00
_ = > Err ( format! ( " Types {:?} and {:?} don't unify " , t1 , t2 ) )
2017-10-08 23:33:53 -07:00
}
}
}
2017-10-08 22:48:10 -07:00
2017-10-09 12:26:25 -07:00
#[ cfg(test) ]
mod tests {
2017-10-10 02:17:07 -07:00
use super ::{ Type , TypeVar , TypeConst , TypeContext } ;
use super ::Type ::* ;
use super ::TypeConst ::* ;
2017-10-09 12:26:25 -07:00
use schala_lang ::parsing ::{ parse , tokenize } ;
macro_rules ! type_test {
( $input :expr , $correct :expr ) = > {
{
let mut tc = TypeContext ::new ( ) ;
let ast = parse ( tokenize ( $input ) ) . 0. unwrap ( ) ;
2017-10-11 16:43:04 -07:00
tc . add_symbols ( & ast ) ;
2017-10-09 12:26:25 -07:00
assert_eq! ( $correct , tc . type_check ( & ast ) . unwrap ( ) )
}
}
}
#[ test ]
fn basic_inference ( ) {
2017-10-10 02:17:07 -07:00
type_test! ( " 30 " , TConst ( Integer ) ) ;
2017-10-11 16:43:04 -07:00
type_test! ( " fn x(a: Int): Bool {}; x(1) " , TConst ( Boolean ) ) ;
2017-10-09 12:26:25 -07:00
}
}