Remove arity from ReducedIR, symbol table

Instead look this up via the type context
This commit is contained in:
Greg Shuflin 2021-10-29 19:00:27 -07:00
parent 6b9ca92e00
commit 304df5c50e
5 changed files with 154 additions and 202 deletions

View File

@ -1,12 +1,13 @@
use crate::ast;
use crate::symbol_table::{DefId, SymbolSpec, SymbolTable};
use crate::builtin::Builtin;
use std::{collections::HashMap, str::FromStr};
use std::str::FromStr;
use std::collections::HashMap;
use crate::{
ast,
builtin::Builtin,
symbol_table::{DefId, SymbolSpec, SymbolTable},
};
mod types;
mod test;
mod types;
pub use types::*;
@ -22,10 +23,7 @@ struct Reducer<'a> {
impl<'a> Reducer<'a> {
fn new(symbol_table: &'a SymbolTable) -> Self {
Self {
symbol_table,
functions: HashMap::new(),
}
Self { symbol_table, functions: HashMap::new() }
}
fn reduce(mut self, ast: &ast::AST) -> ReducedIR {
@ -43,20 +41,26 @@ impl<'a> Reducer<'a> {
match &kind {
ast::StatementKind::Expression(expr) => {
entrypoint.push(Statement::Expression(self.expression(expr)));
},
ast::StatementKind::Declaration(ast::Declaration::Binding { name: _, constant, expr, ..}) => {
}
ast::StatementKind::Declaration(ast::Declaration::Binding {
name: _,
constant,
expr,
..
}) => {
let symbol = self.symbol_table.lookup_symbol(item_id).unwrap();
let def_id = symbol.def_id().unwrap();
entrypoint.push(Statement::Binding { id: def_id, constant: *constant, expr: self.expression(expr) });
},
_ => ()
entrypoint.push(Statement::Binding {
id: def_id,
constant: *constant,
expr: self.expression(expr),
});
}
_ => (),
}
}
ReducedIR {
functions: self.functions,
entrypoint,
}
ReducedIR { functions: self.functions, entrypoint }
}
fn top_level_statement(&mut self, statement: &ast::Statement) {
@ -65,12 +69,11 @@ impl<'a> Reducer<'a> {
ast::StatementKind::Expression(_expr) => {
//TODO expressions can in principle contain definitions, but I won't worry
//about it now
},
ast::StatementKind::Declaration(decl) => {
}
ast::StatementKind::Declaration(decl) =>
if let ast::Declaration::FuncDecl(_, statements) = decl {
self.insert_function_definition(item_id, statements);
}
},
},
ast::StatementKind::Import(..) => (),
ast::StatementKind::Module(_modspec) => {
//TODO handle modules
@ -81,32 +84,28 @@ impl<'a> Reducer<'a> {
fn function_internal_statement(&mut self, statement: &ast::Statement) -> Option<Statement> {
let ast::Statement { id: item_id, kind, .. } = statement;
match kind {
ast::StatementKind::Expression(expr) => {
Some(Statement::Expression(self.expression(expr)))
},
ast::StatementKind::Expression(expr) => Some(Statement::Expression(self.expression(expr))),
ast::StatementKind::Declaration(decl) => match decl {
ast::Declaration::FuncDecl(_, statements) => {
self.insert_function_definition(item_id, statements);
None
},
ast::Declaration::Binding { constant, expr, ..} => {
}
ast::Declaration::Binding { constant, expr, .. } => {
let symbol = self.symbol_table.lookup_symbol(item_id).unwrap();
let def_id = symbol.def_id().unwrap();
Some(Statement::Binding { id: def_id, constant: *constant, expr: self.expression(expr) })
},
}
_ => None
_ => None,
},
_ => None
_ => None,
}
}
fn insert_function_definition(&mut self, item_id: &ast::ItemId, statements: &ast::Block) {
let symbol = self.symbol_table.lookup_symbol(item_id).unwrap();
let def_id = symbol.def_id().unwrap();
let function_def = FunctionDefinition {
body: self.function_internal_block(statements)
};
let function_def = FunctionDefinition { body: self.function_internal_block(statements) };
self.functions.insert(def_id, function_def);
}
@ -118,31 +117,25 @@ impl<'a> Reducer<'a> {
FloatLiteral(f) => Expression::Literal(Literal::Float(*f)),
StringLiteral(s) => Expression::Literal(Literal::StringLit(s.clone())),
BoolLiteral(b) => Expression::Literal(Literal::Bool(*b)),
BinExp(binop, lhs, rhs) => self.binop(binop, lhs, rhs),
BinExp(binop, lhs, rhs) => self.binop(binop, lhs, rhs),
PrefixExp(op, arg) => self.prefix(op, arg),
Value(qualified_name) => self.value(qualified_name),
Call { f, arguments } => Expression::Call {
Call { f, arguments } => Expression::Call {
f: Box::new(self.expression(f)),
args: arguments
.iter()
.map(|arg| self.invocation_argument(arg))
.collect(),
args: arguments.iter().map(|arg| self.invocation_argument(arg)).collect(),
},
TupleLiteral(exprs) => Expression::Tuple(exprs.iter().map(|e| self.expression(e)).collect()),
IfExpression { discriminator, body, } => self.reduce_if_expression(discriminator.as_ref().map(|x| x.as_ref()), body),
Lambda { params, body, .. } => {
Expression::Callable(Callable::Lambda {
arity: params.len() as u8,
body: self.function_internal_block(body),
})
},
IfExpression { discriminator, body } =>
self.reduce_if_expression(discriminator.as_ref().map(|x| x.as_ref()), body),
Lambda { params, body, .. } => Expression::Callable(Callable::Lambda {
arity: params.len() as u8,
body: self.function_internal_block(body),
}),
NamedStruct { name, fields } => {
let symbol = self.symbol_table.lookup_symbol(&name.id).unwrap();
let constructor = match symbol.spec() {
SymbolSpec::RecordConstructor { tag, members: _, type_id } => Expression::Callable(Callable::RecordConstructor {
type_id,
tag,
}),
SymbolSpec::RecordConstructor { tag, members: _, type_id } =>
Expression::Callable(Callable::RecordConstructor { type_id, tag }),
e => return Expression::ReductionError(format!("Bad symbol for NamedStruct: {:?}", e)),
};
@ -153,11 +146,8 @@ impl<'a> Reducer<'a> {
unimplemented!()
}
Expression::Call {
f: Box::new(constructor),
args: ordered_args,
}
},
Expression::Call { f: Box::new(constructor), args: ordered_args }
}
Index { .. } => Expression::ReductionError("Index expr not implemented".to_string()),
WhileExpression { .. } => Expression::ReductionError("While expr not implemented".to_string()),
ForExpression { .. } => Expression::ReductionError("For expr not implemented".to_string()),
@ -165,34 +155,27 @@ impl<'a> Reducer<'a> {
}
}
fn reduce_if_expression(&mut self, discriminator: Option<&ast::Expression>, body: &ast::IfExpressionBody) -> Expression {
use ast::IfExpressionBody::*;
fn reduce_if_expression(
&mut self,
discriminator: Option<&ast::Expression>,
body: &ast::IfExpressionBody,
) -> Expression {
use ast::IfExpressionBody::*;
let cond = Box::new(match discriminator {
Some(expr) => self.expression(expr),
None => return Expression::ReductionError("blank cond if-expr not supported".to_string()),
});
match body {
SimpleConditional {
then_case,
else_case,
} => {
SimpleConditional { then_case, else_case } => {
let then_clause = self.function_internal_block(then_case);
let else_clause = match else_case.as_ref() {
None => vec![],
Some(stmts) => self.function_internal_block(stmts),
};
Expression::Conditional {
cond,
then_clause,
else_clause,
}
},
SimplePatternMatch {
pattern,
then_case,
else_case,
} => {
Expression::Conditional { cond, then_clause, else_clause }
}
SimplePatternMatch { pattern, then_case, else_case } => {
let alternatives = vec![
Alternative {
pattern: match pattern.reduce(self.symbol_table) {
@ -211,24 +194,28 @@ impl<'a> Reducer<'a> {
];
Expression::CaseMatch { cond, alternatives }
},
}
CondList(ref condition_arms) => {
let mut alternatives = vec![];
for arm in condition_arms {
match arm.condition {
ast::Condition::Expression(ref _expr) => return Expression::ReductionError("case-expression".to_string()),
ast::Condition::Expression(ref _expr) =>
return Expression::ReductionError("case-expression".to_string()),
ast::Condition::Pattern(ref pat) => {
let alt = Alternative {
pattern: match pat.reduce(self.symbol_table) {
Ok(p) => p,
Err(e) => return Expression::ReductionError(format!("Bad pattern: {:?}", e)),
Err(e) =>
return Expression::ReductionError(format!("Bad pattern: {:?}", e)),
},
item: self.function_internal_block(&arm.body),
};
alternatives.push(alt);
},
ast::Condition::TruncatedOp(_, _) => return Expression::ReductionError("case-expression-trunc-op".to_string()),
ast::Condition::Else => return Expression::ReductionError("case-expression-else".to_string()),
}
ast::Condition::TruncatedOp(_, _) =>
return Expression::ReductionError("case-expression-trunc-op".to_string()),
ast::Condition::Else =>
return Expression::ReductionError("case-expression-else".to_string()),
}
}
Expression::CaseMatch { cond, alternatives }
@ -241,7 +228,7 @@ impl<'a> Reducer<'a> {
match invoc {
Positional(ex) => self.expression(ex),
Keyword { .. } => Expression::ReductionError("Keyword arguments not supported".to_string()),
Ignored => Expression::ReductionError("Ignored arguments not supported".to_string()),
Ignored => Expression::ReductionError("Ignored arguments not supported".to_string()),
}
}
@ -252,12 +239,10 @@ impl<'a> Reducer<'a> {
fn prefix(&mut self, prefix: &ast::PrefixOp, arg: &ast::Expression) -> Expression {
let builtin: Option<Builtin> = TryFrom::try_from(prefix).ok();
match builtin {
Some(op) => {
Expression::Call {
f: Box::new(Expression::Callable(Callable::Builtin(op))),
args: vec![self.expression(arg)],
}
}
Some(op) => Expression::Call {
f: Box::new(Expression::Callable(Callable::Builtin(op))),
args: vec![self.expression(arg)],
},
None => {
//TODO need this for custom prefix ops
Expression::ReductionError("User-defined prefix ops not supported".to_string())
@ -278,21 +263,18 @@ impl<'a> Reducer<'a> {
} else {
return ReductionError(format!("Couldn't look up name: {:?}", qualified_name));
}
},
}
_ => return ReductionError("Trying to assign to a non-name".to_string()),
};
Expression::Assign {
lval,
rval: Box::new(self.expression(rhs)),
}
},
Expression::Assign { lval, rval: Box::new(self.expression(rhs)) }
}
Some(op) => Expression::Call {
f: Box::new(Expression::Callable(Callable::Builtin(op))),
args: vec![self.expression(lhs), self.expression(rhs)],
},
//TODO handle a user-defined operation
None => ReductionError("User-defined operations not supported".to_string())
None => ReductionError("User-defined operations not supported".to_string()),
}
}
@ -301,7 +283,8 @@ impl<'a> Reducer<'a> {
let symbol = match self.symbol_table.lookup_symbol(&qualified_name.id) {
Some(s) => s,
None => return Expression::ReductionError(format!("No symbol found for name: {:?}", qualified_name))
None =>
return Expression::ReductionError(format!("No symbol found for name: {:?}", qualified_name)),
};
let def_id = symbol.def_id();
@ -311,14 +294,12 @@ impl<'a> Reducer<'a> {
GlobalBinding => Expression::Lookup(Lookup::GlobalVar(def_id.unwrap())),
LocalVariable => Expression::Lookup(Lookup::LocalVar(def_id.unwrap())),
FunctionParam(n) => Expression::Lookup(Lookup::Param(n)),
DataConstructor { tag, arity, type_id } => Expression::Callable(Callable::DataConstructor {
type_id,
arity: arity as u32,
tag,
}),
RecordConstructor { .. } => {
Expression::ReductionError(format!("The symbol for value {:?} is unexpectdly a RecordConstructor", qualified_name))
},
DataConstructor { tag, type_id } =>
Expression::Callable(Callable::DataConstructor { type_id, tag }),
RecordConstructor { .. } => Expression::ReductionError(format!(
"The symbol for value {:?} is unexpectdly a RecordConstructor",
qualified_name
)),
}
}
}
@ -328,60 +309,53 @@ impl ast::Pattern {
Ok(match self {
ast::Pattern::Ignored => Pattern::Ignored,
ast::Pattern::TuplePattern(subpatterns) => {
let items: Result<Vec<Pattern>, PatternError> = subpatterns.iter()
.map(|pat| pat.reduce(symbol_table)).into_iter().collect();
let items: Result<Vec<Pattern>, PatternError> =
subpatterns.iter().map(|pat| pat.reduce(symbol_table)).into_iter().collect();
let items = items?;
Pattern::Tuple {
tag: None,
subpatterns: items,
}
},
Pattern::Tuple { tag: None, subpatterns: items }
}
ast::Pattern::Literal(lit) => Pattern::Literal(match lit {
ast::PatternLiteral::NumPattern { neg, num } => match (neg, num) {
ast::PatternLiteral::NumPattern { neg, num } => match (neg, num) {
(false, ast::ExpressionKind::NatLiteral(n)) => Literal::Nat(*n),
(false, ast::ExpressionKind::FloatLiteral(f)) => Literal::Float(*f),
(true, ast::ExpressionKind::NatLiteral(n)) => Literal::Int(-(*n as i64)),
(true, ast::ExpressionKind::FloatLiteral(f)) => Literal::Float(-f),
(_, e) => return Err(format!("Internal error, unexpected pattern literal: {:?}", e).into())
(_, e) =>
return Err(format!("Internal error, unexpected pattern literal: {:?}", e).into()),
},
ast::PatternLiteral::StringPattern(s) => Literal::StringLit(s.clone()),
ast::PatternLiteral::BoolPattern(b) => Literal::Bool(*b),
}),
ast::Pattern::TupleStruct(name, subpatterns) => {
let symbol = symbol_table.lookup_symbol(&name.id).unwrap();
if let SymbolSpec::DataConstructor { tag, type_id: _, arity: _ } = symbol.spec() {
let items: Result<Vec<Pattern>, PatternError> = subpatterns.iter().map(|pat| pat.reduce(symbol_table))
.into_iter().collect();
if let SymbolSpec::DataConstructor { tag, type_id: _ } = symbol.spec() {
let items: Result<Vec<Pattern>, PatternError> =
subpatterns.iter().map(|pat| pat.reduce(symbol_table)).into_iter().collect();
let items = items?;
Pattern::Tuple {
tag: Some(tag),
subpatterns: items,
}
Pattern::Tuple { tag: Some(tag), subpatterns: items }
} else {
return Err("Internal error, trying to match something that's not a DataConstructor".into());
return Err(
"Internal error, trying to match something that's not a DataConstructor".into()
);
}
},
}
ast::Pattern::VarOrName(name) => {
let symbol = symbol_table.lookup_symbol(&name.id).unwrap();
match symbol.spec() {
SymbolSpec::DataConstructor { tag, type_id: _, arity: _ } => {
Pattern::Tuple {
tag: Some(tag),
subpatterns: vec![]
}
},
SymbolSpec::DataConstructor { tag, type_id: _ } =>
Pattern::Tuple { tag: Some(tag), subpatterns: vec![] },
SymbolSpec::LocalVariable => {
let def_id = symbol.def_id().unwrap();
Pattern::Binding(def_id)
},
spec => return Err(format!("Unexpected VarOrName symbol: {:?}", spec).into())
}
spec => return Err(format!("Unexpected VarOrName symbol: {:?}", spec).into()),
}
},
ast::Pattern::Record(name, _specified_members/*Vec<(Rc<String>, Pattern)>*/) => {
}
ast::Pattern::Record(name, _specified_members /*Vec<(Rc<String>, Pattern)>*/) => {
let symbol = symbol_table.lookup_symbol(&name.id).unwrap();
match symbol.spec() {
SymbolSpec::RecordConstructor { tag: _, members: _, type_id: _ } => unimplemented!(),
spec => return Err(format!("Unexpected Record pattern symbol: {:?}", spec).into())
spec => return Err(format!("Unexpected Record pattern symbol: {:?}", spec).into()),
}
}
})

View File

@ -1,10 +1,10 @@
use std::collections::HashMap;
use std::rc::Rc;
use std::convert::From;
use std::{collections::HashMap, convert::From, rc::Rc};
use crate::builtin::Builtin;
use crate::symbol_table::{DefId, SymbolTable};
use crate::type_inference::TypeId;
use crate::{
builtin::Builtin,
symbol_table::{DefId, SymbolTable},
type_inference::TypeId,
};
//TODO most of these Clone impls only exist to support function application, because the
//tree-walking evaluator moves the reduced IR members.
@ -41,11 +41,7 @@ impl ReducedIR {
#[derive(Debug, Clone)]
pub enum Statement {
Expression(Expression),
Binding {
id: DefId,
constant: bool,
expr: Expression
},
Binding { id: DefId, constant: bool, expr: Expression },
}
#[derive(Debug, Clone)]
@ -53,24 +49,11 @@ pub enum Expression {
Literal(Literal),
Tuple(Vec<Expression>),
Lookup(Lookup),
Assign {
lval: DefId,
rval: Box<Expression>,
},
Assign { lval: DefId, rval: Box<Expression> },
Callable(Callable),
Call {
f: Box<Expression>,
args: Vec<Expression>
},
Conditional {
cond: Box<Expression>,
then_clause: Vec<Statement>,
else_clause: Vec<Statement>,
},
CaseMatch {
cond: Box<Expression>,
alternatives: Vec<Alternative>,
},
Call { f: Box<Expression>, args: Vec<Expression> },
Conditional { cond: Box<Expression>, then_clause: Vec<Statement>, else_clause: Vec<Statement> },
CaseMatch { cond: Box<Expression>, alternatives: Vec<Alternative> },
ReductionError(String),
}
@ -82,26 +65,16 @@ impl Expression {
#[derive(Debug)]
pub struct FunctionDefinition {
pub body: Vec<Statement>
pub body: Vec<Statement>,
}
#[derive(Debug, Clone)]
pub enum Callable {
Builtin(Builtin),
UserDefined(DefId),
Lambda {
arity: u8,
body: Vec<Statement>
},
DataConstructor {
type_id: TypeId,
arity: u32,
tag: u32
},
RecordConstructor {
type_id: TypeId,
tag: u32,
},
Lambda { arity: u8, body: Vec<Statement> },
DataConstructor { type_id: TypeId, tag: u32 },
RecordConstructor { type_id: TypeId, tag: u32 },
}
#[derive(Debug, Clone)]
@ -129,13 +102,10 @@ pub struct Alternative {
#[derive(Debug, Clone)]
pub enum Pattern {
Tuple {
subpatterns: Vec<Pattern>,
tag: Option<u32>,
},
Tuple { subpatterns: Vec<Pattern>, tag: Option<u32> },
Literal(Literal),
Ignored,
Binding(DefId)
Binding(DefId),
}
#[allow(dead_code)]

View File

@ -257,9 +257,7 @@ impl fmt::Display for Symbol {
#[derive(Debug, Clone)]
pub enum SymbolSpec {
Func,
// The tag and arity here are *surface* tags, computed from the order in which they were
// defined. The type context may create a different ordering.
DataConstructor { tag: u32, arity: usize, type_id: TypeId },
DataConstructor { tag: u32, type_id: TypeId },
RecordConstructor { tag: u32, members: HashMap<Rc<String>, TypeId>, type_id: TypeId },
GlobalBinding, //Only for global variables, not for function-local ones or ones within a `let` scope context
LocalVariable,
@ -271,8 +269,7 @@ impl fmt::Display for SymbolSpec {
use self::SymbolSpec::*;
match self {
Func => write!(f, "Func"),
DataConstructor { tag, type_id, arity } =>
write!(f, "DataConstructor(tag: {}, arity: {}, type: {})", tag, arity, type_id),
DataConstructor { tag, type_id } => write!(f, "DataConstructor(tag: {}, type: {})", tag, type_id),
RecordConstructor { type_id, tag, .. } =>
write!(f, "RecordConstructor(tag: {})(<members> -> {})", tag, type_id),
GlobalBinding => write!(f, "GlobalBinding"),
@ -484,10 +481,8 @@ impl<'a> SymbolTableRunner<'a> {
let id = fqsn_id_map.get(&fqsn).unwrap();
let tag = index as u32;
let spec = match &variant.members {
type_inference::VariantMembers::Unit =>
SymbolSpec::DataConstructor { tag, arity: 0, type_id },
type_inference::VariantMembers::Tuple(items) =>
SymbolSpec::DataConstructor { tag, arity: items.len(), type_id },
type_inference::VariantMembers::Unit => SymbolSpec::DataConstructor { tag, type_id },
type_inference::VariantMembers::Tuple(..) => SymbolSpec::DataConstructor { tag, type_id },
type_inference::VariantMembers::Record(..) =>
SymbolSpec::RecordConstructor { tag, members: HashMap::new(), type_id },
};

View File

@ -106,8 +106,14 @@ impl<'a, 'b> Evaluator<'a, 'b> {
Primitive::unit()
}
Expression::Call { box f, args } => self.call_expression(f, args)?,
Expression::Callable(Callable::DataConstructor { type_id, arity, tag }) if arity == 0 =>
Primitive::Object { type_id, tag, items: vec![] },
Expression::Callable(Callable::DataConstructor { type_id, tag }) => {
let arity = self.type_context.lookup_variant_arity(&type_id, tag).unwrap();
if arity == 0 {
Primitive::Object { type_id, tag, items: vec![] }
} else {
Primitive::Callable(Callable::DataConstructor { type_id, tag })
}
}
Expression::Callable(func) => Primitive::Callable(func),
Expression::Conditional { box cond, then_clause, else_clause } => {
let cond = self.expression(cond)?;
@ -205,7 +211,8 @@ impl<'a, 'b> Evaluator<'a, 'b> {
}
self.apply_function(body, args)
}
Callable::DataConstructor { type_id, arity, tag } => {
Callable::DataConstructor { type_id, tag } => {
let arity = self.type_context.lookup_variant_arity(&type_id, tag).unwrap();
if arity as usize != args.len() {
return Err(format!(
"Constructor expression requries {} arguments, only {} provided",

View File

@ -31,12 +31,10 @@ impl TypeContext {
let record_variant = matches!(members.get(0).unwrap(), VariantMemberBuilder::KeyVal(..));
if record_variant {
let pending_members = members
.into_iter()
.map(|var| match var {
VariantMemberBuilder::KeyVal(name, ty) => (name, ty),
_ => panic!("Compiler internal error: variant mismatch"),
});
let pending_members = members.into_iter().map(|var| match var {
VariantMemberBuilder::KeyVal(name, ty) => (name, ty),
_ => panic!("Compiler internal error: variant mismatch"),
});
//TODO make this mapping meaningful
let type_ids = pending_members
@ -46,12 +44,10 @@ impl TypeContext {
pending_variants
.push(Variant { name: variant_builder.name, members: VariantMembers::Record(type_ids) });
} else {
let pending_members = members
.into_iter()
.map(|var| match var {
VariantMemberBuilder::Pending(pending_type) => pending_type,
_ => panic!("Compiler internal error: variant mismatch"),
});
let pending_members = members.into_iter().map(|var| match var {
VariantMemberBuilder::Pending(pending_type) => pending_type,
_ => panic!("Compiler internal error: variant mismatch"),
});
//TODO make this mapping meaningful
let type_ids = pending_members.into_iter().map(|_ty_id| self.type_id_store.fresh()).collect();
@ -78,6 +74,16 @@ impl TypeContext {
.map(|variant| variant.name.as_ref())
}
pub fn lookup_variant_arity(&self, type_id: &TypeId, tag: u32) -> Option<u32> {
self.defined_types.get(type_id).and_then(|defined| defined.variants.get(tag as usize)).map(
|variant| match &variant.members {
VariantMembers::Unit => 0,
VariantMembers::Tuple(items) => items.len() as u32,
VariantMembers::Record(items) => items.len() as u32,
},
)
}
pub fn lookup_type(&self, type_id: &TypeId) -> Option<&DefinedType> {
self.defined_types.get(type_id)
}