diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index d1998a72..10a0180e 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -77,6 +77,382 @@ impl Default for IdGen { } } +#[derive(PartialEq, Clone, Debug)] +pub enum BuiltinArgs { + TwoArgs(Term, Term), + ThreeArgs(Term, Term, Term), + TwoArgsAnyOrder(Term, Term), +} + +impl BuiltinArgs { + fn args_from_arg_stack(stack: Vec<(usize, Term)>, is_order_agnostic: bool) -> Self { + let mut ordered_arg_stack = + stack + .into_iter() + .rev() + .map(|(_, arg)| arg) + .sorted_by(|arg1, arg2| { + // sort by constant first if the builtin is order agnostic + if is_order_agnostic { + if matches!(arg1, Term::Constant(_)) && matches!(arg2, Term::Constant(_)) { + std::cmp::Ordering::Equal + } else if matches!(arg1, Term::Constant(_)) { + std::cmp::Ordering::Less + } else if matches!(arg2, Term::Constant(_)) { + std::cmp::Ordering::Greater + } else { + std::cmp::Ordering::Equal + } + } else { + std::cmp::Ordering::Equal + } + }); + + if ordered_arg_stack.len() == 2 && is_order_agnostic { + // This is the special case where the order of args is irrelevant to the builtin + // An example is addInteger or multiplyInteger + BuiltinArgs::TwoArgsAnyOrder( + ordered_arg_stack.next().unwrap(), + ordered_arg_stack.next().unwrap(), + ) + } else if ordered_arg_stack.len() == 2 { + BuiltinArgs::TwoArgs( + ordered_arg_stack.next().unwrap(), + ordered_arg_stack.next().unwrap(), + ) + } else { + // println!("ARG STACK FOR FUNC {:#?}, {:#?}", ordered_arg_stack, func); + BuiltinArgs::ThreeArgs( + ordered_arg_stack.next().unwrap(), + ordered_arg_stack.next().unwrap(), + ordered_arg_stack.next().unwrap(), + ) + } + } + + fn args_to_curried_tree(self, scope: &Scope) -> CurriedTree { + match self { + BuiltinArgs::TwoArgs(arg1, arg2) | BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => { + CurriedTree::Branch { + node: arg1, + multiple_occurrences: false, + children: vec![CurriedTree::Leaf { + node: arg2, + multiple_occurrences: false, + scope: scope.clone(), + }], + scope: scope.clone(), + } + } + + BuiltinArgs::ThreeArgs(arg1, arg2, arg3) => CurriedTree::Branch { + node: arg1, + multiple_occurrences: false, + children: vec![CurriedTree::Branch { + node: arg2, + multiple_occurrences: false, + children: vec![CurriedTree::Leaf { + node: arg3, + multiple_occurrences: false, + scope: scope.clone(), + }], + scope: scope.clone(), + }], + scope: scope.clone(), + }, + } + } +} + +#[derive(PartialEq, Clone, Debug)] +pub enum CurriedTree { + Branch { + node: Term, + multiple_occurrences: bool, + children: Vec, + scope: Scope, + }, + Leaf { + node: Term, + multiple_occurrences: bool, + scope: Scope, + }, +} + +impl CurriedTree { + pub fn node(&self) -> &Term { + match self { + CurriedTree::Branch { node, .. } => node, + CurriedTree::Leaf { node, .. } => node, + } + } + + pub fn node_mut(&mut self) -> &mut Term { + match self { + CurriedTree::Branch { node, .. } => node, + CurriedTree::Leaf { node, .. } => node, + } + } + + pub fn multiple_occurrences(&self) -> bool { + match self { + CurriedTree::Branch { + multiple_occurrences, + .. + } => *multiple_occurrences, + CurriedTree::Leaf { + multiple_occurrences, + .. + } => *multiple_occurrences, + } + } + + pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedTree { + match (self, path) { + ( + CurriedTree::Branch { + node, + mut children, + scope: branch_scope, + .. + }, + BuiltinArgs::TwoArgs(_, arg2) | BuiltinArgs::TwoArgsAnyOrder(_, arg2), + ) => { + if let Some(CurriedTree::Leaf { + multiple_occurrences, + scope: leaf_scope, + .. + }) = children.iter_mut().find(|child| child.node() == &arg2) + { + // So here we mutate the found child of the branch to update the leaf + // Note this is a 2 arg builtin so the depth is 2 + *multiple_occurrences = true; + *leaf_scope = leaf_scope.common_ancestor(scope); + CurriedTree::Branch { + node, + multiple_occurrences: true, + children, + scope: branch_scope.common_ancestor(scope), + } + } else { + children.push(CurriedTree::Leaf { + node: arg2, + multiple_occurrences: false, + scope: scope.clone(), + }); + CurriedTree::Branch { + node, + multiple_occurrences: true, + children, + scope: branch_scope.common_ancestor(scope), + } + } + } + ( + CurriedTree::Branch { + node, + mut children, + scope: branch_scope, + .. + }, + BuiltinArgs::ThreeArgs(_, arg2, arg3), + ) => { + if let Some(CurriedTree::Branch { + multiple_occurrences: child_multiple_occurrences, + children: child_children, + scope: child_scope, + .. + }) = children.iter_mut().find(|child| child.node() == &arg2) + { + if let Some(CurriedTree::Leaf { + multiple_occurrences: leaf_multiple_occurrences, + scope: leaf_scope, + .. + }) = child_children + .iter_mut() + .find(|child| child.node() == &arg3) + { + // So here we mutate the found child of the branch to update the leaf + // Note this is a 3 arg builtin so the depth is 3 + *leaf_multiple_occurrences = true; + *leaf_scope = leaf_scope.common_ancestor(scope); + *child_multiple_occurrences = true; + *child_scope = child_scope.common_ancestor(scope); + CurriedTree::Branch { + node, + multiple_occurrences: true, + children, + scope: branch_scope.common_ancestor(scope), + } + } else { + child_children.push(CurriedTree::Leaf { + node: arg3, + multiple_occurrences: false, + scope: scope.clone(), + }); + *child_multiple_occurrences = true; + *child_scope = child_scope.common_ancestor(scope); + CurriedTree::Branch { + node, + multiple_occurrences: true, + children, + scope: branch_scope.common_ancestor(scope), + } + } + } else { + children.push(BuiltinArgs::TwoArgs(arg2, arg3).args_to_curried_tree(scope)); + CurriedTree::Branch { + node, + multiple_occurrences: true, + children, + scope: branch_scope.common_ancestor(scope), + } + } + } + // Since all args are always added to the tree. The minimum depth of a tree is 2 and max is 3 + // Therefore we can't have a leaf at the root level of the match + _ => unreachable!(), + } + } + + pub fn prune_single_occurrences(mut self) -> Self { + match &mut self { + CurriedTree::Branch { children, .. } => { + *children = children + .clone() + .into_iter() + .filter(|child| child.multiple_occurrences()) + .map(|child| { + if matches!(child, CurriedTree::Branch { .. }) { + child.prune_single_occurrences() + } else { + child + } + }) + .collect_vec(); + } + _ => unreachable!(), + } + self + } +} + +#[derive(PartialEq, Clone, Debug)] +pub struct CurriedBuiltin { + pub func: DefaultFunction, + /// For use with subtract integer where we can flip the order of the arguments + /// if the second argument is a constant + pub children: Vec, +} + +impl CurriedBuiltin { + pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedBuiltin { + let mut children = self.children.clone(); + + match &path { + // First we just peak at the first arg to see if it was curried before + BuiltinArgs::TwoArgs(arg1, _) | BuiltinArgs::ThreeArgs(arg1, _, _) => { + if let Some(child) = children.iter_mut().find(|child| child.node() == arg1) { + // mutate child here so we don't have to recreate the whole tree + // We pass in the scope so we can get the common ancestor if the arg was curried before + *child = child.clone().merge_node_by_path(path, scope); + CurriedBuiltin { + func: self.func, + children, + } + } else { + // We found a new arg so we add it to the list of children + children.push(path.args_to_curried_tree(scope)); + CurriedBuiltin { + func: self.func, + children, + } + } + } + BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => { + // This is the special case where we search by both args before adding a new child + if let Some(child) = children.iter_mut().find(|child| child.node() == arg1) { + *child = child.clone().merge_node_by_path(path, scope); + CurriedBuiltin { + func: self.func, + children, + } + } else if let Some(child) = children.iter_mut().find(|child| child.node() == arg2) { + // If we found a curried argument using arg2 then we flip the order of the args + // before merging into the tree + *child = child.clone().merge_node_by_path( + BuiltinArgs::TwoArgsAnyOrder(arg2.clone(), arg1.clone()), + scope, + ); + + CurriedBuiltin { + func: self.func, + children, + } + } else { + // We found a new arg so we add it to the list of children + children.push(path.args_to_curried_tree(scope)); + CurriedBuiltin { + func: self.func, + children, + } + } + } + } + } + + pub fn prune_single_occurrences(mut self) -> Self { + self.children = self + .children + .into_iter() + .filter(|child| child.multiple_occurrences()) + .map(|child| child.prune_single_occurrences()) + .collect_vec(); + self + } +} + +pub fn is_order_agnostic_builtin(func: DefaultFunction) -> bool { + matches!( + func, + DefaultFunction::AddInteger + | DefaultFunction::MultiplyInteger + | DefaultFunction::EqualsInteger + | DefaultFunction::EqualsByteString + | DefaultFunction::EqualsString + | DefaultFunction::EqualsData + | DefaultFunction::Bls12_381_G1_Equal + | DefaultFunction::Bls12_381_G2_Equal + | DefaultFunction::Bls12_381_G1_Add + | DefaultFunction::Bls12_381_G2_Add + ) +} +/// For now all of the curry builtins are not forceable +pub fn can_curry_builtin(func: DefaultFunction) -> bool { + matches!( + func, + DefaultFunction::AddInteger + | DefaultFunction::SubtractInteger + | DefaultFunction::MultiplyInteger + | DefaultFunction::EqualsInteger + | DefaultFunction::EqualsByteString + | DefaultFunction::EqualsString + | DefaultFunction::EqualsData + | DefaultFunction::Bls12_381_G1_Equal + | DefaultFunction::Bls12_381_G2_Equal + | DefaultFunction::LessThanInteger + | DefaultFunction::LessThanEqualsInteger + | DefaultFunction::AppendByteString + | DefaultFunction::ConsByteString + | DefaultFunction::SliceByteString + | DefaultFunction::IndexByteString + | DefaultFunction::LessThanEqualsByteString + | DefaultFunction::LessThanByteString + | DefaultFunction::Bls12_381_G1_Add + | DefaultFunction::Bls12_381_G2_Add + ) +} + impl Program { fn traverse_uplc_with( self, @@ -530,6 +906,102 @@ impl Program { } }) } + + // WIP + pub fn builtin_curry_reducer(self) -> Program { + let mut curried_terms = vec![]; + let mut curry_applied_ids: Vec = vec![]; + + let a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { + Term::Builtin(func) => { + if can_curry_builtin(*func) && arg_stack.len() == func.arity() { + let mut scope = scope.clone(); + + // Get upper scope of the function plus args + // So for example if the scope is [.., ARG, ARG, FUNC] + // we want to pop off the last 3 to get the scope right above the function applications + for _ in 0..func.arity() { + scope = scope.pop(); + } + + let is_order_agnostic = is_order_agnostic_builtin(*func); + + // In the case of order agnostic builtins we want to sort the args by constant first + // This gives us the opportunity to curry constants that often pop up in the code + let builtin_args = + BuiltinArgs::args_from_arg_stack(arg_stack, is_order_agnostic); + + // First we see if we have already curried this builtin before + if let Some(curried_builtin) = curried_terms + .iter_mut() + .find(|curried_term: &&mut CurriedBuiltin| curried_term.func == *func) + { + // We found it the builtin was curried before + // So now we merge the new args into the existing curried builtin + *curried_builtin = (*curried_builtin) + .clone() + .merge_node_by_path(builtin_args, &scope); + } else { + // Brand new buitlin so we add it to the list + curried_terms.push(CurriedBuiltin { + func: *func, + children: vec![builtin_args.args_to_curried_tree(&scope)], + }); + } + } + } + + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + _ => {} + }); + + curried_terms = curried_terms + .into_iter() + .map(|func| func.prune_single_occurrences()) + .filter(|func| !func.children.is_empty()) + .collect_vec(); + + println!("CURRIED ARGS"); + for (index, curried_term) in curried_terms.iter().enumerate() { + println!("index is {:#?}, term is {:#?}", index, curried_term); + } + + // TODO: add function to generate names for curried_terms for generating vars to insert + a.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { + Term::Builtin(func) => { + if can_curry_builtin(*func) { + let Some(curried_builtin) = + curried_terms.iter().find(|curry| curry.func == *func) + else { + return; + }; + + let arg_stack_ids = arg_stack.iter().map(|(id, _)| *id).collect_vec(); + + let builtin_args = BuiltinArgs::args_from_arg_stack( + arg_stack, + is_order_agnostic_builtin(*func), + ); + + if let Some(_) = curried_builtin.children.iter().find(|child| { + let x = (*child) + .clone() + .merge_node_by_path(builtin_args.clone(), scope); + + *child == &x + }) { + curry_applied_ids.extend(arg_stack_ids); + } else { + } + } + } + Term::Apply { function, argument } => todo!(), + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + _ => {} + }) + } } fn var_occurrences(term: &Term, search_for: Rc) -> usize {