diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index a0e1ce6c..a34bb151 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -9,6 +9,47 @@ use crate::{ ast::{Constant, Data, Name, Program, Term, Type}, builtins::DefaultFunction, }; +#[derive(PartialEq, Clone)] +pub enum BuiltinArgs { + TwoArgs(Term, Term), + ThreeArgs(Term, Term, Term), + TwoArgsAnyOrder(Term, Term), +} + +impl BuiltinArgs { + fn args_to_curried_tree(self, scope: &Scope) -> CurriedTree { + match self { + BuiltinArgs::TwoArgs(arg1, arg2) | BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => { + CurriedTree::Branch { + node: arg1, + multiple_occurrences: false, + children: vec![CurriedTree::Leaf { + node: arg2, + multiple_occurrences: false, + scope: scope.clone(), + }], + scope: scope.clone(), + } + } + + BuiltinArgs::ThreeArgs(arg1, arg2, arg3) => CurriedTree::Branch { + node: arg1, + multiple_occurrences: false, + children: vec![CurriedTree::Branch { + node: arg2, + multiple_occurrences: false, + children: vec![CurriedTree::Leaf { + node: arg3, + multiple_occurrences: false, + scope: scope.clone(), + }], + scope: scope.clone(), + }], + scope: scope.clone(), + }, + } + } +} #[derive(PartialEq, Clone)] pub enum CurriedTree { @@ -25,6 +66,131 @@ pub enum CurriedTree { }, } +impl CurriedTree { + pub fn node(&self) -> &Term { + match self { + CurriedTree::Branch { node, .. } => node, + CurriedTree::Leaf { node, .. } => node, + } + } + + pub fn node_mut(&mut self) -> &mut Term { + match self { + CurriedTree::Branch { node, .. } => node, + CurriedTree::Leaf { node, .. } => node, + } + } + + pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedTree { + match (self, path) { + ( + CurriedTree::Branch { + node, + mut children, + scope: branch_scope, + .. + }, + BuiltinArgs::TwoArgs(_, arg2) | BuiltinArgs::TwoArgsAnyOrder(_, arg2), + ) => { + if let Some(CurriedTree::Leaf { + multiple_occurrences, + scope: leaf_scope, + .. + }) = children.iter_mut().find(|child| child.node() == &arg2) + { + // So here we mutate the found child of the branch to update the leaf + // Note this is a 2 arg builtin so the depth is 2 + *multiple_occurrences = true; + *leaf_scope = leaf_scope.common_ancestor(scope); + CurriedTree::Branch { + node, + multiple_occurrences: true, + children, + scope: branch_scope.common_ancestor(scope), + } + } else { + children.push(CurriedTree::Leaf { + node: arg2, + multiple_occurrences: false, + scope: scope.clone(), + }); + CurriedTree::Branch { + node, + multiple_occurrences: true, + children, + scope: branch_scope.common_ancestor(scope), + } + } + } + ( + CurriedTree::Branch { + node, + mut children, + scope: branch_scope, + .. + }, + BuiltinArgs::ThreeArgs(_, arg2, arg3), + ) => { + if let Some(CurriedTree::Branch { + multiple_occurrences: child_multiple_occurrences, + children: child_children, + scope: child_scope, + .. + }) = children.iter_mut().find(|child| child.node() == &arg2) + { + if let Some(CurriedTree::Leaf { + multiple_occurrences: leaf_multiple_occurrences, + scope: leaf_scope, + .. + }) = child_children + .iter_mut() + .find(|child| child.node() == &arg3) + { + // So here we mutate the found child of the branch to update the leaf + // Note this is a 3 arg builtin so the depth is 3 + *leaf_multiple_occurrences = true; + *leaf_scope = leaf_scope.common_ancestor(scope); + *child_multiple_occurrences = true; + *child_scope = child_scope.common_ancestor(scope); + CurriedTree::Branch { + node, + multiple_occurrences: true, + children, + scope: branch_scope.common_ancestor(scope), + } + } else { + child_children.push(CurriedTree::Leaf { + node: arg3, + multiple_occurrences: false, + scope: scope.clone(), + }); + *child_multiple_occurrences = true; + *child_scope = child_scope.common_ancestor(scope); + CurriedTree::Branch { + node, + multiple_occurrences: true, + children, + scope: branch_scope.common_ancestor(scope), + } + } + } else { + children.push(BuiltinArgs::TwoArgs(arg2, arg3).args_to_curried_tree(scope)); + CurriedTree::Branch { + node, + multiple_occurrences: true, + children, + scope: branch_scope.common_ancestor(scope), + } + } + } + // Since all args are always added to the tree. The minimum depth of a tree is 2 and max is 3 + // Therefore we can't have a leaf at the root level of the match + _ => unreachable!(), + } + } +} + +#[derive(PartialEq, Clone)] pub struct CurriedBuiltin { pub func: DefaultFunction, /// For use with subtract integer where we can flip the order of the arguments @@ -33,6 +199,68 @@ pub struct CurriedBuiltin { pub children: Vec, } +impl CurriedBuiltin { + pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedBuiltin { + let mut children = self.children.clone(); + + match &path { + // First we just peak at the first arg to see if it was curried before + BuiltinArgs::TwoArgs(arg1, _) | BuiltinArgs::ThreeArgs(arg1, _, _) => { + if let Some(child) = children.iter_mut().find(|child| child.node() == arg1) { + // mutate child here so we don't have to recreate the whole tree + // We pass in the scope so we can get the common ancestor if the arg was curried before + *child = child.clone().merge_node_by_path(path, scope); + CurriedBuiltin { + func: self.func, + flipped: self.flipped, + children, + } + } else { + // We found a new arg so we add it to the list of children + children.push(path.args_to_curried_tree(scope)); + CurriedBuiltin { + func: self.func, + flipped: self.flipped, + children, + } + } + } + BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => { + // This is the special case where we search by both args before adding a new child + if let Some(child) = children.iter_mut().find(|child| child.node() == arg1) { + *child = child.clone().merge_node_by_path(path, scope); + CurriedBuiltin { + func: self.func, + flipped: self.flipped, + children, + } + } else if let Some(child) = children.iter_mut().find(|child| child.node() == arg2) { + // If we found a curried argument using arg2 then we flip the order of the args + // before merging into the tree + *child = child.clone().merge_node_by_path( + BuiltinArgs::TwoArgsAnyOrder(arg2.clone(), arg1.clone()), + scope, + ); + + CurriedBuiltin { + func: self.func, + flipped: self.flipped, + children, + } + } else { + // We found a new arg so we add it to the list of children + children.push(path.args_to_curried_tree(scope)); + CurriedBuiltin { + func: self.func, + flipped: self.flipped, + children, + } + } + } + } + } +} + #[derive(Eq, Hash, PartialEq, Clone)] pub enum ScopePath { FUNC, @@ -599,94 +827,81 @@ impl Program { let mut curried_terms = vec![]; let mut curry_applied_ids: Vec = vec![]; - self.traverse_uplc_with(&mut |_id, term, mut arg_stack, scope| match term { + self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { Term::Builtin(func) => { if can_curry_builtin(*func) && arg_stack.len() == func.arity() { let mut scope = scope.clone(); // Get upper scope of the function plus args + // So for example if the scope is [.., ARG, ARG, FUNC] + // we want to pop off the last 3 to get the scope right above the function applications for _ in 0..func.arity() { scope = scope.pop(); } let is_order_agnostic = is_order_agnostic_builtin(*func); + // In the case of order agnostic builtins we want to sort the args by constant first + // This gives us the opportunity to curry constants that often pop up in the code + let mut ordered_arg_stack = arg_stack + .into_iter() + .map(|(_, arg)| arg) + .sorted_by(|arg1, arg2| { + // sort by constant first if the builtin is order agnostic + if is_order_agnostic { + if matches!(arg1, Term::Constant(_)) + && matches!(arg2, Term::Constant(_)) + { + std::cmp::Ordering::Equal + } else if matches!(arg1, Term::Constant(_)) { + std::cmp::Ordering::Less + } else if matches!(arg2, Term::Constant(_)) { + std::cmp::Ordering::Greater + } else { + std::cmp::Ordering::Equal + } + } else { + std::cmp::Ordering::Equal + } + }); + + let builtin_args = if ordered_arg_stack.len() == 2 && is_order_agnostic { + // This is the special case where the order of args is irrelevant to the builtin + // An example is addInteger or multiplyInteger + BuiltinArgs::TwoArgsAnyOrder( + ordered_arg_stack.next().unwrap(), + ordered_arg_stack.next().unwrap(), + ) + } else if ordered_arg_stack.len() == 2 { + BuiltinArgs::TwoArgs( + ordered_arg_stack.next().unwrap(), + ordered_arg_stack.next().unwrap(), + ) + } else { + BuiltinArgs::ThreeArgs( + ordered_arg_stack.next().unwrap(), + ordered_arg_stack.next().unwrap(), + ordered_arg_stack.next().unwrap(), + ) + }; + + // First we see if we have already curried this builtin before if let Some(curried_builtin) = curried_terms .iter_mut() .find(|curried_term: &&mut CurriedBuiltin| curried_term.func == *func) { - let mut current_children = &mut curried_builtin.children; - - let ordered_args = - arg_stack - .into_iter() - .map(|(_, arg)| arg) - .sorted_by(|arg1, arg2| { - if is_order_agnostic { - if matches!(arg1, Term::Constant(_)) - && matches!(arg2, Term::Constant(_)) - { - std::cmp::Ordering::Equal - } else if matches!(arg1, Term::Constant(_)) { - std::cmp::Ordering::Greater - } else if matches!(arg2, Term::Constant(_)) { - std::cmp::Ordering::Less - } else { - std::cmp::Ordering::Equal - } - } else { - std::cmp::Ordering::Equal - } - }); - - todo!("Finish this") + // We found it the builtin was curried before + // So now we merge the new args into the existing curried builtin + *curried_builtin = (*curried_builtin) + .clone() + .merge_node_by_path(builtin_args, &scope); } else { - let Some(curried_tree) = arg_stack - .into_iter() - .map(|(_, arg)| arg) - .sorted_by(|arg1, arg2| { - if is_order_agnostic { - if matches!(arg1, Term::Constant(_)) - && matches!(arg2, Term::Constant(_)) - { - std::cmp::Ordering::Equal - } else if matches!(arg1, Term::Constant(_)) { - std::cmp::Ordering::Greater - } else if matches!(arg2, Term::Constant(_)) { - std::cmp::Ordering::Less - } else { - std::cmp::Ordering::Equal - } - } else { - std::cmp::Ordering::Equal - } - }) - .fold(None, |acc, arg| match acc { - Some(curry) => Some(CurriedTree::Branch { - node: arg, - multiple_occurrences: false, - children: vec![curry], - scope: scope.clone(), - }), - None => Some(CurriedTree::Leaf { - node: arg, - multiple_occurrences: false, - scope: scope.clone(), - }), - }) - else { - return; - }; - - let curried_builtin = CurriedBuiltin { + // Brand new buitlin so we add it to the list + curried_terms.push(CurriedBuiltin { func: *func, - // TODO: handle the special case of subtract integer flipped: false, - - children: vec![curried_tree], - }; - - curried_terms.push(curried_builtin); + children: vec![builtin_args.args_to_curried_tree(&scope)], + }); } } }