From 4ea6fdffe8f75645fb9fde01a3ac5dbf35cfbfd3 Mon Sep 17 00:00:00 2001 From: Kasey <49739331+MicroProofs@users.noreply.github.com> Date: Wed, 13 Nov 2024 15:08:36 -0500 Subject: [PATCH] Aiken UPLC Optimization overhaul (#1052) * Refactor and structuring optimizations to be less computationally heavy * Forgot to commit the new file containing the optimization do over * Point to correct functions in shrinker2 * Split out inline_constr_ops since it adds in builtins that can then be swept up by the builtin force reduction * Fix: issue where identity reducer was always returning true * Forward inlining on lambdas produces better results. This is due to a forward pass being able to apply an argument that may have no_inline at the top where as vice-versa would reduce the arg first. * Clippy and test fixes * Clear no_inlines when inlining a function * Convert shrinker2 to replace shrinker and update tests --- crates/aiken-lang/src/gen_uplc.rs | 10 +- crates/aiken-project/src/tests/gen_uplc.rs | 26 +- crates/uplc/src/optimize.rs | 54 +- crates/uplc/src/optimize/shrinker.rs | 1770 +++++++++++--------- 4 files changed, 1035 insertions(+), 825 deletions(-) diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index 94eea06a..e03e6407 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -3782,7 +3782,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.remove_no_inlines().try_into().unwrap(); + program.clean_up().try_into().unwrap(); Some( eval_program @@ -3892,7 +3892,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.remove_no_inlines().try_into().unwrap(); + program.clean_up().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) @@ -4434,7 +4434,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.remove_no_inlines().try_into().unwrap(); + program.clean_up().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) @@ -4459,7 +4459,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.remove_no_inlines().try_into().unwrap(); + program.clean_up().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) @@ -4880,7 +4880,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.remove_no_inlines().try_into().unwrap(); + program.clean_up().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) diff --git a/crates/aiken-project/src/tests/gen_uplc.rs b/crates/aiken-project/src/tests/gen_uplc.rs index dc82f84b..8a5506a5 100644 --- a/crates/aiken-project/src/tests/gen_uplc.rs +++ b/crates/aiken-project/src/tests/gen_uplc.rs @@ -3690,12 +3690,12 @@ fn always_true_validator() { Term::snd_pair() .apply(Term::unconstr_data().apply(Term::var(context))) .as_var("tail_id_2", |tail_id_2| { - Term::head_list() + Term::tail_list() .apply(Term::Var(tail_id_2.clone())) - .as_var("__transaction__", |_transaction| { - Term::tail_list().apply(Term::Var(tail_id_2)).as_var( - "tail_id_3", - |tail_id_3| { + .as_var("tail_id_3", |tail_id_3| { + Term::head_list() + .apply(Term::Var(tail_id_2.clone())) + .as_var("__transaction__", |_transaction| { Term::head_list() .apply(Term::Var(tail_id_3.clone())) .as_var("__redeemer__", |_redeemer| { @@ -3708,8 +3708,7 @@ fn always_true_validator() { ) }) }) - }, - ) + }) }) }) .delayed_if_then_else( @@ -4187,12 +4186,12 @@ fn generic_validator_type_test() { Term::snd_pair() .apply(Term::unconstr_data().apply(Term::var(context))) .as_var("tail_id_13", |tail_id_13| { - Term::head_list() + Term::tail_list() .apply(Term::Var(tail_id_13.clone())) - .as_var("__transaction__", |_transaction| { - Term::tail_list().apply(Term::Var(tail_id_13)).as_var( - "tail_id_14", - |tail_id_14| { + .as_var("tail_id_14", |tail_id_14| { + Term::head_list() + .apply(Term::Var(tail_id_13.clone())) + .as_var("__transaction__", |_transaction| { Term::head_list() .apply(Term::Var(tail_id_14.clone())) .as_var("__redeemer__", |redeemer| { @@ -4202,8 +4201,7 @@ fn generic_validator_type_test() { choose_purpose(redeemer, purpose, trace) }) }) - }, - ) + }) }) }) .delayed_if_then_else( diff --git a/crates/uplc/src/optimize.rs b/crates/uplc/src/optimize.rs index ef60c20a..fd00a2c9 100644 --- a/crates/uplc/src/optimize.rs +++ b/crates/uplc/src/optimize.rs @@ -4,25 +4,39 @@ pub mod interner; pub mod shrinker; pub fn aiken_optimize_and_intern(program: Program) -> Program { - program - .inline_constr_ops() - .bls381_compressor() - .builtin_force_reducer() - .lambda_reducer() - .inline_reducer() - .identity_reducer() - .lambda_reducer() - .inline_reducer() - .force_delay_reducer() - .cast_data_reducer() - .builtin_eval_reducer() - .convert_arithmetic_ops() + let mut prog = program.run_once_pass(); + + let mut prev_count = 0; + + loop { + let (current_program, context) = prog.multi_pass(); + + if context.node_count == prev_count { + prog = current_program; + break; + } else { + prog = current_program; + prev_count = context.node_count; + } + } + + prog = prog .builtin_curry_reducer() - .lambda_reducer() - .inline_reducer() - .identity_reducer() - .builtin_curry_reducer() - .lambda_reducer() - .inline_reducer() - .remove_no_inlines() + .multi_pass() + .0 + .builtin_curry_reducer(); + + loop { + let (current_program, context) = prog.multi_pass(); + + if context.node_count == prev_count { + prog = current_program; + break; + } else { + prog = current_program; + prev_count = context.node_count; + } + } + + prog.clean_up() } diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index 9a96283e..44c15412 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -5,10 +5,11 @@ use crate::{ builtins::DefaultFunction, machine::{cost_model::ExBudget, runtime::Compressable}, }; +use blst::{blst_p1, blst_p2}; use indexmap::IndexMap; use itertools::Itertools; use pallas_primitives::conway::{BigInt, PlutusData}; -use std::{cmp::Ordering, iter, ops::Neg, rc::Rc, vec}; +use std::{cmp::Ordering, iter, ops::Neg, rc::Rc}; #[derive(Eq, Hash, PartialEq, Clone, Debug, PartialOrd)] pub enum ScopePath { @@ -852,134 +853,874 @@ impl CurriedBuiltin { } } +#[derive(Debug, Clone)] +pub struct Context { + pub inlined_apply_ids: Vec, + pub constants_to_flip: Vec, + pub builtins_map: IndexMap, + pub blst_p1_list: Vec, + pub blst_p2_list: Vec, + pub node_count: usize, +} + +#[derive(Clone, Debug)] +pub enum Args { + Force(usize), + Apply(usize, Term), +} + impl Term { fn traverse_uplc_with_helper( &mut self, scope: &Scope, - mut arg_stack: Vec<(usize, Term)>, + mut arg_stack: Vec, id_gen: &mut IdGen, - with: &mut impl FnMut(Option, &mut Term, Vec<(usize, Term)>, &Scope), + with: &mut impl FnMut(Option, &mut Term, Vec, &Scope, &mut Context), + context: &mut Context, inline_lambda: bool, ) { match self { Term::Apply { function, argument } => { let arg = Rc::make_mut(argument); - Self::traverse_uplc_with_helper( - arg, + arg.traverse_uplc_with_helper( &scope.push(ScopePath::ARG), vec![], id_gen, with, + context, inline_lambda, ); let apply_id = id_gen.next_id(); - arg_stack.push((apply_id, arg.clone())); + arg_stack.push(Args::Apply(apply_id, arg.clone())); let func = Rc::make_mut(function); - Self::traverse_uplc_with_helper( - func, + func.traverse_uplc_with_helper( &scope.push(ScopePath::FUNC), arg_stack, id_gen, with, + context, inline_lambda, ); - scope.pop(); + with(Some(apply_id), self, vec![], scope, context); + } + Term::Force(f) => { + let f = Rc::make_mut(f); + let force_id = id_gen.next_id(); - with(Some(apply_id), self, vec![], scope); + arg_stack.push(Args::Force(force_id)); + + f.traverse_uplc_with_helper(scope, arg_stack, id_gen, with, context, inline_lambda); + + with(Some(force_id), self, vec![], scope, context); } Term::Delay(d) => { let d = Rc::make_mut(d); - // First we recurse further to reduce the inner terms before coming back up to the Delay - Self::traverse_uplc_with_helper(d, scope, arg_stack, id_gen, with, inline_lambda); - with(None, self, vec![], scope); + let delay_arg = arg_stack + .pop() + .map(|arg| { + assert!(matches!(arg, Args::Force(_))); + vec![arg] + }) + .unwrap_or_default(); + + d.traverse_uplc_with_helper(scope, arg_stack, id_gen, with, context, inline_lambda); + + with(None, self, delay_arg, scope, context); } Term::Lambda { - parameter_name: p, + parameter_name, body, } => { - let p = p.as_ref().clone(); + let p = parameter_name.clone(); + // Lambda pops one item off the arg stack. If there is no item then it is a unsaturated lambda - // We also skip NO_INLINE lambdas since those are placeholder lambdas created by codegen - let args = if p.text == NO_INLINE { + // NO_INLINE lambdas come in with 0 arguments on the arg stack + let args = if parameter_name.text == NO_INLINE { vec![] } else { - arg_stack.pop().map(|arg| vec![arg]).unwrap_or_default() + arg_stack + .pop() + .map(|arg| { + assert!(matches!(arg, Args::Apply(_, _))); + vec![arg] + }) + .unwrap_or_default() }; if inline_lambda { // Pass in either one or zero args. // For lambda we run the function with first then recurse on the body or replaced term - with(None, self, args, scope); + + with(None, self, args, scope, context); match self { Term::Lambda { parameter_name, body, - } if parameter_name.as_ref() == &p => { + } if parameter_name.text == p.text && parameter_name.unique == p.unique => { let body = Rc::make_mut(body); - Self::traverse_uplc_with_helper( - body, + body.traverse_uplc_with_helper( scope, arg_stack, id_gen, with, + context, inline_lambda, ); } Term::Constr { .. } => todo!(), Term::Case { .. } => todo!(), - other => Self::traverse_uplc_with_helper( - other, + other => other.traverse_uplc_with_helper( scope, arg_stack, id_gen, with, + context, inline_lambda, ), } } else { let body = Rc::make_mut(body); - Self::traverse_uplc_with_helper( - body, + body.traverse_uplc_with_helper( scope, arg_stack, id_gen, with, + context, inline_lambda, ); - with(None, self, args, scope); + with(None, self, args, scope, context); } } - Term::Force(f) => { - let f = Rc::make_mut(f); - Self::traverse_uplc_with_helper(f, scope, arg_stack, id_gen, with, inline_lambda); - with(None, self, vec![], scope); - } Term::Case { .. } => todo!(), Term::Constr { .. } => todo!(), Term::Builtin(func) => { let mut args = vec![]; - for _ in 0..func.arity() { + for _ in 0..(func.arity() + usize::try_from(func.force_count()).unwrap()) { if let Some(arg) = arg_stack.pop() { args.push(arg); } } // Pass in args up to function arity. - with(None, self, args, scope); + with(None, self, args, scope, context); } term => { - with(None, term, vec![], scope); + with(None, term, vec![], scope, context); + } + } + context.node_count += 1; + } + + fn substitute_var(&mut self, original: Rc, replace_with: &Term) { + match self { + Term::Var(name) if name.text == original.text && name.unique == original.unique => { + *self = replace_with.clone(); + } + Term::Delay(body) => Rc::make_mut(body).substitute_var(original, replace_with), + Term::Lambda { + parameter_name, + body, + } if parameter_name.text != original.text + || parameter_name.unique != original.unique => + { + Rc::make_mut(body).substitute_var(original, replace_with); + } + Term::Apply { function, argument } => { + Rc::make_mut(function).substitute_var(original.clone(), replace_with); + Rc::make_mut(argument).substitute_var(original, replace_with); + } + Term::Force(f) => { + Rc::make_mut(f).substitute_var(original, replace_with); + } + Term::Case { .. } => todo!(), + Term::Constr { .. } => todo!(), + _ => (), + } + } + + fn replace_identity_usage(&mut self, original: Rc) { + match self { + Term::Delay(body) => { + Rc::make_mut(body).replace_identity_usage(original.clone()); + } + Term::Lambda { + parameter_name, + body, + } => { + if parameter_name.text != original.text || parameter_name.unique != original.unique + { + Rc::make_mut(body).replace_identity_usage(original.clone()); + } + } + Term::Apply { function, argument } => { + let func = Rc::make_mut(function); + let arg = Rc::make_mut(argument); + + func.replace_identity_usage(original.clone()); + arg.replace_identity_usage(original.clone()); + + let Term::Var(name) = &func else { + return; + }; + + if name.text == original.text && name.unique == original.unique { + *self = std::mem::replace(arg, Term::Error.force()); + } + } + Term::Force(f) => { + Rc::make_mut(f).replace_identity_usage(original.clone()); + } + Term::Case { .. } => todo!(), + Term::Constr { .. } => todo!(), + _ => (), + } + } + + fn var_occurrences( + &self, + search_for: Rc, + mut arg_stack: Vec<()>, + mut force_stack: Vec<()>, + ) -> VarLookup { + match self { + Term::Var(name) => { + if name.text == search_for.text && name.unique == search_for.unique { + VarLookup::new_found() + } else { + VarLookup::new() + } + } + Term::Delay(body) => { + let not_forced: isize = isize::from(force_stack.pop().is_none()); + + body.var_occurrences(search_for, arg_stack, force_stack) + .delay_if_found(not_forced) + } + Term::Lambda { + parameter_name, + body, + } => { + if parameter_name.text == NO_INLINE { + body.var_occurrences(search_for, arg_stack, force_stack) + .no_inline_if_found() + } else if parameter_name.text == search_for.text + && parameter_name.unique == search_for.unique + { + VarLookup::new() + } else { + let not_applied: isize = isize::from(arg_stack.pop().is_none()); + body.var_occurrences(search_for, arg_stack, force_stack) + .delay_if_found(not_applied) + } + } + Term::Apply { function, argument } => { + arg_stack.push(()); + + function + .var_occurrences(search_for.clone(), arg_stack, force_stack) + .combine(argument.var_occurrences(search_for, vec![], vec![])) + } + Term::Force(x) => { + force_stack.push(()); + x.var_occurrences(search_for, arg_stack, force_stack) + } + Term::Case { .. } => todo!(), + Term::Constr { .. } => todo!(), + _ => VarLookup::new(), + } + } + + fn lambda_reducer( + &mut self, + _id: Option, + mut arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) -> bool { + let mut changed = false; + match self { + Term::Lambda { + parameter_name, + body, + } => { + // pops stack here no matter what + if let Some(Args::Apply(arg_id, arg_term)) = arg_stack.pop() { + let replace = match &arg_term { + // Do nothing for String consts + Term::Constant(c) if matches!(c.as_ref(), Constant::String(_)) => false, + // Inline Delay Error terms since total size is only 1 byte + // So it costs more size to have them hoisted + Term::Delay(e) if matches!(e.as_ref(), Term::Error) => true, + // If it wraps a builtin with consts or arguments passed in then inline + l @ Term::Lambda { .. } if is_a_builtin_wrapper(l) => true, + // Inline smaller terms too + Term::Constant(_) | Term::Var(_) | Term::Builtin(_) => true, + + _ => false, + }; + changed = replace; + + if replace { + let body = Rc::make_mut(body); + context.inlined_apply_ids.push(arg_id); + + body.substitute_var(parameter_name.clone(), arg_term.pierce_no_inlines()); + // creates new body that replaces all var occurrences with the arg + *self = std::mem::replace(body, Term::Error.force()); + } + } + } + + Term::Case { .. } => todo!(), + Term::Constr { .. } => todo!(), + _ => (), + }; + + changed + } + + // IMPORTANT: RUNS ONE TIME + fn builtin_force_reducer( + &mut self, + _id: Option, + mut arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) { + if let Term::Builtin(func) = self { + arg_stack.reverse(); + let has_forces = func.force_count() > 0; + while let Some(Args::Force(id)) = arg_stack.pop() { + context.inlined_apply_ids.push(id); + } + + if has_forces { + context.builtins_map.insert(*func as u8, ()); + *self = Term::var(format!("__{}_wrapped", func.aiken_name())); + } + } + } + + // IMPORTANT: RUNS ONE TIME + fn bls381_compressor( + &mut self, + _id: Option, + _arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) { + if let Term::Constant(con) = self { + match con.as_ref() { + Constant::Bls12_381G1Element(blst_p1) => { + if let Some(index) = context + .blst_p1_list + .iter() + .position(|item| item == blst_p1.as_ref()) + { + *self = Term::var(format!("blst_p1_index_{}", index)); + } else { + context.blst_p1_list.push(*blst_p1.as_ref()); + *self = + Term::var(format!("blst_p1_index_{}", context.blst_p1_list.len() - 1)); + } + } + Constant::Bls12_381G2Element(blst_p2) => { + if let Some(index) = context + .blst_p2_list + .iter() + .position(|item| item == blst_p2.as_ref()) + { + *self = Term::var(format!("blst_p2_index_{}", index)); + } else { + context.blst_p2_list.push(*blst_p2.as_ref()); + *self = + Term::var(format!("blst_p2_index_{}", context.blst_p2_list.len() - 1)); + } + } + _ => (), + } + } + } + + fn identity_reducer( + &mut self, + _id: Option, + mut arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) -> bool { + let mut changed = false; + + match self { + Term::Lambda { + parameter_name, + body, + } => { + let body = Rc::make_mut(body); + // pops stack here no matter what + let temp = Term::Error; + + if let ( + arg_id, + Term::Lambda { + parameter_name: identity_name, + body: identity_body, + }, + ) = match &arg_stack.pop() { + Some(Args::Apply( + arg_id, + Term::Lambda { + parameter_name: inline_name, + body, + }, + )) if inline_name.text == NO_INLINE => (*arg_id, body.as_ref()), + Some(Args::Apply(arg_id, term)) => (*arg_id, term), + _ => (0, &temp), + } { + let Term::Var(identity_var) = identity_body.as_ref() else { + return false; + }; + + if identity_var.text == identity_name.text + && identity_var.unique == identity_name.unique + { + // Replace all applied usages of identity with the arg + body.replace_identity_usage(parameter_name.clone()); + // Have to check if the body still has any occurrences of the parameter + // After attempting replacement + if !body + .var_occurrences(parameter_name.clone(), vec![], vec![]) + .found + { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = std::mem::replace(body, Term::Error.force()); + } + } + } + } + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + _ => (), + }; + + changed + } + + fn inline_reducer( + &mut self, + _id: Option, + mut arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) -> bool { + let mut changed = false; + + match self { + Term::Lambda { + parameter_name, + body, + } => { + // pops stack here no matter what + if let Some(Args::Apply(arg_id, arg_term)) = arg_stack.pop() { + let arg_term = match &arg_term { + Term::Lambda { + parameter_name, + body, + } if parameter_name.text == NO_INLINE => body.as_ref().clone(), + _ => arg_term, + }; + + let body = Rc::make_mut(body); + + let var_lookup = body.var_occurrences(parameter_name.clone(), vec![], vec![]); + + let substitute_condition = (var_lookup.delays == 0 && !var_lookup.no_inline) + || matches!( + &arg_term, + Term::Var(_) + | Term::Constant(_) + | Term::Delay(_) + | Term::Lambda { .. } + | Term::Builtin(_), + ); + + if var_lookup.occurrences == 1 && substitute_condition { + changed = true; + body.substitute_var(parameter_name.clone(), arg_term.pierce_no_inlines()); + + context.inlined_apply_ids.push(arg_id); + *self = std::mem::replace(body, Term::Error.force()); + + // This will strip out unused terms that can't throw an error by themselves + } else if !var_lookup.found + && matches!( + arg_term, + Term::Var(_) + | Term::Constant(_) + | Term::Delay(_) + | Term::Lambda { .. } + | Term::Builtin(_) + ) + { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = std::mem::replace(body, Term::Error.force()); + } + } + } + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + _ => {} + }; + changed + } + + fn force_delay_reducer( + &mut self, + _id: Option, + mut arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) -> bool { + let mut changed = false; + if let Term::Delay(d) = self { + if let Some(Args::Force(id)) = arg_stack.pop() { + changed = true; + context.inlined_apply_ids.push(id); + *self = std::mem::replace(Rc::make_mut(d), Term::Error.force()) + } else if let Term::Force(var) = d.as_ref() { + if let Term::Var(_) = var.as_ref() { + changed = true; + *self = var.as_ref().clone(); + } + } + } + changed + } + + fn remove_no_inlines( + &mut self, + _id: Option, + _arg_stack: Vec, + _scope: &Scope, + _context: &mut Context, + ) { + match self { + Term::Lambda { + parameter_name, + body, + } if parameter_name.text == NO_INLINE => { + *self = std::mem::replace(Rc::make_mut(body), Term::Error.force()); + } + _ => (), + } + } + + // IMPORTANT: RUNS ONE TIME + fn inline_constr_ops( + &mut self, + _id: Option, + _arg_stack: Vec, + _scope: &Scope, + _context: &mut Context, + ) { + if let Term::Apply { function, argument } = self { + if let Term::Var(name) = function.as_ref() { + let arg = Rc::make_mut(argument); + if name.text == CONSTR_FIELDS_EXPOSER { + *self = Term::snd_pair().apply( + Term::unconstr_data().apply(std::mem::replace(arg, Term::Error.force())), + ) + } else if name.text == CONSTR_INDEX_EXPOSER { + *self = Term::fst_pair().apply( + Term::unconstr_data().apply(std::mem::replace(arg, Term::Error.force())), + ) + } + } + } + } + + fn cast_data_reducer( + &mut self, + _id: Option, + mut arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) -> bool { + let mut changed = false; + + match self { + Term::Builtin(first_function) => { + let Some(Args::Apply(arg_id, mut arg_term)) = arg_stack.pop() else { + return false; + }; + + match &mut arg_term { + Term::Apply { function, argument } => { + if let Term::Builtin(second_function) = function.as_ref() { + match (first_function, second_function) { + (DefaultFunction::UnIData, DefaultFunction::IData) + | (DefaultFunction::IData, DefaultFunction::UnIData) + | (DefaultFunction::BData, DefaultFunction::UnBData) + | (DefaultFunction::UnBData, DefaultFunction::BData) + | (DefaultFunction::ListData, DefaultFunction::UnListData) + | (DefaultFunction::UnListData, DefaultFunction::ListData) + | (DefaultFunction::MapData, DefaultFunction::UnMapData) + | (DefaultFunction::UnMapData, DefaultFunction::MapData) => { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = std::mem::replace( + Rc::make_mut(argument), + Term::Error.force(), + ); + } + _ => {} + } + } + } + Term::Constant(c) => match (first_function, c.as_ref()) { + ( + DefaultFunction::UnIData, + Constant::Data(PlutusData::BigInt(BigInt::Int(i))), + ) => { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = Term::integer(i128::from(*i).into()); + } + (DefaultFunction::IData, Constant::Integer(i)) => { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = Term::data(Data::integer(i.clone())); + } + (DefaultFunction::UnBData, Constant::Data(PlutusData::BoundedBytes(b))) => { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = Term::byte_string(b.clone().into()); + } + (DefaultFunction::BData, Constant::ByteString(b)) => { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = Term::data(Data::bytestring(b.clone())); + } + (DefaultFunction::UnListData, Constant::Data(PlutusData::Array(l))) => { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = Term::list_values( + l.iter() + .map(|item| Constant::Data(item.clone())) + .collect_vec(), + ); + } + (DefaultFunction::ListData, Constant::ProtoList(_, l)) => { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = Term::data(Data::list( + l.iter() + .map(|item| match item { + Constant::Data(d) => d.clone(), + _ => unreachable!(), + }) + .collect_vec(), + )); + } + (DefaultFunction::MapData, Constant::ProtoList(_, m)) => { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = Term::data(Data::map( + m.iter() + .map(|m| match m { + Constant::ProtoPair(_, _, f, s) => { + match (f.as_ref(), s.as_ref()) { + (Constant::Data(d), Constant::Data(d2)) => { + (d.clone(), d2.clone()) + } + _ => unreachable!(), + } + } + _ => unreachable!(), + }) + .collect_vec(), + )); + } + (DefaultFunction::UnMapData, Constant::Data(PlutusData::Map(m))) => { + changed = true; + context.inlined_apply_ids.push(arg_id); + *self = Term::map_values( + m.iter() + .map(|item| { + Constant::ProtoPair( + Type::Data, + Type::Data, + Constant::Data(item.0.clone()).into(), + Constant::Data(item.1.clone()).into(), + ) + }) + .collect_vec(), + ); + } + _ => {} + }, + _ => {} + } + } + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + _ => {} + } + changed + } + + // Converts subtract integer with a constant to add integer with a negative constant + fn convert_arithmetic_ops( + &mut self, + _id: Option, + arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) -> bool { + let mut changed = false; + match self { + Term::Builtin(d @ DefaultFunction::SubtractInteger) => { + if arg_stack.len() == d.arity() { + let Some(Args::Apply(apply_id, Term::Constant(_))) = arg_stack.last() else { + return false; + }; + changed = true; + context.constants_to_flip.push(*apply_id); + + *self = Term::Builtin(DefaultFunction::AddInteger); + } + } + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + _ => {} + } + changed + } + + fn builtin_eval_reducer( + &mut self, + _id: Option, + mut arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) -> bool { + let mut changed = false; + + match self { + Term::Builtin(func) => { + arg_stack = arg_stack + .into_iter() + .filter(|args| matches!(args, Args::Apply(_, _))) + .collect_vec(); + + let args = arg_stack + .iter() + .map(|args| { + let Args::Apply(_, term) = args else { + unreachable!() + }; + + term.pierce_no_inlines() + }) + .collect_vec(); + if func.can_curry_builtin() + && arg_stack.len() == func.arity() + && func.is_error_safe(&args) + { + changed = true; + let applied_term = + arg_stack + .into_iter() + .fold(Term::Builtin(*func), |acc, item| { + let Args::Apply(arg_id, arg) = item else { + unreachable!() + }; + + context.inlined_apply_ids.push(arg_id); + acc.apply(arg.pierce_no_inlines().clone()) + }); + + // Check above for is error safe + let eval_term: Term = Program { + version: (1, 0, 0), + term: applied_term, + } + .to_named_debruijn() + .unwrap() + .eval(ExBudget::max()) + .result() + .unwrap() + .try_into() + .unwrap(); + + *self = eval_term; + } + } + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + _ => (), + } + changed + } + + fn remove_inlined_ids( + &mut self, + id: Option, + _arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) { + match self { + Term::Apply { function, .. } | Term::Force(function) => { + // We inlined the arg so now we remove the application of it + let Some(id) = id else { + return; + }; + + if context.inlined_apply_ids.contains(&id) { + let func = Rc::make_mut(function); + *self = std::mem::replace(func, Term::Error.force()); + } + } + _ => (), + } + } + + fn flip_constants( + &mut self, + id: Option, + _arg_stack: Vec, + _scope: &Scope, + context: &mut Context, + ) { + if let Term::Apply { argument, .. } = self { + let Some(id) = id else { + return; + }; + + if context.constants_to_flip.contains(&id) { + let Term::Constant(c) = Rc::make_mut(argument) else { + unreachable!(); + }; + + let Constant::Integer(i) = c.as_ref() else { + unreachable!(); + }; + + *c = Constant::Integer(i.neg()).into(); } } } @@ -1007,112 +1748,72 @@ impl Program { fn traverse_uplc_with( self, inline_lambda: bool, - with: &mut impl FnMut(Option, &mut Term, Vec<(usize, Term)>, &Scope), - ) -> Self { + with: &mut impl FnMut(Option, &mut Term, Vec, &Scope, &mut Context), + ) -> (Self, Context) { let mut term = self.term; let scope = Scope { scope: vec![] }; let arg_stack = vec![]; let mut id_gen = IdGen::new(); - term.traverse_uplc_with_helper(&scope, arg_stack, &mut id_gen, with, inline_lambda); - Program { - version: self.version, - term, - } + let mut context = Context { + inlined_apply_ids: vec![], + constants_to_flip: vec![], + builtins_map: IndexMap::new(), + blst_p1_list: vec![], + blst_p2_list: vec![], + node_count: 0, + }; + + term.traverse_uplc_with_helper( + &scope, + arg_stack, + &mut id_gen, + with, + &mut context, + inline_lambda, + ); + ( + Program { + version: self.version, + term, + }, + context, + ) } + // This one runs the optimizations that are only done a single time + pub fn run_once_pass(self) -> Self { + let program = self + .traverse_uplc_with(false, &mut |id, term, _arg_stack, scope, context| { + term.inline_constr_ops(id, vec![], scope, context); + }) + .0; - pub fn lambda_reducer(self) -> Self { - let mut lambda_applied_ids = vec![]; + let (program, context) = + program.traverse_uplc_with(false, &mut |id, term, arg_stack, scope, context| { + term.bls381_compressor(id, vec![], scope, context); + term.builtin_force_reducer(id, arg_stack, scope, context); + term.remove_inlined_ids(id, vec![], scope, context); + }); - self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| { - match term { - Term::Apply { function, .. } => { - // We are applying some arg so now we unwrap the id of the applied arg - let id = id.unwrap(); - - if lambda_applied_ids.contains(&id) { - let func = Rc::make_mut(function); - // we inlined the arg so now remove the apply and arg from the program - *term = func.clone(); - } - } - Term::Lambda { - parameter_name, - body, - } => { - // pops stack here no matter what - if let Some((arg_id, arg_term)) = arg_stack.pop() { - match &arg_term { - Term::Constant(c) if matches!(c.as_ref(), Constant::String(_)) => {} - Term::Delay(e) if matches!(e.as_ref(), Term::Error) => { - let body = Rc::make_mut(body); - lambda_applied_ids.push(arg_id); - // creates new body that replaces all var occurrences with the arg - *term = substitute_var(body, parameter_name.clone(), &arg_term); - } - Term::Constant(_) | Term::Var(_) | Term::Builtin(_) => { - let body = Rc::make_mut(body); - lambda_applied_ids.push(arg_id); - // creates new body that replaces all var occurrences with the arg - *term = substitute_var(body, parameter_name.clone(), &arg_term); - } - l @ Term::Lambda { .. } => { - if is_a_builtin_wrapper(l) { - let body = Rc::make_mut(body); - lambda_applied_ids.push(arg_id); - // creates new body that replaces all var occurrences with the arg - *term = substitute_var(body, parameter_name.clone(), &arg_term); - } - } - - _ => {} - } - } - } - - Term::Case { .. } => todo!(), - Term::Constr { .. } => todo!(), - _ => {} - } - }) - } - - pub fn builtin_force_reducer(self) -> Self { - let mut builtin_map = IndexMap::new(); - - let program = self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| { - if let Term::Force(f) = term { - let f = Rc::make_mut(f); - match f { - Term::Force(inner_f) => { - if let Term::Builtin(func) = inner_f.as_ref() { - builtin_map.insert(*func as u8, ()); - *term = Term::Var( - Name { - text: format!("__{}_wrapped", func.aiken_name()), - unique: 0.into(), - } - .into(), - ); - } - } - Term::Builtin(func) if func.force_count() == 1 => { - builtin_map.insert(*func as u8, ()); - *term = Term::Var( - Name { - text: format!("__{}_wrapped", func.aiken_name()), - unique: 0.into(), - } - .into(), - ); - } - _ => {} - } - } - }); let mut term = program.term; - for default_func_index in builtin_map.keys().sorted().cloned() { + for (index, blst_p1) in context.blst_p1_list.into_iter().enumerate() { + let compressed = blst_p1.compress(); + + term = term + .lambda(format!("blst_p1_index_{}", index)) + .apply(Term::bls12_381_g1_uncompress().apply(Term::byte_string(compressed))); + } + + for (index, blst_p2) in context.blst_p2_list.into_iter().enumerate() { + let compressed = blst_p2.compress(); + + term = term + .lambda(format!("blst_p2_index_{}", index)) + .apply(Term::bls12_381_g2_uncompress().apply(Term::byte_string(compressed))); + } + + for default_func_index in context.builtins_map.keys().sorted().cloned() { let default_func: DefaultFunction = default_func_index.try_into().unwrap(); term = term @@ -1138,452 +1839,67 @@ impl Program { Program::::try_from(program).unwrap() } - pub fn bls381_compressor(self) -> Self { - let mut blst_p1_list = vec![]; - let mut blst_p2_list = vec![]; + pub fn multi_pass(self) -> (Self, Context) { + self.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| { + let mut changed; - let program = self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| { - if let Term::Constant(con) = term { - match con.as_ref() { - Constant::Bls12_381G1Element(blst_p1) => { - if let Some(index) = blst_p1_list - .iter() - .position(|item| item == blst_p1.as_ref()) - { - *term = Term::var(format!("blst_p1_index_{}", index)); - } else { - blst_p1_list.push(*blst_p1.as_ref()); - *term = Term::var(format!("blst_p1_index_{}", blst_p1_list.len() - 1)); - } - } - Constant::Bls12_381G2Element(blst_p2) => { - if let Some(index) = blst_p2_list - .iter() - .position(|item| item == blst_p2.as_ref()) - { - *term = Term::var(format!("blst_p2_index_{}", index)); - } else { - blst_p2_list.push(*blst_p2.as_ref()); - *term = Term::var(format!("blst_p2_index_{}", blst_p2_list.len() - 1)); - } - } - _ => (), - } + changed = term.lambda_reducer(id, arg_stack.clone(), scope, context); + if changed { + term.remove_inlined_ids(id, vec![], scope, context); + return; } - }); - let mut term = program.term; - - for (index, blst_p1) in blst_p1_list.into_iter().enumerate() { - let compressed = blst_p1.compress(); - - term = term - .lambda(format!("blst_p1_index_{}", index)) - .apply(Term::bls12_381_g1_uncompress().apply(Term::byte_string(compressed))); - } - - for (index, blst_p2) in blst_p2_list.into_iter().enumerate() { - let compressed = blst_p2.compress(); - - term = term - .lambda(format!("blst_p2_index_{}", index)) - .apply(Term::bls12_381_g2_uncompress().apply(Term::byte_string(compressed))); - } - - let mut program = Program { - version: program.version, - term, - }; - - let mut interner = CodeGenInterner::new(); - - interner.program(&mut program); - - let program = Program::::try_from(program).unwrap(); - - Program::::try_from(program).unwrap() - } - - pub fn identity_reducer(self) -> Self { - let mut identity_applied_ids = vec![]; - self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| { - match term { - Term::Apply { function, .. } => { - // We are applying some arg so now we unwrap the id of the applied arg - let id = id.unwrap(); - - if identity_applied_ids.contains(&id) { - let func = Rc::make_mut(function); - // we inlined the arg so now remove the apply and arg from the program - *term = func.clone(); - } - } - Term::Lambda { - parameter_name, - body, - } => { - // pops stack here no matter what - - match arg_stack.pop() { - Some(( - arg_id, - Term::Lambda { - parameter_name: inline_name, - body: identity_body, - }, - )) if inline_name.text == NO_INLINE => { - if let Term::Lambda { - parameter_name: identity_name, - body: identity_body, - } = identity_body.as_ref() - { - if let Term::Var(identity_var) = identity_body.as_ref() { - if identity_var.text == identity_name.text - && identity_var.unique == identity_name.unique - { - // Replace all applied usages of identity with the arg - let temp_term = replace_identity_usage( - body.as_ref(), - parameter_name.clone(), - ); - // Have to check if the body still has any occurrences of the parameter - // After attempting replacement - if var_occurrences( - &temp_term, - parameter_name.clone(), - vec![], - vec![], - ) - .found - { - let body = Rc::make_mut(body); - *body = temp_term; - } else { - identity_applied_ids.push(arg_id); - *term = temp_term; - } - } - } - } - } - - Some(( - arg_id, - Term::Lambda { - parameter_name: identity_name, - body: identity_body, - }, - )) => { - if let Term::Var(identity_var) = identity_body.as_ref() { - if identity_var.text == identity_name.text - && identity_var.unique == identity_name.unique - { - // Replace all applied usages of identity with the arg - let temp_term = replace_identity_usage( - body.as_ref(), - parameter_name.clone(), - ); - // Have to check if the body still has any occurrences of the parameter - // After attempting replacement - if var_occurrences( - &temp_term, - parameter_name.clone(), - vec![], - vec![], - ) - .found - { - let body = Rc::make_mut(body); - *body = temp_term; - } else { - identity_applied_ids.push(arg_id); - *term = temp_term; - } - } - } - } - _ => {} - } - } - Term::Constr { .. } => todo!(), - Term::Case { .. } => todo!(), - _ => {} + changed = term.identity_reducer(id, arg_stack.clone(), scope, context); + if changed { + term.remove_inlined_ids(id, vec![], scope, context); + return; } + changed = term.inline_reducer(id, arg_stack.clone(), scope, context); + if changed { + term.remove_inlined_ids(id, vec![], scope, context); + return; + } + changed = term.force_delay_reducer(id, arg_stack.clone(), scope, context); + if changed { + term.remove_inlined_ids(id, vec![], scope, context); + return; + } + changed = term.cast_data_reducer(id, arg_stack.clone(), scope, context); + if changed { + term.remove_inlined_ids(id, vec![], scope, context); + return; + } + changed = term.builtin_eval_reducer(id, arg_stack.clone(), scope, context); + if changed { + term.remove_inlined_ids(id, vec![], scope, context); + return; + } + term.convert_arithmetic_ops(id, arg_stack, scope, context); + term.flip_constants(id, vec![], scope, context); + term.remove_inlined_ids(id, vec![], scope, context); }) } - pub fn inline_reducer(self) -> Self { - let mut lambda_applied_ids = vec![]; - - self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| match term { - Term::Apply { function, .. } => { - // We are applying some arg so now we unwrap the id of the applied arg - let id = id.unwrap(); - - if lambda_applied_ids.contains(&id) { - // we inlined the arg so now remove the apply and arg from the program - *term = function.as_ref().clone(); - } - } - Term::Lambda { - parameter_name, - body, - } => { - // pops stack here no matter what - if let Some((arg_id, arg_term)) = arg_stack.pop() { - let arg_term = match &arg_term { - Term::Lambda { - parameter_name, - body, - } if parameter_name.text == NO_INLINE => body.as_ref().clone(), - _ => arg_term, - }; - - let body = Rc::make_mut(body); - let var_lookup = var_occurrences(body, parameter_name.clone(), vec![], vec![]); - - let substitute_condition = (var_lookup.delays == 0 && !var_lookup.no_inline) - || matches!( - &arg_term, - Term::Var(_) - | Term::Constant(_) - | Term::Delay(_) - | Term::Lambda { .. } - | Term::Builtin(_), - ); - - if var_lookup.occurrences == 1 && substitute_condition { - *body = substitute_var(body, parameter_name.clone(), &arg_term); - - lambda_applied_ids.push(arg_id); - *term = body.clone(); - - // This will strip out unused terms that can't throw an error by themselves - } else if !var_lookup.found - && matches!( - arg_term, - Term::Var(_) - | Term::Constant(_) - | Term::Delay(_) - | Term::Lambda { .. } - | Term::Builtin(_) - ) - { - lambda_applied_ids.push(arg_id); - *term = body.clone(); - } - } - } - Term::Constr { .. } => todo!(), - Term::Case { .. } => todo!(), - _ => {} + pub fn run_one_opt( + self, + inline_lambda: bool, + with: &mut impl FnMut(Option, &mut Term, Vec, &Scope, &mut Context), + ) -> Self { + self.traverse_uplc_with(inline_lambda, &mut |id, term, arg_stack, scope, context| { + with(id, term, arg_stack, scope, context); + term.flip_constants(id, vec![], scope, context); + term.remove_inlined_ids(id, vec![], scope, context); }) + .0 } - pub fn force_delay_reducer(self) -> Self { - self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| { - if let Term::Force(f) = term { - let f = f.as_ref(); - - if let Term::Delay(body) = f { - *term = body.as_ref().clone(); - } - } - }) - } - - pub fn remove_no_inlines(self) -> Self { - self.traverse_uplc_with(true, &mut |_, term, _, _| match term { - Term::Lambda { - parameter_name, - body, - } if parameter_name.text == NO_INLINE => *term = body.as_ref().clone(), - _ => {} - }) - } - - pub fn inline_constr_ops(self) -> Self { - self.traverse_uplc_with(true, &mut |_, term, _, _| { - if let Term::Apply { function, argument } = term { - if let Term::Var(name) = function.as_ref() { - if name.text == CONSTR_FIELDS_EXPOSER { - *term = Term::snd_pair().apply(Term::Apply { - function: Term::unconstr_data().into(), - argument: argument.clone(), - }) - } else if name.text == CONSTR_INDEX_EXPOSER { - *term = Term::fst_pair().apply(Term::Apply { - function: Term::unconstr_data().into(), - argument: argument.clone(), - }) - } - } - } - }) - } - - pub fn cast_data_reducer(self) -> Self { - let mut applied_ids = vec![]; - - self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| { - match term { - Term::Apply { function, .. } => { - // We are apply some arg so now we unwrap the id of the applied arg - let id = id.unwrap(); - - if applied_ids.contains(&id) { - let func = Rc::make_mut(function); - // we inlined the arg so now remove the apply and arg from the program - *term = func.clone(); - } - } - - Term::Builtin(first_function) => { - let Some((arg_id, arg_term)) = arg_stack.pop() else { - return; - }; - - match arg_term { - Term::Apply { function, argument } => { - if let Term::Builtin(second_function) = function.as_ref() { - match (first_function, second_function) { - (DefaultFunction::UnIData, DefaultFunction::IData) - | (DefaultFunction::IData, DefaultFunction::UnIData) - | (DefaultFunction::BData, DefaultFunction::UnBData) - | (DefaultFunction::UnBData, DefaultFunction::BData) - | (DefaultFunction::ListData, DefaultFunction::UnListData) - | (DefaultFunction::UnListData, DefaultFunction::ListData) - | (DefaultFunction::MapData, DefaultFunction::UnMapData) - | (DefaultFunction::UnMapData, DefaultFunction::MapData) => { - applied_ids.push(arg_id); - *term = argument.as_ref().clone(); - } - _ => {} - } - } - } - Term::Constant(c) => match (first_function, c.as_ref()) { - ( - DefaultFunction::UnIData, - Constant::Data(PlutusData::BigInt(BigInt::Int(i))), - ) => { - applied_ids.push(arg_id); - *term = Term::integer(i128::from(*i).into()); - } - (DefaultFunction::IData, Constant::Integer(i)) => { - applied_ids.push(arg_id); - *term = Term::data(Data::integer(i.clone())); - } - ( - DefaultFunction::UnBData, - Constant::Data(PlutusData::BoundedBytes(b)), - ) => { - applied_ids.push(arg_id); - *term = Term::byte_string(b.clone().into()); - } - (DefaultFunction::BData, Constant::ByteString(b)) => { - applied_ids.push(arg_id); - *term = Term::data(Data::bytestring(b.clone())); - } - (DefaultFunction::UnListData, Constant::Data(PlutusData::Array(l))) => { - applied_ids.push(arg_id); - *term = Term::list_values( - l.iter() - .map(|item| Constant::Data(item.clone())) - .collect_vec(), - ); - } - (DefaultFunction::ListData, Constant::ProtoList(_, l)) => { - applied_ids.push(arg_id); - *term = Term::data(Data::list( - l.iter() - .map(|item| match item { - Constant::Data(d) => d.clone(), - _ => unreachable!(), - }) - .collect_vec(), - )); - } - (DefaultFunction::MapData, Constant::ProtoList(_, m)) => { - applied_ids.push(arg_id); - *term = Term::data(Data::map( - m.iter() - .map(|m| match m { - Constant::ProtoPair(_, _, f, s) => { - match (f.as_ref(), s.as_ref()) { - (Constant::Data(d), Constant::Data(d2)) => { - (d.clone(), d2.clone()) - } - _ => unreachable!(), - } - } - _ => unreachable!(), - }) - .collect_vec(), - )); - } - (DefaultFunction::UnMapData, Constant::Data(PlutusData::Map(m))) => { - applied_ids.push(arg_id); - *term = Term::map_values( - m.iter() - .map(|item| { - Constant::ProtoPair( - Type::Data, - Type::Data, - Constant::Data(item.0.clone()).into(), - Constant::Data(item.1.clone()).into(), - ) - }) - .collect_vec(), - ); - } - _ => {} - }, - _ => {} - } - } - Term::Constr { .. } => todo!(), - Term::Case { .. } => todo!(), - _ => {} - } - }) - } - - // Converts subtract integer with a constant to add integer with a negative constant - pub fn convert_arithmetic_ops(self) -> Self { - let mut constants_to_flip = vec![]; - - self.traverse_uplc_with(true, &mut |id, term, arg_stack, _scope| match term { - Term::Apply { argument, .. } => { - let id = id.unwrap(); - - if constants_to_flip.contains(&id) { - let Term::Constant(c) = Rc::make_mut(argument) else { - unreachable!(); - }; - - let Constant::Integer(i) = c.as_ref() else { - unreachable!(); - }; - - *c = Constant::Integer(i.neg()).into(); - } - } - Term::Builtin(d @ DefaultFunction::SubtractInteger) => { - if arg_stack.len() == d.arity() { - let Some((apply_id, Term::Constant(_))) = arg_stack.last() else { - return; - }; - constants_to_flip.push(*apply_id); - - *term = Term::Builtin(DefaultFunction::AddInteger); - } - } - Term::Constr { .. } => todo!(), - Term::Case { .. } => todo!(), - _ => {} + pub fn clean_up(self) -> Self { + self.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| { + term.remove_no_inlines(id, vec![], scope, context); }) + .0 } + // This one doesn't use the context since it's complicated and traverses the ast twice pub fn builtin_curry_reducer(self) -> Self { let mut curried_terms = vec![]; let mut id_mapped_curry_terms: IndexMap, usize)> = @@ -1596,10 +1912,20 @@ impl Program { let mut final_ids: IndexMap, ()> = IndexMap::new(); - let step_a = - self.traverse_uplc_with(false, &mut |_id, term, arg_stack, scope| match term { + let (step_a, _) = self.traverse_uplc_with( + false, + &mut |_id, term, arg_stack, scope, _context| match term { Term::Builtin(func) => { if func.can_curry_builtin() && arg_stack.len() == func.arity() { + let arg_stack = arg_stack + .into_iter() + .map(|item| { + let Args::Apply(arg_id, arg) = item else { + unreachable!() + }; + (arg_id, arg) + }) + .collect_vec(); // In the case of order agnostic builtins we want to sort the args by constant first // This gives us the opportunity to curry constants that often pop up in the code @@ -1679,7 +2005,8 @@ impl Program { Term::Constr { .. } => todo!(), Term::Case { .. } => todo!(), _ => {} - }); + }, + ); id_mapped_curry_terms .into_iter() @@ -1703,10 +2030,21 @@ impl Program { } }); - let mut step_b = - step_a.traverse_uplc_with(false, &mut |id, term, mut arg_stack, scope| match term { + let (mut step_b, _) = step_a.traverse_uplc_with( + false, + &mut |id, term, arg_stack, scope, _context| match term { Term::Builtin(func) => { if func.can_curry_builtin() && arg_stack.len() == func.arity() { + let mut arg_stack = arg_stack + .into_iter() + .map(|item| { + let Args::Apply(arg_id, arg) = item else { + unreachable!() + }; + (arg_id, arg) + }) + .collect_vec(); + let Some(curried_builtin) = curried_terms.iter().find(|curry| curry.func == *func) else { @@ -1759,7 +2097,9 @@ impl Program { for (key, val) in insert_list.into_iter().rev() { let name = id_vec_function_to_var(&key.func_name, &key.id_vec); - if var_occurrences(term, Name::text(&name).into(), vec![], vec![]).found + if term + .var_occurrences(Name::text(&name).into(), vec![], vec![]) + .found { *term = term.clone().lambda(name).apply(val); } @@ -1773,14 +2113,17 @@ impl Program { for (key, val) in insert_list.into_iter().rev() { let name = id_vec_function_to_var(&key.func_name, &key.id_vec); - if var_occurrences(term, Name::text(&name).into(), vec![], vec![]).found + if term + .var_occurrences(Name::text(&name).into(), vec![], vec![]) + .found { *term = term.clone().lambda(name).apply(val); } } } } - }); + }, + ); let mut interner = CodeGenInterner::new(); @@ -1788,56 +2131,6 @@ impl Program { step_b } - - pub fn builtin_eval_reducer(self) -> Self { - let mut applied_ids = vec![]; - - self.traverse_uplc_with(false, &mut |id, term, arg_stack, _scope| match term { - Term::Builtin(func) => { - let args = arg_stack - .iter() - .map(|(_, term)| term.pierce_no_inlines()) - .collect_vec(); - if func.can_curry_builtin() - && arg_stack.len() == func.arity() - && func.is_error_safe(&args) - { - let applied_term = - arg_stack - .into_iter() - .fold(Term::Builtin(*func), |acc, item| { - applied_ids.push(item.0); - acc.apply(item.1.pierce_no_inlines().clone()) - }); - - // Check above for is error safe - let eval_term: Term = Program { - version: (1, 0, 0), - term: applied_term, - } - .to_named_debruijn() - .unwrap() - .eval(ExBudget::max()) - .result() - .unwrap() - .try_into() - .unwrap(); - - *term = eval_term; - } - } - Term::Apply { function, .. } => { - let id = id.unwrap(); - - if applied_ids.contains(&id) { - *term = function.as_ref().clone(); - } - } - Term::Constr { .. } => todo!(), - Term::Case { .. } => todo!(), - _ => {} - }) - } } fn id_vec_function_to_var(func_name: &str, id_vec: &[usize]) -> String { @@ -1852,153 +2145,6 @@ fn id_vec_function_to_var(func_name: &str, id_vec: &[usize]) -> String { ) } -fn var_occurrences( - term: &Term, - search_for: Rc, - mut arg_stack: Vec<()>, - mut force_stack: Vec<()>, -) -> VarLookup { - match term { - Term::Var(name) => { - if name.text == search_for.text && name.unique == search_for.unique { - VarLookup::new_found() - } else { - VarLookup::new() - } - } - Term::Delay(body) => { - let not_forced: isize = isize::from(force_stack.pop().is_none()); - - var_occurrences(body, search_for, arg_stack, force_stack).delay_if_found(not_forced) - } - Term::Lambda { - parameter_name, - body, - } => { - if parameter_name.text == NO_INLINE { - var_occurrences(body.as_ref(), search_for, arg_stack, force_stack) - .no_inline_if_found() - } else if parameter_name.text == search_for.text - && parameter_name.unique == search_for.unique - { - VarLookup::new() - } else { - let not_applied: isize = isize::from(arg_stack.pop().is_none()); - var_occurrences(body.as_ref(), search_for, arg_stack, force_stack) - .delay_if_found(not_applied) - } - } - Term::Apply { function, argument } => { - arg_stack.push(()); - - var_occurrences( - function.as_ref(), - search_for.clone(), - arg_stack, - force_stack, - ) - .combine(var_occurrences( - argument.as_ref(), - search_for, - vec![], - vec![], - )) - } - Term::Force(x) => { - force_stack.push(()); - var_occurrences(x.as_ref(), search_for, arg_stack, force_stack) - } - Term::Case { .. } => todo!(), - Term::Constr { .. } => todo!(), - _ => VarLookup::new(), - } -} - -fn substitute_var(term: &Term, original: Rc, replace_with: &Term) -> Term { - match term { - Term::Var(name) => { - if name.text == original.text && name.unique == original.unique { - replace_with.clone() - } else { - Term::Var(name.clone()) - } - } - Term::Delay(body) => { - Term::Delay(substitute_var(body.as_ref(), original, replace_with).into()) - } - Term::Lambda { - parameter_name, - body, - } => { - if parameter_name.text == original.text && parameter_name.unique == original.unique { - Term::Lambda { - parameter_name: parameter_name.clone(), - body: body.clone(), - } - } else { - Term::Lambda { - parameter_name: parameter_name.clone(), - body: substitute_var(body.as_ref(), original, replace_with).into(), - } - } - } - Term::Apply { function, argument } => Term::Apply { - function: substitute_var(function.as_ref(), original.clone(), replace_with).into(), - argument: substitute_var(argument.as_ref(), original, replace_with).into(), - }, - Term::Force(f) => Term::Force(substitute_var(f.as_ref(), original, replace_with).into()), - Term::Case { .. } => todo!(), - Term::Constr { .. } => todo!(), - x => x.clone(), - } -} - -fn replace_identity_usage(term: &Term, original: Rc) -> Term { - match term { - Term::Delay(body) => Term::Delay(replace_identity_usage(body.as_ref(), original).into()), - Term::Lambda { - parameter_name, - body, - } => { - if parameter_name.text == original.text && parameter_name.unique == original.unique { - Term::Lambda { - parameter_name: parameter_name.clone(), - body: body.clone(), - } - } else { - Term::Lambda { - parameter_name: parameter_name.clone(), - body: Rc::new(replace_identity_usage(body.as_ref(), original)), - } - } - } - Term::Apply { function, argument } => { - let func = replace_identity_usage(function.as_ref(), original.clone()); - let arg = replace_identity_usage(argument.as_ref(), original.clone()); - - let Term::Var(name) = &func else { - return Term::Apply { - function: func.into(), - argument: arg.into(), - }; - }; - - if name.text == original.text && name.unique == original.unique { - arg - } else { - Term::Apply { - function: func.into(), - argument: arg.into(), - } - } - } - Term::Force(f) => Term::Force(Rc::new(replace_identity_usage(f.as_ref(), original))), - Term::Case { .. } => todo!(), - Term::Constr { .. } => todo!(), - x => x.clone(), - } -} - fn is_a_builtin_wrapper(term: &Term) -> bool { let (names, term) = pop_lambdas_and_get_names(term); @@ -2097,7 +2243,11 @@ mod tests { ), }; - compare_optimization(expected, program, |p| p.lambda_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.lambda_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2114,7 +2264,11 @@ mod tests { term: Term::integer(6.into()), }; - compare_optimization(expected, program, |p| p.lambda_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.lambda_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2129,7 +2283,11 @@ mod tests { term: Term::add_integer(), }; - compare_optimization(expected, program, |p| p.lambda_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.lambda_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2166,7 +2324,11 @@ mod tests { .apply(Term::bool(false).lambda("x")), }; - compare_optimization(expected, program, |p| p.lambda_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.lambda_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2199,7 +2361,7 @@ mod tests { .apply(Term::tail_list()), }; - compare_optimization(expected, program, |p| p.builtin_force_reducer()); + compare_optimization(expected, program, |p| p.run_once_pass()); } #[test] @@ -2244,7 +2406,7 @@ mod tests { .apply(Term::Builtin(DefaultFunction::IfThenElse).force()), }; - compare_optimization(expected, program, |p| p.builtin_force_reducer()); + compare_optimization(expected, program, |p| p.run_once_pass()); } #[test] @@ -2285,7 +2447,7 @@ mod tests { .apply(Term::snd_pair()), }; - compare_optimization(expected, program, |p| p.builtin_force_reducer()); + compare_optimization(expected, program, |p| p.run_once_pass()); } #[test] @@ -2308,7 +2470,11 @@ mod tests { .apply(Term::byte_string(vec![]).delay()), }; - compare_optimization(expected, program, |p| p.identity_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.identity_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2331,7 +2497,11 @@ mod tests { .apply(Term::byte_string(vec![]).delay()), }; - compare_optimization(expected, program, |p| p.identity_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.identity_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2378,7 +2548,11 @@ mod tests { ), }; - compare_optimization(expected, program, |p| p.identity_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.identity_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2401,7 +2575,11 @@ mod tests { .apply(Term::byte_string(vec![]).delay()), }; - compare_optimization(expected, program, |p| p.identity_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.identity_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2426,7 +2604,11 @@ mod tests { .lambda(NO_INLINE), }; - compare_optimization(expected, program, |p| p.identity_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.identity_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2444,7 +2626,11 @@ mod tests { term: Term::sha2_256().apply(Term::byte_string(vec![]).delay()), }; - compare_optimization(expected, program, |p| p.inline_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.inline_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2461,7 +2647,11 @@ mod tests { term: Term::sha2_256(), }; - compare_optimization(expected, program, |p| p.inline_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.inline_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2486,7 +2676,11 @@ mod tests { .lambda("x"), }; - compare_optimization(expected, program, |p| p.cast_data_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.cast_data_reducer(id, arg_stack, scope, context); + }) + }); } #[test] @@ -2509,7 +2703,11 @@ mod tests { .lambda("x"), }; - compare_optimization(expected, program, |p| p.cast_data_reducer()); + compare_optimization(expected, program, |p| { + p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| { + term.cast_data_reducer(id, arg_stack, scope, context); + }) + }); } #[test]