diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index 18ee7029..8aa95e74 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -10,6 +10,7 @@ use indexmap::IndexMap; use itertools::{FoldWhile, Itertools}; use pallas_primitives::conway::{BigInt, PlutusData}; use std::{cmp::Ordering, iter, ops::Neg, rc::Rc}; +use strum::IntoEnumIterator; #[derive(Eq, Hash, PartialEq, Clone, Debug, PartialOrd)] pub enum ScopePath { @@ -348,6 +349,12 @@ impl DefaultFunction { format!("__{}_wrapped", self.aiken_name()) } } +pub fn forceable_wrapped_names() -> Vec { + DefaultFunction::iter() + .filter(|df| df.force_count() > 0) + .map(|df| df.wrapped_name()) + .collect_vec() +} #[derive(PartialEq, Clone, Debug)] pub enum BuiltinArgs { @@ -1316,7 +1323,7 @@ impl Term { // So it costs more size to have them hoisted Term::Delay(e) if matches!(e.as_ref(), Term::Error) => true, // If it wraps a builtin with consts or arguments passed in then inline - Term::Lambda { .. } => arg_term.is_a_builtin_wrapper(context), + Term::Lambda { .. } => arg_term.is_a_builtin_wrapper(), // Inline smaller terms too Term::Constant(_) | Term::Var(_) | Term::Builtin(_) => true, @@ -1430,7 +1437,7 @@ impl Term { arg.split_body_lambda(); - arg_stack.push(std::mem::replace(arg, Term::Error.force())); + arg_stack.push(Args::Apply(0, std::mem::replace(arg, Term::Error.force()))); } Term::Lambda { parameter_name, @@ -1438,7 +1445,7 @@ impl Term { } => { current_term = Rc::make_mut(body); - if let Some(arg) = arg_stack.pop() { + if let Some(Args::Apply(_, arg)) = arg_stack.pop() { let names = arg.get_var_names(); let func = (parameter_name.clone(), arg); @@ -1467,7 +1474,12 @@ impl Term { unsat_lams.push(parameter_name.clone()); } } - Term::Delay(term) | Term::Force(term) => { + Term::Force(term) => { + current_term = Rc::make_mut(term); + + arg_stack.push(Args::Force(0)); + } + Term::Delay(term) => { Rc::make_mut(term).split_body_lambda(); break; } @@ -1522,7 +1534,10 @@ impl Term { // Replace args that weren't consumed let term = arg_stack .into_iter() - .rfold(term_to_build_on, |term, arg| term.apply(arg)); + .rfold(term_to_build_on, |term, arg| match arg { + Args::Force(_) => term.force(), + Args::Apply(_, arg) => term.apply(arg), + }); let term = function_groups.into_iter().rfold(term, |term, group| { let term = group.iter().rfold(term, |term, (name, _)| Term::Lambda { @@ -2212,7 +2227,7 @@ impl Term { std::mem::replace(term, Term::Error.force()) } - fn is_a_builtin_wrapper(&self, context: &Context) -> bool { + fn is_a_builtin_wrapper(&self) -> bool { let (names, term) = self.pop_lambdas_and_get_names(); let mut arg_names = vec![]; @@ -2233,12 +2248,7 @@ impl Term { } let func_is_builtin = match term { - Term::Var(name) => context - .builtins_map - .keys() - .map(|func| func.wrapped_name()) - .any(|func| func == name.text), - + Term::Var(name) => forceable_wrapped_names().contains(&name.text), Term::Builtin(_) => true, _ => false, }; @@ -2508,7 +2518,7 @@ impl Program { id_vec } else { - // Brand new buitlin so we add it to the list + // Brand new builtin so we add it to the list let curried_builtin = builtin_args.clone().args_to_curried_args(*func); let Some(id_vec) = curried_builtin.get_id_args(builtin_args) else {