diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index e0a754cc..4aa13546 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -35,17 +35,17 @@ impl Scope { pub fn push(&self, path: ScopePath) -> Self { let mut new_scope = self.scope.clone(); new_scope.push(path); - Scope { scope: new_scope } + Self { scope: new_scope } } pub fn pop(&self) -> Self { let mut new_scope = self.scope.clone(); new_scope.pop(); - Scope { scope: new_scope } + Self { scope: new_scope } } pub fn common_ancestor(&self, other: &Scope) -> Self { - Scope { + Self { scope: self .scope .iter() @@ -90,6 +90,51 @@ impl Default for IdGen { } } +#[derive(PartialEq, PartialOrd, Default, Debug, Clone)] +pub struct VarLookup { + found: bool, + occurrences: isize, + delays: isize, +} + +impl VarLookup { + pub fn new() -> Self { + Self { + found: false, + occurrences: 0, + delays: 0, + } + } + + pub fn new_found() -> Self { + Self { + found: true, + occurrences: 1, + delays: 0, + } + } + + pub fn combine(self, other: Self) -> Self { + Self { + found: self.found || other.found, + occurrences: self.occurrences + other.occurrences, + delays: self.delays + other.delays, + } + } + + pub fn delay_if_found(self, delay_amount: isize) -> Self { + if self.found { + Self { + found: self.found, + occurrences: self.occurrences, + delays: self.delays + delay_amount, + } + } else { + self + } + } +} + impl DefaultFunction { pub fn is_order_agnostic_builtin(self) -> bool { matches!( @@ -606,7 +651,7 @@ pub struct CurriedBuiltin { impl CurriedBuiltin { pub fn merge_node_by_path(self, path: BuiltinArgs) -> Self { - CurriedBuiltin { + Self { func: self.func, args: self.args.merge_node_by_path(path), } @@ -808,25 +853,7 @@ impl Program { let mut lambda_applied_ids = vec![]; let mut identity_applied_ids = vec![]; // TODO: Remove extra traversals - self.traverse_uplc_with(&mut |_id, term, _arg_stack, _scope| { - // Since this one just inlines single occurrences. It's probably not needed - if let Term::Apply { function, argument } = term { - let func = Rc::make_mut(function); - - if let Term::Lambda { - parameter_name, - body, - } = func - { - if let Term::Var(name) = body.as_ref() { - if name.as_ref() == parameter_name.as_ref() { - *term = argument.as_ref().clone(); - } - } - } - } - }) - .traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| { + self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| { match term { Term::Apply { function, .. } => { // We are applying some arg so now we unwrap the id of the applied arg @@ -858,7 +885,7 @@ impl Program { replace_identity_usage(body.as_ref(), parameter_name.clone()); // Have to check if the body still has any occurrences of the parameter // After attempting replacement - if var_occurrences(body.as_ref(), parameter_name.clone()) > 0 { + if var_occurrences(body.as_ref(), parameter_name.clone()).found { let body = Rc::make_mut(body); *body = temp_term; } else { @@ -881,9 +908,8 @@ impl Program { let id = id.unwrap(); if lambda_applied_ids.contains(&id) { - let func = Rc::make_mut(function); // we inlined the arg so now remove the apply and arg from the program - *term = func.clone(); + *term = function.as_ref().clone(); } } Term::Lambda { @@ -893,11 +919,10 @@ impl Program { // pops stack here no matter what if let Some((arg_id, arg_term)) = arg_stack.pop() { let body = Rc::make_mut(body); - let occurrences = var_occurrences(body, parameter_name.clone()); - let delays = delayed_execution(body); + let var_lookup = var_occurrences(body, parameter_name.clone()); - if occurrences == 1 - && (delays == 0 + if var_lookup.occurrences == 1 + && (var_lookup.delays == 0 || matches!( &arg_term, Term::Var(_) @@ -913,7 +938,7 @@ impl Program { *term = body.clone(); // This will strip out unused terms that can't throw an error by themselves - } else if occurrences == 0 + } else if !var_lookup.found && matches!( arg_term, Term::Var(_) @@ -1082,7 +1107,7 @@ impl Program { IndexMap::new(); let mut final_ids: IndexMap, ()> = IndexMap::new(); - let a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { + let step_a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { Term::Builtin(func) => { if func.can_curry_builtin() && arg_stack.len() == func.arity() { let is_order_agnostic = func.is_order_agnostic_builtin(); @@ -1191,7 +1216,7 @@ impl Program { } }); - let mut b = a.traverse_uplc_with(&mut |id, term, arg_stack, scope| match term { + let mut step_b = step_a.traverse_uplc_with(&mut |id, term, arg_stack, scope| match term { Term::Builtin(func) => { if func.can_curry_builtin() { let Some(curried_builtin) = @@ -1255,7 +1280,9 @@ impl Program { .join("_") ); - *term = term.clone().lambda(name).apply(val); + if var_occurrences(term, Name::text(&name).into()).found { + *term = term.clone().lambda(name).apply(val); + } } } } @@ -1273,7 +1300,9 @@ impl Program { .join("_") ); - *term = term.clone().lambda(name).apply(val); + if var_occurrences(term, Name::text(&name).into()).found { + *term = term.clone().lambda(name).apply(val); + } } } } @@ -1281,68 +1310,49 @@ impl Program { let mut interner = Interner::new(); - interner.program(&mut b); + interner.program(&mut step_b); - b + step_b } } -fn var_occurrences(term: &Term, search_for: Rc) -> usize { +fn var_occurrences(term: &Term, search_for: Rc) -> VarLookup { match term { Term::Var(name) => { - if name.as_ref() == search_for.as_ref() { - 1 + if name.text == search_for.text && name.unique == search_for.unique { + VarLookup::new_found() } else { - 0 + VarLookup::new() } } - Term::Delay(body) => var_occurrences(body.as_ref(), search_for), + Term::Delay(body) => var_occurrences(body.as_ref(), search_for).delay_if_found(1), Term::Lambda { parameter_name, body, } => { - if parameter_name.clone() != search_for { - var_occurrences(body.as_ref(), search_for) + if parameter_name.text != search_for.text || parameter_name.unique != search_for.unique + { + var_occurrences(body.as_ref(), search_for).delay_if_found(1) } else { - 0 + VarLookup::new() } } Term::Apply { function, argument } => { var_occurrences(function.as_ref(), search_for.clone()) - + var_occurrences(argument.as_ref(), search_for) + .delay_if_found(-1) + .combine(var_occurrences(argument.as_ref(), search_for)) } - Term::Force(x) => var_occurrences(x.as_ref(), search_for), + Term::Force(x) => var_occurrences(x.as_ref(), search_for).delay_if_found(-1), Term::Case { .. } => todo!(), Term::Constr { .. } => todo!(), - _ => 0, - } -} - -fn delayed_execution(term: &Term) -> usize { - match term { - Term::Delay(body) => 1 + delayed_execution(body.as_ref()), - Term::Lambda { body, .. } => 1 + delayed_execution(body.as_ref()), - Term::Apply { function, argument } => { - delayed_execution(function.as_ref()) + delayed_execution(argument.as_ref()) - } - Term::Force(x) => delayed_execution(x.as_ref()), - Term::Case { constr, branches } => { - 1 + delayed_execution(constr.as_ref()) - + branches - .iter() - .fold(0, |acc, branch| acc + delayed_execution(branch)) - } - Term::Constr { fields, .. } => fields - .iter() - .fold(0, |acc, field| acc + delayed_execution(field)), - _ => 0, + _ => VarLookup::new(), } } fn substitute_var(term: &Term, original: Rc, replace_with: &Term) -> Term { match term { Term::Var(name) => { - if name.as_ref() == original.as_ref() { + if name.text == original.text && name.unique == original.unique { replace_with.clone() } else { Term::Var(name.clone()) @@ -1355,10 +1365,10 @@ fn substitute_var(term: &Term, original: Rc, replace_with: &Term { - if parameter_name.as_ref() != original.as_ref() { + if parameter_name.text != original.text || parameter_name.unique != original.unique { Term::Lambda { parameter_name: parameter_name.clone(), - body: Rc::new(substitute_var(body.as_ref(), original, replace_with)), + body: substitute_var(body.as_ref(), original, replace_with).into(), } } else { Term::Lambda { @@ -1368,14 +1378,10 @@ fn substitute_var(term: &Term, original: Rc, replace_with: &Term Term::Apply { - function: Rc::new(substitute_var( - function.as_ref(), - original.clone(), - replace_with, - )), - argument: Rc::new(substitute_var(argument.as_ref(), original, replace_with)), + function: substitute_var(function.as_ref(), original.clone(), replace_with).into(), + argument: substitute_var(argument.as_ref(), original, replace_with).into(), }, - Term::Force(f) => Term::Force(Rc::new(substitute_var(f.as_ref(), original, replace_with))), + Term::Force(f) => Term::Force(substitute_var(f.as_ref(), original, replace_with).into()), Term::Case { .. } => todo!(), Term::Constr { .. } => todo!(), x => x.clone(), @@ -1389,7 +1395,7 @@ fn replace_identity_usage(term: &Term, original: Rc) -> Term { parameter_name, body, } => { - if parameter_name.as_ref() != original.as_ref() { + if parameter_name.text != original.text || parameter_name.unique != original.unique { Term::Lambda { parameter_name: parameter_name.clone(), body: Rc::new(replace_identity_usage(body.as_ref(), original)), @@ -1412,7 +1418,7 @@ fn replace_identity_usage(term: &Term, original: Rc) -> Term { }; }; - if name.as_ref() == original.as_ref() { + if name.text == original.text && name.unique == original.unique { arg } else { Term::Apply {