schala/schala-lang/src/type_inference/mod.rs

228 lines
6.8 KiB
Rust

use std::{collections::HashMap, convert::From};
use crate::{
ast::{TypeIdentifier, AST},
identifier::{define_id_kind, Id, IdStore},
};
define_id_kind!(TypeItem);
pub type TypeId = Id<TypeItem>;
pub struct TypeContext {
defined_types: HashMap<TypeId, DefinedType>,
type_id_store: IdStore<TypeItem>,
}
impl TypeContext {
pub fn new() -> Self {
Self { defined_types: HashMap::new(), type_id_store: IdStore::new() }
}
pub fn register_type(&mut self, builder: TypeBuilder) -> TypeId {
let type_id = self.type_id_store.fresh();
let mut pending_variants = vec![];
for variant_builder in builder.variants.into_iter() {
let members = variant_builder.members;
if members.is_empty() {
pending_variants.push(Variant { name: variant_builder.name, members: VariantMembers::Unit });
continue;
}
let record_variant = matches!(members.get(0).unwrap(), VariantMemberBuilder::KeyVal(..));
if record_variant {
let pending_members = members.into_iter().map(|var| match var {
VariantMemberBuilder::KeyVal(name, ty) => (name, ty),
_ => panic!("Compiler internal error: variant mismatch"),
});
//TODO make this mapping meaningful
let type_ids = pending_members
.into_iter()
.map(|(name, _ty_id)| (name, self.type_id_store.fresh()))
.collect();
pending_variants
.push(Variant { name: variant_builder.name, members: VariantMembers::Record(type_ids) });
} else {
let pending_members = members.into_iter().map(|var| match var {
VariantMemberBuilder::Pending(pending_type) => pending_type,
_ => panic!("Compiler internal error: variant mismatch"),
});
//TODO make this mapping meaningful
let type_ids = pending_members.into_iter().map(|_ty_id| self.type_id_store.fresh()).collect();
pending_variants
.push(Variant { name: variant_builder.name, members: VariantMembers::Tuple(type_ids) });
}
}
// Eventually, I will want to have a better way of determining which numeric tag goes with
// which variant. For now, just sort them alphabetically.
pending_variants.sort_unstable_by(|a, b| a.name.cmp(&b.name));
let defined = DefinedType { name: builder.name, variants: pending_variants };
self.defined_types.insert(type_id, defined);
type_id
}
pub fn variant_local_name(&self, type_id: &TypeId, tag: u32) -> Option<&str> {
self.defined_types
.get(type_id)
.and_then(|defined| defined.variants.get(tag as usize))
.map(|variant| variant.name.as_ref())
}
pub fn lookup_variant_arity(&self, type_id: &TypeId, tag: u32) -> Option<u32> {
self.defined_types.get(type_id).and_then(|defined| defined.variants.get(tag as usize)).map(
|variant| match &variant.members {
VariantMembers::Unit => 0,
VariantMembers::Tuple(items) => items.len() as u32,
VariantMembers::Record(items) => items.len() as u32,
},
)
}
pub fn lookup_record_members(&self, type_id: &TypeId, tag: u32) -> Option<&[(String, TypeId)]> {
self.defined_types.get(type_id).and_then(|defined| defined.variants.get(tag as usize)).and_then(
|variant| match &variant.members {
VariantMembers::Record(items) => Some(items.as_ref()),
_ => None,
},
)
}
pub fn lookup_type(&self, type_id: &TypeId) -> Option<&DefinedType> {
self.defined_types.get(type_id)
}
//TODO return some kind of overall type later?
pub fn typecheck(&mut self, ast: &AST) -> Result<(), TypeError> {
Ok(())
}
}
/// A type defined in program source code, as opposed to a builtin.
#[allow(dead_code)]
#[derive(Debug)]
pub struct DefinedType {
pub name: String,
// the variants are in this list according to tag order
pub variants: Vec<Variant>,
}
#[derive(Debug)]
pub struct Variant {
pub name: String,
pub members: VariantMembers,
}
#[derive(Debug)]
pub enum VariantMembers {
Unit,
// Should be non-empty
Tuple(Vec<TypeId>),
Record(Vec<(String, TypeId)>),
}
/// Represents a type mentioned as a member of another type during the type registration process.
/// It may not have been registered itself in the relevant context.
#[allow(dead_code)]
#[derive(Debug)]
pub struct PendingType {
inner: TypeIdentifier,
}
impl From<&TypeIdentifier> for PendingType {
fn from(type_identifier: &TypeIdentifier) -> Self {
Self { inner: type_identifier.clone() }
}
}
#[derive(Debug)]
pub struct TypeBuilder {
name: String,
variants: Vec<VariantBuilder>,
}
impl TypeBuilder {
pub fn new(name: &str) -> Self {
Self { name: name.to_string(), variants: vec![] }
}
pub fn add_variant(&mut self, vb: VariantBuilder) {
self.variants.push(vb);
}
}
#[derive(Debug)]
pub struct VariantBuilder {
name: String,
members: Vec<VariantMemberBuilder>,
}
impl VariantBuilder {
pub fn new(name: &str) -> Self {
Self { name: name.to_string(), members: vec![] }
}
pub fn add_member(&mut self, member_ty: PendingType) {
self.members.push(VariantMemberBuilder::Pending(member_ty));
}
// You can't call this and `add_member` on the same fn, there should be a runtime error when
// that's detected.
pub fn add_record_member(&mut self, name: &str, ty: PendingType) {
self.members.push(VariantMemberBuilder::KeyVal(name.to_string(), ty));
}
}
#[derive(Debug)]
enum VariantMemberBuilder {
Pending(PendingType),
KeyVal(String, PendingType),
}
#[derive(Debug, Clone)]
pub struct TypeError {
pub msg: String,
}
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TypeConst {
Unit,
Nat,
Int,
Float,
StringT,
Bool,
Ordering,
}
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq)]
pub enum Type {
Const(TypeConst),
//Var(TypeVar),
Arrow { params: Vec<Type>, ret: Box<Type> },
Compound { ty_name: String, args: Vec<Type> },
}
macro_rules! ty {
($type_name:ident) => {
Type::Const(crate::type_inference::TypeConst::$type_name)
};
($t1:ident -> $t2:ident) => {
Type::Arrow { params: vec![ty!($t1)], ret: box ty!($t2) }
};
($t1:ident -> $t2:ident -> $t3:ident) => {
Type::Arrow { params: vec![ty!($t1), ty!($t2)], ret: box ty!($t3) }
};
($type_list:ident, $ret_type:ident) => {
Type::Arrow { params: $type_list, ret: box $ret_type }
};
}