447 lines
20 KiB
Rust
447 lines
20 KiB
Rust
use std::{collections::HashMap, rc::Rc, str::FromStr};
|
|
|
|
use crate::{
|
|
ast,
|
|
builtin::Builtin,
|
|
symbol_table::{DefId, SymbolSpec, SymbolTable},
|
|
type_inference::{TypeContext, TypeId},
|
|
};
|
|
|
|
mod test;
|
|
mod types;
|
|
|
|
pub use types::*;
|
|
|
|
pub fn reduce(ast: &ast::AST, symbol_table: &SymbolTable, type_context: &TypeContext) -> ReducedIR {
|
|
let reducer = Reducer::new(symbol_table, type_context);
|
|
reducer.reduce(ast)
|
|
}
|
|
|
|
struct Reducer<'a, 'b> {
|
|
symbol_table: &'a SymbolTable,
|
|
functions: HashMap<DefId, FunctionDefinition>,
|
|
type_context: &'b TypeContext,
|
|
}
|
|
|
|
impl<'a, 'b> Reducer<'a, 'b> {
|
|
fn new(symbol_table: &'a SymbolTable, type_context: &'b TypeContext) -> Self {
|
|
Self { symbol_table, functions: HashMap::new(), type_context }
|
|
}
|
|
|
|
fn reduce(mut self, ast: &ast::AST) -> ReducedIR {
|
|
// First reduce all functions
|
|
// TODO once this works, maybe rewrite it using the Visitor
|
|
for statement in ast.statements.statements.iter() {
|
|
self.top_level_statement(statement);
|
|
}
|
|
|
|
// Then compute the entrypoint statements (which may reference previously-computed
|
|
// functions by ID)
|
|
let mut entrypoint = vec![];
|
|
for statement in ast.statements.statements.iter() {
|
|
let ast::Statement { id: item_id, kind, .. } = statement;
|
|
match &kind {
|
|
ast::StatementKind::Expression(expr) => {
|
|
entrypoint.push(Statement::Expression(self.expression(expr)));
|
|
}
|
|
ast::StatementKind::Declaration(ast::Declaration::Binding {
|
|
name: _,
|
|
constant,
|
|
expr,
|
|
..
|
|
}) => {
|
|
let symbol = self.symbol_table.lookup_symbol(item_id).unwrap();
|
|
let def_id = symbol.def_id().unwrap();
|
|
entrypoint.push(Statement::Binding {
|
|
id: def_id,
|
|
constant: *constant,
|
|
expr: self.expression(expr),
|
|
});
|
|
}
|
|
_ => (),
|
|
}
|
|
}
|
|
|
|
ReducedIR { functions: self.functions, entrypoint }
|
|
}
|
|
|
|
fn top_level_statement(&mut self, statement: &ast::Statement) {
|
|
let ast::Statement { id: item_id, kind, .. } = statement;
|
|
match kind {
|
|
ast::StatementKind::Expression(_expr) => {
|
|
//TODO expressions can in principle contain definitions, but I won't worry
|
|
//about it now
|
|
}
|
|
ast::StatementKind::Declaration(decl) =>
|
|
if let ast::Declaration::FuncDecl(_, statements) = decl {
|
|
self.insert_function_definition(item_id, statements);
|
|
},
|
|
// Imports should have already been processed by the symbol table and are irrelevant
|
|
// for this representation.
|
|
ast::StatementKind::Import(..) => (),
|
|
ast::StatementKind::Flow(..) => {
|
|
//TODO this should be an error
|
|
}
|
|
}
|
|
}
|
|
|
|
fn function_internal_statement(&mut self, statement: &ast::Statement) -> Option<Statement> {
|
|
let ast::Statement { id: item_id, kind, .. } = statement;
|
|
match kind {
|
|
ast::StatementKind::Expression(expr) => Some(Statement::Expression(self.expression(expr))),
|
|
ast::StatementKind::Declaration(decl) => match decl {
|
|
ast::Declaration::FuncDecl(_, statements) => {
|
|
self.insert_function_definition(item_id, statements);
|
|
None
|
|
}
|
|
ast::Declaration::Binding { constant, expr, .. } => {
|
|
let symbol = self.symbol_table.lookup_symbol(item_id).unwrap();
|
|
let def_id = symbol.def_id().unwrap();
|
|
Some(Statement::Binding { id: def_id, constant: *constant, expr: self.expression(expr) })
|
|
}
|
|
_ => None,
|
|
},
|
|
ast::StatementKind::Import(_) => None,
|
|
ast::StatementKind::Flow(ast::FlowControl::Return(expr)) =>
|
|
if let Some(expr) = expr {
|
|
Some(Statement::Return(self.expression(expr)))
|
|
} else {
|
|
Some(Statement::Return(Expression::unit()))
|
|
},
|
|
ast::StatementKind::Flow(ast::FlowControl::Break) => Some(Statement::Break),
|
|
ast::StatementKind::Flow(ast::FlowControl::Continue) => Some(Statement::Continue),
|
|
}
|
|
}
|
|
|
|
fn insert_function_definition(&mut self, item_id: &ast::ItemId, statements: &ast::Block) {
|
|
let symbol = self.symbol_table.lookup_symbol(item_id).unwrap();
|
|
let def_id = symbol.def_id().unwrap();
|
|
let function_def = FunctionDefinition { body: self.function_internal_block(statements) };
|
|
self.functions.insert(def_id, function_def);
|
|
}
|
|
|
|
fn expression(&mut self, expr: &ast::Expression) -> Expression {
|
|
use crate::ast::ExpressionKind::*;
|
|
|
|
match &expr.kind {
|
|
NatLiteral(n) => Expression::Literal(Literal::Nat(*n)),
|
|
FloatLiteral(f) => Expression::Literal(Literal::Float(*f)),
|
|
StringLiteral(s) => Expression::Literal(Literal::StringLit(s.clone())),
|
|
BoolLiteral(b) => Expression::Literal(Literal::Bool(*b)),
|
|
BinExp(binop, lhs, rhs) => self.binop(binop, lhs, rhs),
|
|
PrefixExp(op, arg) => self.prefix(op, arg),
|
|
Value(qualified_name) => self.value(qualified_name),
|
|
Call { f, arguments } => Expression::Call {
|
|
f: Box::new(self.expression(f)),
|
|
args: arguments.iter().map(|arg| self.invocation_argument(arg)).collect(),
|
|
},
|
|
TupleLiteral(exprs) => Expression::Tuple(exprs.iter().map(|e| self.expression(e)).collect()),
|
|
IfExpression { discriminator, body } =>
|
|
self.reduce_if_expression(discriminator.as_ref().map(|x| x.as_ref()), body),
|
|
Lambda { params, body, .. } => Expression::Callable(Callable::Lambda {
|
|
arity: params.len() as u8,
|
|
body: self.function_internal_block(body),
|
|
}),
|
|
NamedStruct { name, fields } => {
|
|
let symbol = match self.symbol_table.lookup_symbol(&name.id) {
|
|
Some(symbol) => symbol,
|
|
None => return Expression::ReductionError(format!("No symbol found for {:?}", name)),
|
|
};
|
|
let (tag, type_id) = match symbol.spec() {
|
|
SymbolSpec::RecordConstructor { tag, type_id } => (tag, type_id),
|
|
e => return Expression::ReductionError(format!("Bad symbol for NamedStruct: {:?}", e)),
|
|
};
|
|
|
|
let field_order = compute_field_orderings(self.type_context, &type_id, tag).unwrap();
|
|
|
|
let mut field_map = HashMap::new();
|
|
for (name, expr) in fields.iter() {
|
|
field_map.insert(name.as_ref(), expr);
|
|
}
|
|
|
|
let mut ordered_args = vec![];
|
|
for field in field_order.iter() {
|
|
let expr = match field_map.get(&field) {
|
|
Some(expr) => expr,
|
|
None =>
|
|
return Expression::ReductionError(format!(
|
|
"Field {} not specified for record {}",
|
|
field, name
|
|
)),
|
|
};
|
|
ordered_args.push(self.expression(expr));
|
|
}
|
|
|
|
let constructor =
|
|
Expression::Callable(Callable::RecordConstructor { type_id, tag, field_order });
|
|
Expression::Call { f: Box::new(constructor), args: ordered_args }
|
|
}
|
|
Index { indexee, indexers } => self.reduce_index(indexee.as_ref(), indexers.as_slice()),
|
|
WhileExpression { condition, body } => {
|
|
let cond = Box::new(if let Some(condition) = condition {
|
|
self.expression(condition)
|
|
} else {
|
|
Expression::Literal(Literal::Bool(true))
|
|
});
|
|
let statements = self.function_internal_block(body);
|
|
Expression::Loop { cond, statements }
|
|
}
|
|
ForExpression { .. } => Expression::ReductionError("For expr not implemented".to_string()),
|
|
ListLiteral(items) => Expression::List(items.iter().map(|item| self.expression(item)).collect()),
|
|
Access { name, expr } =>
|
|
Expression::Access { name: name.as_ref().to_string(), expr: Box::new(self.expression(expr)) },
|
|
}
|
|
}
|
|
|
|
//TODO figure out the semantics of multiple indexers - for now, just ignore them
|
|
fn reduce_index(&mut self, indexee: &ast::Expression, indexers: &[ast::Expression]) -> Expression {
|
|
if indexers.len() != 1 {
|
|
return Expression::ReductionError("Invalid index expression".to_string());
|
|
}
|
|
let indexee = self.expression(indexee);
|
|
let indexer = self.expression(&indexers[0]);
|
|
Expression::Index { indexee: Box::new(indexee), indexer: Box::new(indexer) }
|
|
}
|
|
|
|
fn reduce_if_expression(
|
|
&mut self,
|
|
discriminator: Option<&ast::Expression>,
|
|
body: &ast::IfExpressionBody,
|
|
) -> Expression {
|
|
use ast::IfExpressionBody::*;
|
|
|
|
let cond = Box::new(match discriminator {
|
|
Some(expr) => self.expression(expr),
|
|
None => return Expression::ReductionError("blank cond if-expr not supported".to_string()),
|
|
});
|
|
match body {
|
|
SimpleConditional { then_case, else_case } => {
|
|
let then_clause = self.function_internal_block(then_case);
|
|
let else_clause = match else_case.as_ref() {
|
|
None => vec![],
|
|
Some(stmts) => self.function_internal_block(stmts),
|
|
};
|
|
Expression::Conditional { cond, then_clause, else_clause }
|
|
}
|
|
SimplePatternMatch { pattern, then_case, else_case } => {
|
|
let alternatives = vec![
|
|
Alternative {
|
|
pattern: match pattern.reduce(self.symbol_table) {
|
|
Ok(p) => p,
|
|
Err(e) => return Expression::ReductionError(format!("Bad pattern: {:?}", e)),
|
|
},
|
|
item: self.function_internal_block(then_case),
|
|
},
|
|
Alternative {
|
|
pattern: Pattern::Ignored,
|
|
item: match else_case.as_ref() {
|
|
Some(else_case) => self.function_internal_block(else_case),
|
|
None => vec![],
|
|
},
|
|
},
|
|
];
|
|
|
|
Expression::CaseMatch { cond, alternatives }
|
|
}
|
|
CondList(ref condition_arms) => {
|
|
let mut alternatives = vec![];
|
|
for arm in condition_arms {
|
|
match arm.condition {
|
|
ast::Condition::Expression(ref _expr) =>
|
|
return Expression::ReductionError("case-expression".to_string()),
|
|
ast::Condition::Pattern(ref pat) => {
|
|
let alt = Alternative {
|
|
pattern: match pat.reduce(self.symbol_table) {
|
|
Ok(p) => p,
|
|
Err(e) =>
|
|
return Expression::ReductionError(format!("Bad pattern: {:?}", e)),
|
|
},
|
|
item: self.function_internal_block(&arm.body),
|
|
};
|
|
alternatives.push(alt);
|
|
}
|
|
ast::Condition::TruncatedOp(_, _) =>
|
|
return Expression::ReductionError("case-expression-trunc-op".to_string()),
|
|
ast::Condition::Else =>
|
|
return Expression::ReductionError("case-expression-else".to_string()),
|
|
}
|
|
}
|
|
Expression::CaseMatch { cond, alternatives }
|
|
}
|
|
}
|
|
}
|
|
|
|
fn invocation_argument(&mut self, invoc: &ast::InvocationArgument) -> Expression {
|
|
use crate::ast::InvocationArgument::*;
|
|
match invoc {
|
|
Positional(ex) => self.expression(ex),
|
|
Keyword { .. } => Expression::ReductionError("Keyword arguments not supported".to_string()),
|
|
Ignored => Expression::ReductionError("Ignored arguments not supported".to_string()),
|
|
}
|
|
}
|
|
|
|
fn function_internal_block(&mut self, block: &ast::Block) -> Vec<Statement> {
|
|
block.statements.iter().filter_map(|stmt| self.function_internal_statement(stmt)).collect()
|
|
}
|
|
|
|
fn prefix(&mut self, prefix: &ast::PrefixOp, arg: &ast::Expression) -> Expression {
|
|
let builtin: Option<Builtin> = TryFrom::try_from(prefix).ok();
|
|
match builtin {
|
|
Some(op) => Expression::Call {
|
|
f: Box::new(Expression::Callable(Callable::Builtin(op))),
|
|
args: vec![self.expression(arg)],
|
|
},
|
|
None => {
|
|
//TODO need this for custom prefix ops
|
|
Expression::ReductionError("User-defined prefix ops not supported".to_string())
|
|
}
|
|
}
|
|
}
|
|
|
|
fn binop(&mut self, binop: &ast::BinOp, lhs: &ast::Expression, rhs: &ast::Expression) -> Expression {
|
|
use Expression::ReductionError;
|
|
|
|
let operation = Builtin::from_str(binop.sigil()).ok();
|
|
match operation {
|
|
Some(Builtin::Assignment) => {
|
|
let lval = match &lhs.kind {
|
|
ast::ExpressionKind::Value(qualified_name) => {
|
|
if let Some(symbol) = self.symbol_table.lookup_symbol(&qualified_name.id) {
|
|
symbol.def_id().unwrap()
|
|
} else {
|
|
return ReductionError(format!("Couldn't look up name: {:?}", qualified_name));
|
|
}
|
|
}
|
|
_ => return ReductionError("Trying to assign to a non-name".to_string()),
|
|
};
|
|
|
|
Expression::Assign { lval, rval: Box::new(self.expression(rhs)) }
|
|
}
|
|
Some(op) => Expression::Call {
|
|
f: Box::new(Expression::Callable(Callable::Builtin(op))),
|
|
args: vec![self.expression(lhs), self.expression(rhs)],
|
|
},
|
|
//TODO handle a user-defined operation
|
|
None => ReductionError("User-defined operations not supported".to_string()),
|
|
}
|
|
}
|
|
|
|
fn value(&mut self, qualified_name: &ast::QualifiedName) -> Expression {
|
|
use SymbolSpec::*;
|
|
|
|
let symbol = match self.symbol_table.lookup_symbol(&qualified_name.id) {
|
|
Some(s) => s,
|
|
None =>
|
|
return Expression::ReductionError(format!("No symbol found for name: {:?}", qualified_name)),
|
|
};
|
|
|
|
let def_id = symbol.def_id();
|
|
|
|
match symbol.spec() {
|
|
Builtin(b) => Expression::Callable(Callable::Builtin(b)),
|
|
Func => Expression::Lookup(Lookup::Function(def_id.unwrap())),
|
|
GlobalBinding => Expression::Lookup(Lookup::GlobalVar(def_id.unwrap())),
|
|
LocalVariable => Expression::Lookup(Lookup::LocalVar(def_id.unwrap())),
|
|
FunctionParam(n) => Expression::Lookup(Lookup::Param(n)),
|
|
DataConstructor { tag, type_id } =>
|
|
Expression::Callable(Callable::DataConstructor { type_id, tag }),
|
|
RecordConstructor { .. } => Expression::ReductionError(format!(
|
|
"The symbol for value {:?} is unexpectdly a RecordConstructor",
|
|
qualified_name
|
|
)),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ast::Pattern {
|
|
fn reduce(&self, symbol_table: &SymbolTable) -> Result<Pattern, PatternError> {
|
|
Ok(match self {
|
|
ast::Pattern::Ignored => Pattern::Ignored,
|
|
ast::Pattern::TuplePattern(subpatterns) => {
|
|
let items: Result<Vec<Pattern>, PatternError> =
|
|
subpatterns.iter().map(|pat| pat.reduce(symbol_table)).into_iter().collect();
|
|
let items = items?;
|
|
Pattern::Tuple { tag: None, subpatterns: items }
|
|
}
|
|
ast::Pattern::Literal(lit) => Pattern::Literal(match lit {
|
|
ast::PatternLiteral::NumPattern { neg, num } => match (neg, num) {
|
|
(false, ast::ExpressionKind::NatLiteral(n)) => Literal::Nat(*n),
|
|
(false, ast::ExpressionKind::FloatLiteral(f)) => Literal::Float(*f),
|
|
(true, ast::ExpressionKind::NatLiteral(n)) => Literal::Int(-(*n as i64)),
|
|
(true, ast::ExpressionKind::FloatLiteral(f)) => Literal::Float(-f),
|
|
(_, e) =>
|
|
return Err(format!("Internal error, unexpected pattern literal: {:?}", e).into()),
|
|
},
|
|
ast::PatternLiteral::StringPattern(s) => Literal::StringLit(s.clone()),
|
|
ast::PatternLiteral::BoolPattern(b) => Literal::Bool(*b),
|
|
}),
|
|
ast::Pattern::TupleStruct(name, subpatterns) => {
|
|
let symbol = symbol_table.lookup_symbol(&name.id).unwrap();
|
|
if let SymbolSpec::DataConstructor { tag, type_id: _ } = symbol.spec() {
|
|
let items: Result<Vec<Pattern>, PatternError> =
|
|
subpatterns.iter().map(|pat| pat.reduce(symbol_table)).into_iter().collect();
|
|
let items = items?;
|
|
Pattern::Tuple { tag: Some(tag), subpatterns: items }
|
|
} else {
|
|
return Err(
|
|
"Internal error, trying to match something that's not a DataConstructor".into()
|
|
);
|
|
}
|
|
}
|
|
ast::Pattern::VarOrName(name) => {
|
|
let symbol = symbol_table.lookup_symbol(&name.id).unwrap();
|
|
match symbol.spec() {
|
|
SymbolSpec::DataConstructor { tag, type_id: _ } =>
|
|
Pattern::Tuple { tag: Some(tag), subpatterns: vec![] },
|
|
SymbolSpec::LocalVariable => {
|
|
let def_id = symbol.def_id().unwrap();
|
|
Pattern::Binding(def_id)
|
|
}
|
|
spec => return Err(format!("Unexpected VarOrName symbol: {:?}", spec).into()),
|
|
}
|
|
}
|
|
ast::Pattern::Record(name, specified_members) => {
|
|
let symbol = symbol_table.lookup_symbol(&name.id).unwrap();
|
|
if let SymbolSpec::RecordConstructor { tag, type_id: _ } = symbol.spec() {
|
|
//TODO do this computation from the type_id
|
|
/*
|
|
if specified_members.iter().any(|(member, _)| !members.contains_key(member)) {
|
|
return Err(format!("Unknown key in record pattern").into());
|
|
}
|
|
*/
|
|
|
|
let subpatterns: Result<Vec<(Rc<String>, Pattern)>, PatternError> = specified_members
|
|
.iter()
|
|
.map(|(name, pat)| {
|
|
pat.reduce(symbol_table).map(|reduced_pat| (name.clone(), reduced_pat))
|
|
})
|
|
.into_iter()
|
|
.collect();
|
|
let subpatterns = subpatterns?;
|
|
Pattern::Record { tag, subpatterns }
|
|
} else {
|
|
return Err(format!("Unexpected Record pattern symbol: {:?}", symbol.spec()).into());
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Given the type context and a variant, compute what order the fields on it were stored.
|
|
/// This needs to be public until type-checking is fully implemented because the type information
|
|
/// is only available at runtime.
|
|
pub fn compute_field_orderings(
|
|
type_context: &TypeContext,
|
|
type_id: &TypeId,
|
|
tag: u32,
|
|
) -> Option<Vec<String>> {
|
|
// Eventually, the ReducedIR should decide what field ordering is optimal.
|
|
// For now, just do it alphabetically.
|
|
|
|
let record_members = type_context.lookup_record_members(type_id, tag)?;
|
|
let mut field_order: Vec<String> =
|
|
record_members.iter().map(|(field, _type_id)| field).cloned().collect();
|
|
field_order.sort_unstable();
|
|
Some(field_order)
|
|
}
|