feat: builtin wrapper reduction optimization

This commit is contained in:
microproofs 2024-04-26 19:24:04 +02:00
parent 4f99c81dee
commit 945a3f743b
1 changed files with 190 additions and 93 deletions

View File

@ -708,6 +708,7 @@ impl Term<Name> {
mut arg_stack: Vec<(usize, Term<Name>)>,
id_gen: &mut IdGen,
with: &mut impl FnMut(Option<usize>, &mut Term<Name>, Vec<(usize, Term<Name>)>, &Scope),
inline_lambda: bool,
) {
match self {
Term::Apply { function, argument } => {
@ -719,6 +720,7 @@ impl Term<Name> {
vec![],
id_gen,
with,
inline_lambda,
);
let apply_id = id_gen.next_id();
@ -732,6 +734,7 @@ impl Term<Name> {
arg_stack,
id_gen,
with,
inline_lambda,
);
scope.pop();
@ -741,31 +744,72 @@ impl Term<Name> {
Term::Delay(d) => {
let d = Rc::make_mut(d);
// First we recurse further to reduce the inner terms before coming back up to the Delay
Self::traverse_uplc_with_helper(d, scope, arg_stack, id_gen, with);
Self::traverse_uplc_with_helper(d, scope, arg_stack, id_gen, with, inline_lambda);
with(None, self, vec![], scope);
}
Term::Lambda {
parameter_name: p,
body,
parameter_name,
} => {
let body = Rc::make_mut(body);
let p = p.as_ref().clone();
// Lambda pops one item off the arg stack. If there is no item then it is a unsaturated lambda
// We also skip NO_INLINE lambdas since those are placeholder lambdas created by codegen
let args = if parameter_name.text == NO_INLINE {
let args = if p.text == NO_INLINE {
vec![]
} else {
arg_stack.pop().map(|arg| vec![arg]).unwrap_or_default()
};
// Pass in either one or zero args.
Self::traverse_uplc_with_helper(body, scope, arg_stack, id_gen, with);
with(None, self, args, scope);
if inline_lambda {
// Pass in either one or zero args.
// For lambda we run the function with first then recurse on the body or replaced term
with(None, self, args, scope);
match self {
Term::Lambda {
parameter_name,
body,
} if parameter_name.as_ref() == &p => {
let body = Rc::make_mut(body);
Self::traverse_uplc_with_helper(
body,
scope,
arg_stack,
id_gen,
with,
inline_lambda,
);
}
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
other => Self::traverse_uplc_with_helper(
other,
scope,
arg_stack,
id_gen,
with,
inline_lambda,
),
}
} else {
let body = Rc::make_mut(body);
Self::traverse_uplc_with_helper(
body,
scope,
arg_stack,
id_gen,
with,
inline_lambda,
);
with(None, self, vec![], scope);
}
}
Term::Force(f) => {
let f = Rc::make_mut(f);
Self::traverse_uplc_with_helper(f, scope, arg_stack, id_gen, with);
Self::traverse_uplc_with_helper(f, scope, arg_stack, id_gen, with, inline_lambda);
with(None, self, vec![], scope);
}
Term::Case { .. } => todo!(),
@ -792,6 +836,7 @@ impl Term<Name> {
impl Program<Name> {
fn traverse_uplc_with(
self,
inline_lambda: bool,
with: &mut impl FnMut(Option<usize>, &mut Term<Name>, Vec<(usize, Term<Name>)>, &Scope),
) -> Self {
let mut term = self.term;
@ -799,7 +844,7 @@ impl Program<Name> {
let arg_stack = vec![];
let mut id_gen = IdGen::new();
term.traverse_uplc_with_helper(&scope, arg_stack, &mut id_gen, with);
term.traverse_uplc_with_helper(&scope, arg_stack, &mut id_gen, with, inline_lambda);
Program {
version: self.version,
term,
@ -809,7 +854,7 @@ impl Program<Name> {
pub fn lambda_reducer(self) -> Self {
let mut lambda_applied_ids = vec![];
self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| {
self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| {
match term {
Term::Apply { function, .. } => {
// We are applying some arg so now we unwrap the id of the applied arg
@ -827,7 +872,7 @@ impl Program<Name> {
} => {
// pops stack here no matter what
if let Some((arg_id, arg_term)) = arg_stack.pop() {
match arg_term {
match &arg_term {
Term::Constant(c) if matches!(c.as_ref(), Constant::String(_)) => {}
Term::Constant(_) | Term::Var(_) | Term::Builtin(_) => {
let body = Rc::make_mut(body);
@ -835,10 +880,20 @@ impl Program<Name> {
// creates new body that replaces all var occurrences with the arg
*term = substitute_var(body, parameter_name.clone(), &arg_term);
}
l @ Term::Lambda { .. } => {
if is_a_builtin_wrapper(l) {
let body = Rc::make_mut(body);
lambda_applied_ids.push(arg_id);
// creates new body that replaces all var occurrences with the arg
*term = substitute_var(body, parameter_name.clone(), &arg_term);
}
}
_ => {}
}
}
}
Term::Case { .. } => todo!(),
Term::Constr { .. } => todo!(),
_ => {}
@ -849,7 +904,7 @@ impl Program<Name> {
pub fn builtin_force_reducer(self) -> Self {
let mut builtin_map = IndexMap::new();
let program = self.traverse_uplc_with(&mut |_id, term, _arg_stack, _scope| {
let program = self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| {
if let Term::Force(f) = term {
let f = Rc::make_mut(f);
match f {
@ -909,7 +964,7 @@ impl Program<Name> {
pub fn identity_reducer(self) -> Self {
let mut identity_applied_ids = vec![];
self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| {
self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| {
match term {
Term::Apply { function, .. } => {
// We are applying some arg so now we unwrap the id of the applied arg
@ -1018,7 +1073,7 @@ impl Program<Name> {
pub fn inline_reducer(self) -> Self {
let mut lambda_applied_ids = vec![];
self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| match term {
self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| match term {
Term::Apply { function, .. } => {
// We are applying some arg so now we unwrap the id of the applied arg
let id = id.unwrap();
@ -1084,7 +1139,7 @@ impl Program<Name> {
}
pub fn force_delay_reducer(self) -> Self {
self.traverse_uplc_with(&mut |_id, term, _arg_stack, _scope| {
self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| {
if let Term::Force(f) = term {
let f = f.as_ref();
@ -1096,7 +1151,7 @@ impl Program<Name> {
}
pub fn remove_no_inlines(self) -> Self {
self.traverse_uplc_with(&mut |_, term, _, _| match term {
self.traverse_uplc_with(true, &mut |_, term, _, _| match term {
Term::Lambda {
parameter_name,
body,
@ -1106,7 +1161,7 @@ impl Program<Name> {
}
pub fn inline_constr_ops(self) -> Self {
self.traverse_uplc_with(&mut |_, term, _, _| {
self.traverse_uplc_with(true, &mut |_, term, _, _| {
if let Term::Apply { function, argument } = term {
if let Term::Var(name) = function.as_ref() {
if name.text == CONSTR_FIELDS_EXPOSER {
@ -1128,7 +1183,7 @@ impl Program<Name> {
pub fn cast_data_reducer(self) -> Self {
let mut applied_ids = vec![];
self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| {
self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| {
match term {
Term::Apply { function, .. } => {
// We are apply some arg so now we unwrap the id of the applied arg
@ -1256,7 +1311,7 @@ impl Program<Name> {
pub fn convert_arithmetic_ops(self) -> Self {
let mut constants_to_flip = vec![];
self.traverse_uplc_with(&mut |id, term, arg_stack, _scope| match term {
self.traverse_uplc_with(true, &mut |id, term, arg_stack, _scope| match term {
Term::Apply { argument, .. } => {
let id = id.unwrap();
@ -1300,92 +1355,93 @@ impl Program<Name> {
let mut final_ids: IndexMap<Vec<usize>, ()> = IndexMap::new();
let step_a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term {
Term::Builtin(func) => {
if func.can_curry_builtin() && arg_stack.len() == func.arity() {
let is_order_agnostic = func.is_order_agnostic_builtin();
let step_a =
self.traverse_uplc_with(false, &mut |_id, term, arg_stack, scope| match term {
Term::Builtin(func) => {
if func.can_curry_builtin() && arg_stack.len() == func.arity() {
let is_order_agnostic = func.is_order_agnostic_builtin();
// In the case of order agnostic builtins we want to sort the args by constant first
// This gives us the opportunity to curry constants that often pop up in the code
// In the case of order agnostic builtins we want to sort the args by constant first
// This gives us the opportunity to curry constants that often pop up in the code
let builtin_args =
BuiltinArgs::args_from_arg_stack(arg_stack, is_order_agnostic);
let builtin_args =
BuiltinArgs::args_from_arg_stack(arg_stack, is_order_agnostic);
// First we see if we have already curried this builtin before
let mut id_vec = if let Some((index, _)) =
curried_terms.iter_mut().find_position(
|curried_term: &&mut CurriedBuiltin| curried_term.func == *func,
) {
// We found it the builtin was curried before
// So now we merge the new args into the existing curried builtin
// First we see if we have already curried this builtin before
let mut id_vec = if let Some((index, _)) =
curried_terms.iter_mut().find_position(
|curried_term: &&mut CurriedBuiltin| curried_term.func == *func,
) {
// We found it the builtin was curried before
// So now we merge the new args into the existing curried builtin
let curried_builtin = curried_terms.swap_remove(index);
let curried_builtin = curried_terms.swap_remove(index);
let curried_builtin =
curried_builtin.merge_node_by_path(builtin_args.clone());
let curried_builtin =
curried_builtin.merge_node_by_path(builtin_args.clone());
let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else {
unreachable!();
};
let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else {
unreachable!();
};
flipped_terms
.insert(scope.clone(), curried_builtin.is_flipped(&builtin_args));
flipped_terms
.insert(scope.clone(), curried_builtin.is_flipped(&builtin_args));
curried_terms.push(curried_builtin);
curried_terms.push(curried_builtin);
id_vec
} else {
// Brand new buitlin so we add it to the list
let curried_builtin = builtin_args.clone().args_to_curried_args(*func);
let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else {
unreachable!();
};
curried_terms.push(curried_builtin);
id_vec
};
while let Some(node) = id_vec.pop() {
let mut id_only_vec =
id_vec.iter().map(|item| item.curried_id).collect_vec();
id_only_vec.push(node.curried_id);
let curry_name = CurriedName {
func_name: func.aiken_name(),
id_vec: id_only_vec,
};
if let Some((map_scope, _, occurrences)) =
id_mapped_curry_terms.get_mut(&curry_name)
{
*map_scope = map_scope.common_ancestor(scope);
*occurrences += 1;
} else if id_vec.is_empty() {
id_mapped_curry_terms.insert(
curry_name,
(scope.clone(), Term::Builtin(*func).apply(node.term), 1),
);
id_vec
} else {
let var_name = id_vec_function_to_var(
&func.aiken_name(),
&id_vec.iter().map(|item| item.curried_id).collect_vec(),
);
// Brand new buitlin so we add it to the list
let curried_builtin = builtin_args.clone().args_to_curried_args(*func);
id_mapped_curry_terms.insert(
curry_name,
(scope.clone(), Term::var(var_name).apply(node.term), 1),
);
let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else {
unreachable!();
};
curried_terms.push(curried_builtin);
id_vec
};
while let Some(node) = id_vec.pop() {
let mut id_only_vec =
id_vec.iter().map(|item| item.curried_id).collect_vec();
id_only_vec.push(node.curried_id);
let curry_name = CurriedName {
func_name: func.aiken_name(),
id_vec: id_only_vec,
};
if let Some((map_scope, _, occurrences)) =
id_mapped_curry_terms.get_mut(&curry_name)
{
*map_scope = map_scope.common_ancestor(scope);
*occurrences += 1;
} else if id_vec.is_empty() {
id_mapped_curry_terms.insert(
curry_name,
(scope.clone(), Term::Builtin(*func).apply(node.term), 1),
);
} else {
let var_name = id_vec_function_to_var(
&func.aiken_name(),
&id_vec.iter().map(|item| item.curried_id).collect_vec(),
);
id_mapped_curry_terms.insert(
curry_name,
(scope.clone(), Term::var(var_name).apply(node.term), 1),
);
}
}
}
}
}
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
_ => {}
});
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
_ => {}
});
id_mapped_curry_terms
.into_iter()
@ -1410,7 +1466,7 @@ impl Program<Name> {
});
let mut step_b =
step_a.traverse_uplc_with(&mut |id, term, mut arg_stack, scope| match term {
step_a.traverse_uplc_with(false, &mut |id, term, mut arg_stack, scope| match term {
Term::Builtin(func) => {
if func.can_curry_builtin() && arg_stack.len() == func.arity() {
let Some(curried_builtin) =
@ -1658,6 +1714,47 @@ fn replace_identity_usage(term: &Term<Name>, original: Rc<Name>) -> Term<Name> {
}
}
fn is_a_builtin_wrapper(term: &Term<Name>) -> bool {
let (names, term) = pop_lambdas_and_get_names(term);
let mut arg_names = vec![];
let mut term = term;
while let Term::Apply { function, argument } = term {
match argument.as_ref() {
Term::Var(name) => arg_names.push(name),
Term::Constant(_) => {}
_ => {
return false;
}
}
term = function.as_ref();
}
arg_names.iter().all(|item| names.contains(item)) && matches!(term, Term::Builtin(_))
}
fn pop_lambdas_and_get_names(term: &Term<Name>) -> (Vec<Rc<Name>>, &Term<Name>) {
let mut names = vec![];
let mut term = term;
while let Term::Lambda {
parameter_name,
body,
} = term
{
if parameter_name.text != NO_INLINE {
names.push(parameter_name.clone());
}
term = body.as_ref();
}
(names, term)
}
#[cfg(test)]
mod tests {