diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index ec29a7a3..d69e5160 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -708,6 +708,7 @@ impl Term { mut arg_stack: Vec<(usize, Term)>, id_gen: &mut IdGen, with: &mut impl FnMut(Option, &mut Term, Vec<(usize, Term)>, &Scope), + inline_lambda: bool, ) { match self { Term::Apply { function, argument } => { @@ -719,6 +720,7 @@ impl Term { vec![], id_gen, with, + inline_lambda, ); let apply_id = id_gen.next_id(); @@ -732,6 +734,7 @@ impl Term { arg_stack, id_gen, with, + inline_lambda, ); scope.pop(); @@ -741,31 +744,72 @@ impl Term { Term::Delay(d) => { let d = Rc::make_mut(d); // First we recurse further to reduce the inner terms before coming back up to the Delay - Self::traverse_uplc_with_helper(d, scope, arg_stack, id_gen, with); + Self::traverse_uplc_with_helper(d, scope, arg_stack, id_gen, with, inline_lambda); with(None, self, vec![], scope); } Term::Lambda { + parameter_name: p, body, - parameter_name, } => { - let body = Rc::make_mut(body); + let p = p.as_ref().clone(); // Lambda pops one item off the arg stack. If there is no item then it is a unsaturated lambda // We also skip NO_INLINE lambdas since those are placeholder lambdas created by codegen - - let args = if parameter_name.text == NO_INLINE { + let args = if p.text == NO_INLINE { vec![] } else { arg_stack.pop().map(|arg| vec![arg]).unwrap_or_default() }; - // Pass in either one or zero args. - Self::traverse_uplc_with_helper(body, scope, arg_stack, id_gen, with); - with(None, self, args, scope); + if inline_lambda { + // Pass in either one or zero args. + // For lambda we run the function with first then recurse on the body or replaced term + with(None, self, args, scope); + + match self { + Term::Lambda { + parameter_name, + body, + } if parameter_name.as_ref() == &p => { + let body = Rc::make_mut(body); + Self::traverse_uplc_with_helper( + body, + scope, + arg_stack, + id_gen, + with, + inline_lambda, + ); + } + + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + other => Self::traverse_uplc_with_helper( + other, + scope, + arg_stack, + id_gen, + with, + inline_lambda, + ), + } + } else { + let body = Rc::make_mut(body); + + Self::traverse_uplc_with_helper( + body, + scope, + arg_stack, + id_gen, + with, + inline_lambda, + ); + with(None, self, vec![], scope); + } } Term::Force(f) => { let f = Rc::make_mut(f); - Self::traverse_uplc_with_helper(f, scope, arg_stack, id_gen, with); + Self::traverse_uplc_with_helper(f, scope, arg_stack, id_gen, with, inline_lambda); with(None, self, vec![], scope); } Term::Case { .. } => todo!(), @@ -792,6 +836,7 @@ impl Term { impl Program { fn traverse_uplc_with( self, + inline_lambda: bool, with: &mut impl FnMut(Option, &mut Term, Vec<(usize, Term)>, &Scope), ) -> Self { let mut term = self.term; @@ -799,7 +844,7 @@ impl Program { let arg_stack = vec![]; let mut id_gen = IdGen::new(); - term.traverse_uplc_with_helper(&scope, arg_stack, &mut id_gen, with); + term.traverse_uplc_with_helper(&scope, arg_stack, &mut id_gen, with, inline_lambda); Program { version: self.version, term, @@ -809,7 +854,7 @@ impl Program { pub fn lambda_reducer(self) -> Self { let mut lambda_applied_ids = vec![]; - self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| { + self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| { match term { Term::Apply { function, .. } => { // We are applying some arg so now we unwrap the id of the applied arg @@ -827,7 +872,7 @@ impl Program { } => { // pops stack here no matter what if let Some((arg_id, arg_term)) = arg_stack.pop() { - match arg_term { + match &arg_term { Term::Constant(c) if matches!(c.as_ref(), Constant::String(_)) => {} Term::Constant(_) | Term::Var(_) | Term::Builtin(_) => { let body = Rc::make_mut(body); @@ -835,10 +880,20 @@ impl Program { // creates new body that replaces all var occurrences with the arg *term = substitute_var(body, parameter_name.clone(), &arg_term); } + l @ Term::Lambda { .. } => { + if is_a_builtin_wrapper(l) { + let body = Rc::make_mut(body); + lambda_applied_ids.push(arg_id); + // creates new body that replaces all var occurrences with the arg + *term = substitute_var(body, parameter_name.clone(), &arg_term); + } + } + _ => {} } } } + Term::Case { .. } => todo!(), Term::Constr { .. } => todo!(), _ => {} @@ -849,7 +904,7 @@ impl Program { pub fn builtin_force_reducer(self) -> Self { let mut builtin_map = IndexMap::new(); - let program = self.traverse_uplc_with(&mut |_id, term, _arg_stack, _scope| { + let program = self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| { if let Term::Force(f) = term { let f = Rc::make_mut(f); match f { @@ -909,7 +964,7 @@ impl Program { pub fn identity_reducer(self) -> Self { let mut identity_applied_ids = vec![]; - self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| { + self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| { match term { Term::Apply { function, .. } => { // We are applying some arg so now we unwrap the id of the applied arg @@ -1018,7 +1073,7 @@ impl Program { pub fn inline_reducer(self) -> Self { let mut lambda_applied_ids = vec![]; - self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| match term { + self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| match term { Term::Apply { function, .. } => { // We are applying some arg so now we unwrap the id of the applied arg let id = id.unwrap(); @@ -1084,7 +1139,7 @@ impl Program { } pub fn force_delay_reducer(self) -> Self { - self.traverse_uplc_with(&mut |_id, term, _arg_stack, _scope| { + self.traverse_uplc_with(true, &mut |_id, term, _arg_stack, _scope| { if let Term::Force(f) = term { let f = f.as_ref(); @@ -1096,7 +1151,7 @@ impl Program { } pub fn remove_no_inlines(self) -> Self { - self.traverse_uplc_with(&mut |_, term, _, _| match term { + self.traverse_uplc_with(true, &mut |_, term, _, _| match term { Term::Lambda { parameter_name, body, @@ -1106,7 +1161,7 @@ impl Program { } pub fn inline_constr_ops(self) -> Self { - self.traverse_uplc_with(&mut |_, term, _, _| { + self.traverse_uplc_with(true, &mut |_, term, _, _| { if let Term::Apply { function, argument } = term { if let Term::Var(name) = function.as_ref() { if name.text == CONSTR_FIELDS_EXPOSER { @@ -1128,7 +1183,7 @@ impl Program { pub fn cast_data_reducer(self) -> Self { let mut applied_ids = vec![]; - self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| { + self.traverse_uplc_with(true, &mut |id, term, mut arg_stack, _scope| { match term { Term::Apply { function, .. } => { // We are apply some arg so now we unwrap the id of the applied arg @@ -1256,7 +1311,7 @@ impl Program { pub fn convert_arithmetic_ops(self) -> Self { let mut constants_to_flip = vec![]; - self.traverse_uplc_with(&mut |id, term, arg_stack, _scope| match term { + self.traverse_uplc_with(true, &mut |id, term, arg_stack, _scope| match term { Term::Apply { argument, .. } => { let id = id.unwrap(); @@ -1300,92 +1355,93 @@ impl Program { let mut final_ids: IndexMap, ()> = IndexMap::new(); - let step_a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { - Term::Builtin(func) => { - if func.can_curry_builtin() && arg_stack.len() == func.arity() { - let is_order_agnostic = func.is_order_agnostic_builtin(); + let step_a = + self.traverse_uplc_with(false, &mut |_id, term, arg_stack, scope| match term { + Term::Builtin(func) => { + if func.can_curry_builtin() && arg_stack.len() == func.arity() { + let is_order_agnostic = func.is_order_agnostic_builtin(); - // In the case of order agnostic builtins we want to sort the args by constant first - // This gives us the opportunity to curry constants that often pop up in the code + // In the case of order agnostic builtins we want to sort the args by constant first + // This gives us the opportunity to curry constants that often pop up in the code - let builtin_args = - BuiltinArgs::args_from_arg_stack(arg_stack, is_order_agnostic); + let builtin_args = + BuiltinArgs::args_from_arg_stack(arg_stack, is_order_agnostic); - // First we see if we have already curried this builtin before - let mut id_vec = if let Some((index, _)) = - curried_terms.iter_mut().find_position( - |curried_term: &&mut CurriedBuiltin| curried_term.func == *func, - ) { - // We found it the builtin was curried before - // So now we merge the new args into the existing curried builtin + // First we see if we have already curried this builtin before + let mut id_vec = if let Some((index, _)) = + curried_terms.iter_mut().find_position( + |curried_term: &&mut CurriedBuiltin| curried_term.func == *func, + ) { + // We found it the builtin was curried before + // So now we merge the new args into the existing curried builtin - let curried_builtin = curried_terms.swap_remove(index); + let curried_builtin = curried_terms.swap_remove(index); - let curried_builtin = - curried_builtin.merge_node_by_path(builtin_args.clone()); + let curried_builtin = + curried_builtin.merge_node_by_path(builtin_args.clone()); - let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else { - unreachable!(); - }; + let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else { + unreachable!(); + }; - flipped_terms - .insert(scope.clone(), curried_builtin.is_flipped(&builtin_args)); + flipped_terms + .insert(scope.clone(), curried_builtin.is_flipped(&builtin_args)); - curried_terms.push(curried_builtin); + curried_terms.push(curried_builtin); - id_vec - } else { - // Brand new buitlin so we add it to the list - let curried_builtin = builtin_args.clone().args_to_curried_args(*func); - - let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else { - unreachable!(); - }; - - curried_terms.push(curried_builtin); - - id_vec - }; - - while let Some(node) = id_vec.pop() { - let mut id_only_vec = - id_vec.iter().map(|item| item.curried_id).collect_vec(); - - id_only_vec.push(node.curried_id); - - let curry_name = CurriedName { - func_name: func.aiken_name(), - id_vec: id_only_vec, - }; - - if let Some((map_scope, _, occurrences)) = - id_mapped_curry_terms.get_mut(&curry_name) - { - *map_scope = map_scope.common_ancestor(scope); - *occurrences += 1; - } else if id_vec.is_empty() { - id_mapped_curry_terms.insert( - curry_name, - (scope.clone(), Term::Builtin(*func).apply(node.term), 1), - ); + id_vec } else { - let var_name = id_vec_function_to_var( - &func.aiken_name(), - &id_vec.iter().map(|item| item.curried_id).collect_vec(), - ); + // Brand new buitlin so we add it to the list + let curried_builtin = builtin_args.clone().args_to_curried_args(*func); - id_mapped_curry_terms.insert( - curry_name, - (scope.clone(), Term::var(var_name).apply(node.term), 1), - ); + let Some(id_vec) = curried_builtin.get_id_args(&builtin_args) else { + unreachable!(); + }; + + curried_terms.push(curried_builtin); + + id_vec + }; + + while let Some(node) = id_vec.pop() { + let mut id_only_vec = + id_vec.iter().map(|item| item.curried_id).collect_vec(); + + id_only_vec.push(node.curried_id); + + let curry_name = CurriedName { + func_name: func.aiken_name(), + id_vec: id_only_vec, + }; + + if let Some((map_scope, _, occurrences)) = + id_mapped_curry_terms.get_mut(&curry_name) + { + *map_scope = map_scope.common_ancestor(scope); + *occurrences += 1; + } else if id_vec.is_empty() { + id_mapped_curry_terms.insert( + curry_name, + (scope.clone(), Term::Builtin(*func).apply(node.term), 1), + ); + } else { + let var_name = id_vec_function_to_var( + &func.aiken_name(), + &id_vec.iter().map(|item| item.curried_id).collect_vec(), + ); + + id_mapped_curry_terms.insert( + curry_name, + (scope.clone(), Term::var(var_name).apply(node.term), 1), + ); + } } } } - } - Term::Constr { .. } => todo!(), - Term::Case { .. } => todo!(), - _ => {} - }); + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + _ => {} + }); id_mapped_curry_terms .into_iter() @@ -1410,7 +1466,7 @@ impl Program { }); let mut step_b = - step_a.traverse_uplc_with(&mut |id, term, mut arg_stack, scope| match term { + step_a.traverse_uplc_with(false, &mut |id, term, mut arg_stack, scope| match term { Term::Builtin(func) => { if func.can_curry_builtin() && arg_stack.len() == func.arity() { let Some(curried_builtin) = @@ -1658,6 +1714,47 @@ fn replace_identity_usage(term: &Term, original: Rc) -> Term { } } +fn is_a_builtin_wrapper(term: &Term) -> bool { + let (names, term) = pop_lambdas_and_get_names(term); + + let mut arg_names = vec![]; + + let mut term = term; + + while let Term::Apply { function, argument } = term { + match argument.as_ref() { + Term::Var(name) => arg_names.push(name), + + Term::Constant(_) => {} + _ => { + return false; + } + } + term = function.as_ref(); + } + + arg_names.iter().all(|item| names.contains(item)) && matches!(term, Term::Builtin(_)) +} + +fn pop_lambdas_and_get_names(term: &Term) -> (Vec>, &Term) { + let mut names = vec![]; + + let mut term = term; + + while let Term::Lambda { + parameter_name, + body, + } = term + { + if parameter_name.text != NO_INLINE { + names.push(parameter_name.clone()); + } + term = body.as_ref(); + } + + (names, term) +} + #[cfg(test)] mod tests {