diff --git a/schala-lang/language/src/reduced_ir/mod.rs b/schala-lang/language/src/reduced_ir/mod.rs index 84ae351..b2923da 100644 --- a/schala-lang/language/src/reduced_ir/mod.rs +++ b/schala-lang/language/src/reduced_ir/mod.rs @@ -1,12 +1,13 @@ -use crate::ast; -use crate::symbol_table::{DefId, SymbolSpec, SymbolTable}; -use crate::builtin::Builtin; +use std::{collections::HashMap, str::FromStr}; -use std::str::FromStr; -use std::collections::HashMap; +use crate::{ + ast, + builtin::Builtin, + symbol_table::{DefId, SymbolSpec, SymbolTable}, +}; -mod types; mod test; +mod types; pub use types::*; @@ -22,10 +23,7 @@ struct Reducer<'a> { impl<'a> Reducer<'a> { fn new(symbol_table: &'a SymbolTable) -> Self { - Self { - symbol_table, - functions: HashMap::new(), - } + Self { symbol_table, functions: HashMap::new() } } fn reduce(mut self, ast: &ast::AST) -> ReducedIR { @@ -43,20 +41,26 @@ impl<'a> Reducer<'a> { match &kind { ast::StatementKind::Expression(expr) => { entrypoint.push(Statement::Expression(self.expression(expr))); - }, - ast::StatementKind::Declaration(ast::Declaration::Binding { name: _, constant, expr, ..}) => { + } + ast::StatementKind::Declaration(ast::Declaration::Binding { + name: _, + constant, + expr, + .. + }) => { let symbol = self.symbol_table.lookup_symbol(item_id).unwrap(); let def_id = symbol.def_id().unwrap(); - entrypoint.push(Statement::Binding { id: def_id, constant: *constant, expr: self.expression(expr) }); - }, - _ => () + entrypoint.push(Statement::Binding { + id: def_id, + constant: *constant, + expr: self.expression(expr), + }); + } + _ => (), } } - ReducedIR { - functions: self.functions, - entrypoint, - } + ReducedIR { functions: self.functions, entrypoint } } fn top_level_statement(&mut self, statement: &ast::Statement) { @@ -65,12 +69,11 @@ impl<'a> Reducer<'a> { ast::StatementKind::Expression(_expr) => { //TODO expressions can in principle contain definitions, but I won't worry //about it now - }, - ast::StatementKind::Declaration(decl) => { + } + ast::StatementKind::Declaration(decl) => if let ast::Declaration::FuncDecl(_, statements) = decl { self.insert_function_definition(item_id, statements); - } - }, + }, ast::StatementKind::Import(..) => (), ast::StatementKind::Module(_modspec) => { //TODO handle modules @@ -81,32 +84,28 @@ impl<'a> Reducer<'a> { fn function_internal_statement(&mut self, statement: &ast::Statement) -> Option { let ast::Statement { id: item_id, kind, .. } = statement; match kind { - ast::StatementKind::Expression(expr) => { - Some(Statement::Expression(self.expression(expr))) - }, + ast::StatementKind::Expression(expr) => Some(Statement::Expression(self.expression(expr))), ast::StatementKind::Declaration(decl) => match decl { ast::Declaration::FuncDecl(_, statements) => { self.insert_function_definition(item_id, statements); None - }, - ast::Declaration::Binding { constant, expr, ..} => { + } + ast::Declaration::Binding { constant, expr, .. } => { let symbol = self.symbol_table.lookup_symbol(item_id).unwrap(); let def_id = symbol.def_id().unwrap(); Some(Statement::Binding { id: def_id, constant: *constant, expr: self.expression(expr) }) - }, + } - _ => None + _ => None, }, - _ => None + _ => None, } } fn insert_function_definition(&mut self, item_id: &ast::ItemId, statements: &ast::Block) { let symbol = self.symbol_table.lookup_symbol(item_id).unwrap(); let def_id = symbol.def_id().unwrap(); - let function_def = FunctionDefinition { - body: self.function_internal_block(statements) - }; + let function_def = FunctionDefinition { body: self.function_internal_block(statements) }; self.functions.insert(def_id, function_def); } @@ -118,31 +117,25 @@ impl<'a> Reducer<'a> { FloatLiteral(f) => Expression::Literal(Literal::Float(*f)), StringLiteral(s) => Expression::Literal(Literal::StringLit(s.clone())), BoolLiteral(b) => Expression::Literal(Literal::Bool(*b)), - BinExp(binop, lhs, rhs) => self.binop(binop, lhs, rhs), + BinExp(binop, lhs, rhs) => self.binop(binop, lhs, rhs), PrefixExp(op, arg) => self.prefix(op, arg), Value(qualified_name) => self.value(qualified_name), - Call { f, arguments } => Expression::Call { + Call { f, arguments } => Expression::Call { f: Box::new(self.expression(f)), - args: arguments - .iter() - .map(|arg| self.invocation_argument(arg)) - .collect(), + args: arguments.iter().map(|arg| self.invocation_argument(arg)).collect(), }, TupleLiteral(exprs) => Expression::Tuple(exprs.iter().map(|e| self.expression(e)).collect()), - IfExpression { discriminator, body, } => self.reduce_if_expression(discriminator.as_ref().map(|x| x.as_ref()), body), - Lambda { params, body, .. } => { - Expression::Callable(Callable::Lambda { - arity: params.len() as u8, - body: self.function_internal_block(body), - }) - }, + IfExpression { discriminator, body } => + self.reduce_if_expression(discriminator.as_ref().map(|x| x.as_ref()), body), + Lambda { params, body, .. } => Expression::Callable(Callable::Lambda { + arity: params.len() as u8, + body: self.function_internal_block(body), + }), NamedStruct { name, fields } => { let symbol = self.symbol_table.lookup_symbol(&name.id).unwrap(); let constructor = match symbol.spec() { - SymbolSpec::RecordConstructor { tag, members: _, type_id } => Expression::Callable(Callable::RecordConstructor { - type_id, - tag, - }), + SymbolSpec::RecordConstructor { tag, members: _, type_id } => + Expression::Callable(Callable::RecordConstructor { type_id, tag }), e => return Expression::ReductionError(format!("Bad symbol for NamedStruct: {:?}", e)), }; @@ -153,11 +146,8 @@ impl<'a> Reducer<'a> { unimplemented!() } - Expression::Call { - f: Box::new(constructor), - args: ordered_args, - } - }, + Expression::Call { f: Box::new(constructor), args: ordered_args } + } Index { .. } => Expression::ReductionError("Index expr not implemented".to_string()), WhileExpression { .. } => Expression::ReductionError("While expr not implemented".to_string()), ForExpression { .. } => Expression::ReductionError("For expr not implemented".to_string()), @@ -165,34 +155,27 @@ impl<'a> Reducer<'a> { } } - fn reduce_if_expression(&mut self, discriminator: Option<&ast::Expression>, body: &ast::IfExpressionBody) -> Expression { - use ast::IfExpressionBody::*; + fn reduce_if_expression( + &mut self, + discriminator: Option<&ast::Expression>, + body: &ast::IfExpressionBody, + ) -> Expression { + use ast::IfExpressionBody::*; let cond = Box::new(match discriminator { Some(expr) => self.expression(expr), None => return Expression::ReductionError("blank cond if-expr not supported".to_string()), }); match body { - SimpleConditional { - then_case, - else_case, - } => { + SimpleConditional { then_case, else_case } => { let then_clause = self.function_internal_block(then_case); let else_clause = match else_case.as_ref() { None => vec![], Some(stmts) => self.function_internal_block(stmts), }; - Expression::Conditional { - cond, - then_clause, - else_clause, - } - }, - SimplePatternMatch { - pattern, - then_case, - else_case, - } => { + Expression::Conditional { cond, then_clause, else_clause } + } + SimplePatternMatch { pattern, then_case, else_case } => { let alternatives = vec![ Alternative { pattern: match pattern.reduce(self.symbol_table) { @@ -211,24 +194,28 @@ impl<'a> Reducer<'a> { ]; Expression::CaseMatch { cond, alternatives } - }, + } CondList(ref condition_arms) => { let mut alternatives = vec![]; for arm in condition_arms { match arm.condition { - ast::Condition::Expression(ref _expr) => return Expression::ReductionError("case-expression".to_string()), + ast::Condition::Expression(ref _expr) => + return Expression::ReductionError("case-expression".to_string()), ast::Condition::Pattern(ref pat) => { let alt = Alternative { pattern: match pat.reduce(self.symbol_table) { Ok(p) => p, - Err(e) => return Expression::ReductionError(format!("Bad pattern: {:?}", e)), + Err(e) => + return Expression::ReductionError(format!("Bad pattern: {:?}", e)), }, item: self.function_internal_block(&arm.body), }; alternatives.push(alt); - }, - ast::Condition::TruncatedOp(_, _) => return Expression::ReductionError("case-expression-trunc-op".to_string()), - ast::Condition::Else => return Expression::ReductionError("case-expression-else".to_string()), + } + ast::Condition::TruncatedOp(_, _) => + return Expression::ReductionError("case-expression-trunc-op".to_string()), + ast::Condition::Else => + return Expression::ReductionError("case-expression-else".to_string()), } } Expression::CaseMatch { cond, alternatives } @@ -241,7 +228,7 @@ impl<'a> Reducer<'a> { match invoc { Positional(ex) => self.expression(ex), Keyword { .. } => Expression::ReductionError("Keyword arguments not supported".to_string()), - Ignored => Expression::ReductionError("Ignored arguments not supported".to_string()), + Ignored => Expression::ReductionError("Ignored arguments not supported".to_string()), } } @@ -252,12 +239,10 @@ impl<'a> Reducer<'a> { fn prefix(&mut self, prefix: &ast::PrefixOp, arg: &ast::Expression) -> Expression { let builtin: Option = TryFrom::try_from(prefix).ok(); match builtin { - Some(op) => { - Expression::Call { - f: Box::new(Expression::Callable(Callable::Builtin(op))), - args: vec![self.expression(arg)], - } - } + Some(op) => Expression::Call { + f: Box::new(Expression::Callable(Callable::Builtin(op))), + args: vec![self.expression(arg)], + }, None => { //TODO need this for custom prefix ops Expression::ReductionError("User-defined prefix ops not supported".to_string()) @@ -278,21 +263,18 @@ impl<'a> Reducer<'a> { } else { return ReductionError(format!("Couldn't look up name: {:?}", qualified_name)); } - }, + } _ => return ReductionError("Trying to assign to a non-name".to_string()), }; - Expression::Assign { - lval, - rval: Box::new(self.expression(rhs)), - } - }, + Expression::Assign { lval, rval: Box::new(self.expression(rhs)) } + } Some(op) => Expression::Call { f: Box::new(Expression::Callable(Callable::Builtin(op))), args: vec![self.expression(lhs), self.expression(rhs)], }, //TODO handle a user-defined operation - None => ReductionError("User-defined operations not supported".to_string()) + None => ReductionError("User-defined operations not supported".to_string()), } } @@ -301,7 +283,8 @@ impl<'a> Reducer<'a> { let symbol = match self.symbol_table.lookup_symbol(&qualified_name.id) { Some(s) => s, - None => return Expression::ReductionError(format!("No symbol found for name: {:?}", qualified_name)) + None => + return Expression::ReductionError(format!("No symbol found for name: {:?}", qualified_name)), }; let def_id = symbol.def_id(); @@ -311,14 +294,12 @@ impl<'a> Reducer<'a> { GlobalBinding => Expression::Lookup(Lookup::GlobalVar(def_id.unwrap())), LocalVariable => Expression::Lookup(Lookup::LocalVar(def_id.unwrap())), FunctionParam(n) => Expression::Lookup(Lookup::Param(n)), - DataConstructor { tag, arity, type_id } => Expression::Callable(Callable::DataConstructor { - type_id, - arity: arity as u32, - tag, - }), - RecordConstructor { .. } => { - Expression::ReductionError(format!("The symbol for value {:?} is unexpectdly a RecordConstructor", qualified_name)) - }, + DataConstructor { tag, type_id } => + Expression::Callable(Callable::DataConstructor { type_id, tag }), + RecordConstructor { .. } => Expression::ReductionError(format!( + "The symbol for value {:?} is unexpectdly a RecordConstructor", + qualified_name + )), } } } @@ -328,60 +309,53 @@ impl ast::Pattern { Ok(match self { ast::Pattern::Ignored => Pattern::Ignored, ast::Pattern::TuplePattern(subpatterns) => { - let items: Result, PatternError> = subpatterns.iter() - .map(|pat| pat.reduce(symbol_table)).into_iter().collect(); + let items: Result, PatternError> = + subpatterns.iter().map(|pat| pat.reduce(symbol_table)).into_iter().collect(); let items = items?; - Pattern::Tuple { - tag: None, - subpatterns: items, - } - }, + Pattern::Tuple { tag: None, subpatterns: items } + } ast::Pattern::Literal(lit) => Pattern::Literal(match lit { - ast::PatternLiteral::NumPattern { neg, num } => match (neg, num) { + ast::PatternLiteral::NumPattern { neg, num } => match (neg, num) { (false, ast::ExpressionKind::NatLiteral(n)) => Literal::Nat(*n), (false, ast::ExpressionKind::FloatLiteral(f)) => Literal::Float(*f), (true, ast::ExpressionKind::NatLiteral(n)) => Literal::Int(-(*n as i64)), (true, ast::ExpressionKind::FloatLiteral(f)) => Literal::Float(-f), - (_, e) => return Err(format!("Internal error, unexpected pattern literal: {:?}", e).into()) + (_, e) => + return Err(format!("Internal error, unexpected pattern literal: {:?}", e).into()), }, ast::PatternLiteral::StringPattern(s) => Literal::StringLit(s.clone()), ast::PatternLiteral::BoolPattern(b) => Literal::Bool(*b), }), ast::Pattern::TupleStruct(name, subpatterns) => { let symbol = symbol_table.lookup_symbol(&name.id).unwrap(); - if let SymbolSpec::DataConstructor { tag, type_id: _, arity: _ } = symbol.spec() { - let items: Result, PatternError> = subpatterns.iter().map(|pat| pat.reduce(symbol_table)) - .into_iter().collect(); + if let SymbolSpec::DataConstructor { tag, type_id: _ } = symbol.spec() { + let items: Result, PatternError> = + subpatterns.iter().map(|pat| pat.reduce(symbol_table)).into_iter().collect(); let items = items?; - Pattern::Tuple { - tag: Some(tag), - subpatterns: items, - } + Pattern::Tuple { tag: Some(tag), subpatterns: items } } else { - return Err("Internal error, trying to match something that's not a DataConstructor".into()); + return Err( + "Internal error, trying to match something that's not a DataConstructor".into() + ); } - }, + } ast::Pattern::VarOrName(name) => { let symbol = symbol_table.lookup_symbol(&name.id).unwrap(); match symbol.spec() { - SymbolSpec::DataConstructor { tag, type_id: _, arity: _ } => { - Pattern::Tuple { - tag: Some(tag), - subpatterns: vec![] - } - }, + SymbolSpec::DataConstructor { tag, type_id: _ } => + Pattern::Tuple { tag: Some(tag), subpatterns: vec![] }, SymbolSpec::LocalVariable => { let def_id = symbol.def_id().unwrap(); Pattern::Binding(def_id) - }, - spec => return Err(format!("Unexpected VarOrName symbol: {:?}", spec).into()) + } + spec => return Err(format!("Unexpected VarOrName symbol: {:?}", spec).into()), } - }, - ast::Pattern::Record(name, _specified_members/*Vec<(Rc, Pattern)>*/) => { + } + ast::Pattern::Record(name, _specified_members /*Vec<(Rc, Pattern)>*/) => { let symbol = symbol_table.lookup_symbol(&name.id).unwrap(); match symbol.spec() { SymbolSpec::RecordConstructor { tag: _, members: _, type_id: _ } => unimplemented!(), - spec => return Err(format!("Unexpected Record pattern symbol: {:?}", spec).into()) + spec => return Err(format!("Unexpected Record pattern symbol: {:?}", spec).into()), } } }) diff --git a/schala-lang/language/src/reduced_ir/types.rs b/schala-lang/language/src/reduced_ir/types.rs index 807b10c..5ef73aa 100644 --- a/schala-lang/language/src/reduced_ir/types.rs +++ b/schala-lang/language/src/reduced_ir/types.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; -use std::rc::Rc; -use std::convert::From; +use std::{collections::HashMap, convert::From, rc::Rc}; -use crate::builtin::Builtin; -use crate::symbol_table::{DefId, SymbolTable}; -use crate::type_inference::TypeId; +use crate::{ + builtin::Builtin, + symbol_table::{DefId, SymbolTable}, + type_inference::TypeId, +}; //TODO most of these Clone impls only exist to support function application, because the //tree-walking evaluator moves the reduced IR members. @@ -41,11 +41,7 @@ impl ReducedIR { #[derive(Debug, Clone)] pub enum Statement { Expression(Expression), - Binding { - id: DefId, - constant: bool, - expr: Expression - }, + Binding { id: DefId, constant: bool, expr: Expression }, } #[derive(Debug, Clone)] @@ -53,24 +49,11 @@ pub enum Expression { Literal(Literal), Tuple(Vec), Lookup(Lookup), - Assign { - lval: DefId, - rval: Box, - }, + Assign { lval: DefId, rval: Box }, Callable(Callable), - Call { - f: Box, - args: Vec - }, - Conditional { - cond: Box, - then_clause: Vec, - else_clause: Vec, - }, - CaseMatch { - cond: Box, - alternatives: Vec, - }, + Call { f: Box, args: Vec }, + Conditional { cond: Box, then_clause: Vec, else_clause: Vec }, + CaseMatch { cond: Box, alternatives: Vec }, ReductionError(String), } @@ -82,26 +65,16 @@ impl Expression { #[derive(Debug)] pub struct FunctionDefinition { - pub body: Vec + pub body: Vec, } #[derive(Debug, Clone)] pub enum Callable { Builtin(Builtin), UserDefined(DefId), - Lambda { - arity: u8, - body: Vec - }, - DataConstructor { - type_id: TypeId, - arity: u32, - tag: u32 - }, - RecordConstructor { - type_id: TypeId, - tag: u32, - }, + Lambda { arity: u8, body: Vec }, + DataConstructor { type_id: TypeId, tag: u32 }, + RecordConstructor { type_id: TypeId, tag: u32 }, } #[derive(Debug, Clone)] @@ -129,13 +102,10 @@ pub struct Alternative { #[derive(Debug, Clone)] pub enum Pattern { - Tuple { - subpatterns: Vec, - tag: Option, - }, + Tuple { subpatterns: Vec, tag: Option }, Literal(Literal), Ignored, - Binding(DefId) + Binding(DefId), } #[allow(dead_code)] diff --git a/schala-lang/language/src/symbol_table/mod.rs b/schala-lang/language/src/symbol_table/mod.rs index 9089840..4e9a948 100644 --- a/schala-lang/language/src/symbol_table/mod.rs +++ b/schala-lang/language/src/symbol_table/mod.rs @@ -257,9 +257,7 @@ impl fmt::Display for Symbol { #[derive(Debug, Clone)] pub enum SymbolSpec { Func, - // The tag and arity here are *surface* tags, computed from the order in which they were - // defined. The type context may create a different ordering. - DataConstructor { tag: u32, arity: usize, type_id: TypeId }, + DataConstructor { tag: u32, type_id: TypeId }, RecordConstructor { tag: u32, members: HashMap, TypeId>, type_id: TypeId }, GlobalBinding, //Only for global variables, not for function-local ones or ones within a `let` scope context LocalVariable, @@ -271,8 +269,7 @@ impl fmt::Display for SymbolSpec { use self::SymbolSpec::*; match self { Func => write!(f, "Func"), - DataConstructor { tag, type_id, arity } => - write!(f, "DataConstructor(tag: {}, arity: {}, type: {})", tag, arity, type_id), + DataConstructor { tag, type_id } => write!(f, "DataConstructor(tag: {}, type: {})", tag, type_id), RecordConstructor { type_id, tag, .. } => write!(f, "RecordConstructor(tag: {})( -> {})", tag, type_id), GlobalBinding => write!(f, "GlobalBinding"), @@ -484,10 +481,8 @@ impl<'a> SymbolTableRunner<'a> { let id = fqsn_id_map.get(&fqsn).unwrap(); let tag = index as u32; let spec = match &variant.members { - type_inference::VariantMembers::Unit => - SymbolSpec::DataConstructor { tag, arity: 0, type_id }, - type_inference::VariantMembers::Tuple(items) => - SymbolSpec::DataConstructor { tag, arity: items.len(), type_id }, + type_inference::VariantMembers::Unit => SymbolSpec::DataConstructor { tag, type_id }, + type_inference::VariantMembers::Tuple(..) => SymbolSpec::DataConstructor { tag, type_id }, type_inference::VariantMembers::Record(..) => SymbolSpec::RecordConstructor { tag, members: HashMap::new(), type_id }, }; diff --git a/schala-lang/language/src/tree_walk_eval/evaluator.rs b/schala-lang/language/src/tree_walk_eval/evaluator.rs index 20c5a15..b1e0a29 100644 --- a/schala-lang/language/src/tree_walk_eval/evaluator.rs +++ b/schala-lang/language/src/tree_walk_eval/evaluator.rs @@ -106,8 +106,14 @@ impl<'a, 'b> Evaluator<'a, 'b> { Primitive::unit() } Expression::Call { box f, args } => self.call_expression(f, args)?, - Expression::Callable(Callable::DataConstructor { type_id, arity, tag }) if arity == 0 => - Primitive::Object { type_id, tag, items: vec![] }, + Expression::Callable(Callable::DataConstructor { type_id, tag }) => { + let arity = self.type_context.lookup_variant_arity(&type_id, tag).unwrap(); + if arity == 0 { + Primitive::Object { type_id, tag, items: vec![] } + } else { + Primitive::Callable(Callable::DataConstructor { type_id, tag }) + } + } Expression::Callable(func) => Primitive::Callable(func), Expression::Conditional { box cond, then_clause, else_clause } => { let cond = self.expression(cond)?; @@ -205,7 +211,8 @@ impl<'a, 'b> Evaluator<'a, 'b> { } self.apply_function(body, args) } - Callable::DataConstructor { type_id, arity, tag } => { + Callable::DataConstructor { type_id, tag } => { + let arity = self.type_context.lookup_variant_arity(&type_id, tag).unwrap(); if arity as usize != args.len() { return Err(format!( "Constructor expression requries {} arguments, only {} provided", diff --git a/schala-lang/language/src/type_inference/mod.rs b/schala-lang/language/src/type_inference/mod.rs index ad04b82..9a8dd0a 100644 --- a/schala-lang/language/src/type_inference/mod.rs +++ b/schala-lang/language/src/type_inference/mod.rs @@ -31,12 +31,10 @@ impl TypeContext { let record_variant = matches!(members.get(0).unwrap(), VariantMemberBuilder::KeyVal(..)); if record_variant { - let pending_members = members - .into_iter() - .map(|var| match var { - VariantMemberBuilder::KeyVal(name, ty) => (name, ty), - _ => panic!("Compiler internal error: variant mismatch"), - }); + let pending_members = members.into_iter().map(|var| match var { + VariantMemberBuilder::KeyVal(name, ty) => (name, ty), + _ => panic!("Compiler internal error: variant mismatch"), + }); //TODO make this mapping meaningful let type_ids = pending_members @@ -46,12 +44,10 @@ impl TypeContext { pending_variants .push(Variant { name: variant_builder.name, members: VariantMembers::Record(type_ids) }); } else { - let pending_members = members - .into_iter() - .map(|var| match var { - VariantMemberBuilder::Pending(pending_type) => pending_type, - _ => panic!("Compiler internal error: variant mismatch"), - }); + let pending_members = members.into_iter().map(|var| match var { + VariantMemberBuilder::Pending(pending_type) => pending_type, + _ => panic!("Compiler internal error: variant mismatch"), + }); //TODO make this mapping meaningful let type_ids = pending_members.into_iter().map(|_ty_id| self.type_id_store.fresh()).collect(); @@ -78,6 +74,16 @@ impl TypeContext { .map(|variant| variant.name.as_ref()) } + pub fn lookup_variant_arity(&self, type_id: &TypeId, tag: u32) -> Option { + self.defined_types.get(type_id).and_then(|defined| defined.variants.get(tag as usize)).map( + |variant| match &variant.members { + VariantMembers::Unit => 0, + VariantMembers::Tuple(items) => items.len() as u32, + VariantMembers::Record(items) => items.len() as u32, + }, + ) + } + pub fn lookup_type(&self, type_id: &TypeId) -> Option<&DefinedType> { self.defined_types.get(type_id) }