From 9a52258e1478ad0a9885f7356c069dabdd2c524d Mon Sep 17 00:00:00 2001 From: microproofs Date: Tue, 30 Jan 2024 09:04:59 -0500 Subject: [PATCH] chugging along with a small refactor and some more work toward currying --- crates/uplc/src/optimize/shrinker.rs | 490 ++++++++++++++++++++------- 1 file changed, 368 insertions(+), 122 deletions(-) diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index 10a0180e..cf94b0d1 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -1,4 +1,4 @@ -use std::{rc::Rc, vec}; +use std::{cmp::Ordering, iter::Peekable, rc::Rc, vec}; use indexmap::IndexMap; use itertools::Itertools; @@ -49,6 +49,10 @@ impl Scope { .collect_vec(), } } + + pub fn len(&self) -> usize { + self.scope.len() + } } impl Default for Scope { @@ -79,34 +83,33 @@ impl Default for IdGen { #[derive(PartialEq, Clone, Debug)] pub enum BuiltinArgs { - TwoArgs(Term, Term), - ThreeArgs(Term, Term, Term), - TwoArgsAnyOrder(Term, Term), + TwoArgs((usize, Term), (usize, Term)), + ThreeArgs( + (usize, Term), + (usize, Term), + (usize, Term), + ), + TwoArgsAnyOrder((usize, Term), (usize, Term)), } impl BuiltinArgs { fn args_from_arg_stack(stack: Vec<(usize, Term)>, is_order_agnostic: bool) -> Self { - let mut ordered_arg_stack = - stack - .into_iter() - .rev() - .map(|(_, arg)| arg) - .sorted_by(|arg1, arg2| { - // sort by constant first if the builtin is order agnostic - if is_order_agnostic { - if matches!(arg1, Term::Constant(_)) && matches!(arg2, Term::Constant(_)) { - std::cmp::Ordering::Equal - } else if matches!(arg1, Term::Constant(_)) { - std::cmp::Ordering::Less - } else if matches!(arg2, Term::Constant(_)) { - std::cmp::Ordering::Greater - } else { - std::cmp::Ordering::Equal - } - } else { - std::cmp::Ordering::Equal - } - }); + let mut ordered_arg_stack = stack.into_iter().rev().sorted_by(|(_, arg1), (_, arg2)| { + // sort by constant first if the builtin is order agnostic + if is_order_agnostic { + if matches!(arg1, Term::Constant(_)) && matches!(arg2, Term::Constant(_)) { + std::cmp::Ordering::Equal + } else if matches!(arg1, Term::Constant(_)) { + std::cmp::Ordering::Less + } else if matches!(arg2, Term::Constant(_)) { + std::cmp::Ordering::Greater + } else { + std::cmp::Ordering::Equal + } + } else { + std::cmp::Ordering::Equal + } + }); if ordered_arg_stack.len() == 2 && is_order_agnostic { // This is the special case where the order of args is irrelevant to the builtin @@ -134,10 +137,12 @@ impl BuiltinArgs { match self { BuiltinArgs::TwoArgs(arg1, arg2) | BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => { CurriedTree::Branch { - node: arg1, + id: arg1.0, + node: arg1.1, multiple_occurrences: false, children: vec![CurriedTree::Leaf { - node: arg2, + id: arg2.0, + node: arg2.1, multiple_occurrences: false, scope: scope.clone(), }], @@ -146,13 +151,16 @@ impl BuiltinArgs { } BuiltinArgs::ThreeArgs(arg1, arg2, arg3) => CurriedTree::Branch { - node: arg1, + id: arg1.0, + node: arg1.1, multiple_occurrences: false, children: vec![CurriedTree::Branch { - node: arg2, + id: arg1.0, + node: arg2.1, multiple_occurrences: false, children: vec![CurriedTree::Leaf { - node: arg3, + id: arg3.0, + node: arg3.1, multiple_occurrences: false, scope: scope.clone(), }], @@ -162,6 +170,16 @@ impl BuiltinArgs { }, } } + + fn to_id_vec(&self) -> Vec { + match self { + BuiltinArgs::TwoArgs(arg1, arg2) | BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => { + vec![arg1.0, arg2.0] + } + + BuiltinArgs::ThreeArgs(arg1, arg2, arg3) => vec![arg1.0, arg2.0, arg3.0], + } + } } #[derive(PartialEq, Clone, Debug)] @@ -171,19 +189,32 @@ pub enum CurriedTree { multiple_occurrences: bool, children: Vec, scope: Scope, + id: usize, }, Leaf { node: Term, multiple_occurrences: bool, scope: Scope, + id: usize, }, } impl CurriedTree { pub fn node(&self) -> &Term { match self { - CurriedTree::Branch { node, .. } => node, - CurriedTree::Leaf { node, .. } => node, + CurriedTree::Branch { node, .. } | CurriedTree::Leaf { node, .. } => node, + } + } + + pub fn id(&self) -> usize { + match self { + CurriedTree::Branch { id, .. } | CurriedTree::Leaf { id, .. } => *id, + } + } + + pub fn scope(&self) -> &Scope { + match self { + CurriedTree::Branch { scope, .. } | CurriedTree::Leaf { scope, .. } => scope, } } @@ -207,6 +238,48 @@ impl CurriedTree { } } + pub fn find_leaf_id_path(&self, path: &BuiltinArgs) -> Vec { + match (self, path) { + ( + CurriedTree::Branch { children, .. }, + BuiltinArgs::TwoArgs(_, (_, arg2)) | BuiltinArgs::TwoArgsAnyOrder(_, (_, arg2)), + ) => { + if let Some(CurriedTree::Leaf { id: leaf_id, .. }) = + children.iter().find(|child| child.node() == arg2) + { + vec![*leaf_id] + } else { + vec![] + } + } + + ( + CurriedTree::Branch { children, .. }, + BuiltinArgs::ThreeArgs(_, (_, arg2), (_, arg3)), + ) => { + if let Some(CurriedTree::Branch { + children: child_children, + id: mid_id, + .. + }) = children.iter().find(|child| child.node() == arg2) + { + if let Some(CurriedTree::Leaf { id: leaf_id, .. }) = + child_children.iter().find(|child| child.node() == arg3) + { + vec![*mid_id, *leaf_id] + } else { + vec![*mid_id] + } + } else { + vec![] + } + } + // Since all args are always added to the tree. The minimum depth of a tree is 2 and max is 3 + // Therefore we can't have a leaf at the root level of the match + _ => unreachable!(), + } + } + pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedTree { match (self, path) { ( @@ -214,9 +287,11 @@ impl CurriedTree { node, mut children, scope: branch_scope, + id, .. }, - BuiltinArgs::TwoArgs(_, arg2) | BuiltinArgs::TwoArgsAnyOrder(_, arg2), + BuiltinArgs::TwoArgs(_, (new_leaf_id, arg2)) + | BuiltinArgs::TwoArgsAnyOrder(_, (new_leaf_id, arg2)), ) => { if let Some(CurriedTree::Leaf { multiple_occurrences, @@ -229,6 +304,7 @@ impl CurriedTree { *multiple_occurrences = true; *leaf_scope = leaf_scope.common_ancestor(scope); CurriedTree::Branch { + id, node, multiple_occurrences: true, children, @@ -236,11 +312,13 @@ impl CurriedTree { } } else { children.push(CurriedTree::Leaf { + id: new_leaf_id, node: arg2, multiple_occurrences: false, scope: scope.clone(), }); CurriedTree::Branch { + id, node, multiple_occurrences: true, children, @@ -253,9 +331,10 @@ impl CurriedTree { node, mut children, scope: branch_scope, + id: top_id, .. }, - BuiltinArgs::ThreeArgs(_, arg2, arg3), + BuiltinArgs::ThreeArgs(_, (new_branch_id, arg2), (new_leaf_id, arg3)), ) => { if let Some(CurriedTree::Branch { multiple_occurrences: child_multiple_occurrences, @@ -279,6 +358,7 @@ impl CurriedTree { *child_multiple_occurrences = true; *child_scope = child_scope.common_ancestor(scope); CurriedTree::Branch { + id: top_id, node, multiple_occurrences: true, children, @@ -286,6 +366,7 @@ impl CurriedTree { } } else { child_children.push(CurriedTree::Leaf { + id: new_leaf_id, node: arg3, multiple_occurrences: false, scope: scope.clone(), @@ -293,6 +374,7 @@ impl CurriedTree { *child_multiple_occurrences = true; *child_scope = child_scope.common_ancestor(scope); CurriedTree::Branch { + id: top_id, node, multiple_occurrences: true, children, @@ -300,8 +382,12 @@ impl CurriedTree { } } } else { - children.push(BuiltinArgs::TwoArgs(arg2, arg3).args_to_curried_tree(scope)); + children.push( + BuiltinArgs::TwoArgs((new_branch_id, arg2), (new_leaf_id, arg3)) + .args_to_curried_tree(scope), + ); CurriedTree::Branch { + id: top_id, node, multiple_occurrences: true, children, @@ -335,6 +421,49 @@ impl CurriedTree { } self } + + fn to_scope_map( + &self, + mut acc: Vec<(Scope, Vec>)>, + current_term: &Term, + ) -> Vec<(Scope, Vec>)> { + if let CurriedTree::Branch { node, children, .. } = self { + acc = children.iter().fold(acc, |acc, child| { + child.to_scope_map(acc, ¤t_term.clone().apply(node.clone())) + }); + } + + let insert_index = acc.iter().enumerate().find_map(|(index, (item_scope, _))| { + if item_scope.len() > self.scope().len() { + Some((Ordering::Less, index)) + } else if item_scope == self.scope() { + // If scopes are exactly equal we keep them in the same scope grouping + Some((Ordering::Equal, index)) + } else { + None + } + }); + + let term = current_term.clone().apply(self.node().clone()); + + if let Some(insert_index) = insert_index { + match insert_index.0 { + Ordering::Less => { + acc.insert(insert_index.1, (self.scope().clone(), vec![term])); + acc + } + Ordering::Equal => { + let item = acc.get_mut(insert_index.1).unwrap(); + item.1.push(term); + acc + } + Ordering::Greater => unreachable!(), + } + } else { + acc.push((self.scope().clone(), vec![term])); + acc + } + } } #[derive(PartialEq, Clone, Debug)] @@ -352,7 +481,7 @@ impl CurriedBuiltin { match &path { // First we just peak at the first arg to see if it was curried before BuiltinArgs::TwoArgs(arg1, _) | BuiltinArgs::ThreeArgs(arg1, _, _) => { - if let Some(child) = children.iter_mut().find(|child| child.node() == arg1) { + if let Some(child) = children.iter_mut().find(|child| child.node() == &arg1.1) { // mutate child here so we don't have to recreate the whole tree // We pass in the scope so we can get the common ancestor if the arg was curried before *child = child.clone().merge_node_by_path(path, scope); @@ -371,13 +500,15 @@ impl CurriedBuiltin { } BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => { // This is the special case where we search by both args before adding a new child - if let Some(child) = children.iter_mut().find(|child| child.node() == arg1) { + if let Some(child) = children.iter_mut().find(|child| child.node() == &arg1.1) { *child = child.clone().merge_node_by_path(path, scope); CurriedBuiltin { func: self.func, children, } - } else if let Some(child) = children.iter_mut().find(|child| child.node() == arg2) { + } else if let Some(child) = + children.iter_mut().find(|child| child.node() == &arg2.1) + { // If we found a curried argument using arg2 then we flip the order of the args // before merging into the tree *child = child.clone().merge_node_by_path( @@ -410,74 +541,114 @@ impl CurriedBuiltin { .collect_vec(); self } -} -pub fn is_order_agnostic_builtin(func: DefaultFunction) -> bool { - matches!( - func, - DefaultFunction::AddInteger - | DefaultFunction::MultiplyInteger - | DefaultFunction::EqualsInteger - | DefaultFunction::EqualsByteString - | DefaultFunction::EqualsString - | DefaultFunction::EqualsData - | DefaultFunction::Bls12_381_G1_Equal - | DefaultFunction::Bls12_381_G2_Equal - | DefaultFunction::Bls12_381_G1_Add - | DefaultFunction::Bls12_381_G2_Add - ) -} -/// For now all of the curry builtins are not forceable -pub fn can_curry_builtin(func: DefaultFunction) -> bool { - matches!( - func, - DefaultFunction::AddInteger - | DefaultFunction::SubtractInteger - | DefaultFunction::MultiplyInteger - | DefaultFunction::EqualsInteger - | DefaultFunction::EqualsByteString - | DefaultFunction::EqualsString - | DefaultFunction::EqualsData - | DefaultFunction::Bls12_381_G1_Equal - | DefaultFunction::Bls12_381_G2_Equal - | DefaultFunction::LessThanInteger - | DefaultFunction::LessThanEqualsInteger - | DefaultFunction::AppendByteString - | DefaultFunction::ConsByteString - | DefaultFunction::SliceByteString - | DefaultFunction::IndexByteString - | DefaultFunction::LessThanEqualsByteString - | DefaultFunction::LessThanByteString - | DefaultFunction::Bls12_381_G1_Add - | DefaultFunction::Bls12_381_G2_Add - ) -} + pub fn find_leaf_id_path(&self, path: &BuiltinArgs) -> Option> { + let children = &self.children; + let mut id_vec = vec![]; -impl Program { - fn traverse_uplc_with( - self, - with: &mut impl FnMut(Option, &mut Term, Vec<(usize, Term)>, &Scope), - ) -> Self { - let mut term = self.term; - let scope = Scope { scope: vec![] }; - let arg_stack = vec![]; - let mut id_gen = IdGen::new(); + match path { + // First we just peak at the first arg to see if it was curried before + BuiltinArgs::TwoArgs(arg1, _) | BuiltinArgs::ThreeArgs(arg1, _, _) => { + if let Some(child) = children.iter().find(|child| child.node() == &arg1.1) { + // mutate child here so we don't have to recreate the whole tree + // We pass in the scope so we can get the common ancestor if the arg was curried before + id_vec.push(child.id()); - Self::traverse_uplc_with_helper(&mut term, &scope, arg_stack, &mut id_gen, with); - Program { - version: self.version, - term, + id_vec.extend(child.find_leaf_id_path(path)); + Some(id_vec) + } else { + None + } + } + BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => { + // This is the special case where we search by both args before adding a new child + if let Some(child) = children.iter().find(|child| child.node() == &arg1.1) { + id_vec.push(child.id()); + + id_vec.extend(child.find_leaf_id_path(path)); + Some(id_vec) + } else if let Some(child) = children.iter().find(|child| child.node() == &arg2.1) { + // If we found a curried argument using arg2 then we flip the order of the args + // before merging into the tree + + id_vec.push(child.id()); + + id_vec.extend(child.find_leaf_id_path(&BuiltinArgs::TwoArgsAnyOrder( + arg2.clone(), + arg1.clone(), + ))); + + Some(id_vec) + } else { + None + } + } } } + fn to_scope_map(&self) -> Vec<(Scope, Vec>)> { + let scope_map = vec![]; + let current_term = Term::Builtin(self.func); + + self.children.iter().fold(scope_map, |acc, child| { + child.to_scope_map(acc, ¤t_term) + }) + } +} + +impl DefaultFunction { + pub fn is_order_agnostic_builtin(self) -> bool { + matches!( + self, + DefaultFunction::AddInteger + | DefaultFunction::MultiplyInteger + | DefaultFunction::EqualsInteger + | DefaultFunction::EqualsByteString + | DefaultFunction::EqualsString + | DefaultFunction::EqualsData + | DefaultFunction::Bls12_381_G1_Equal + | DefaultFunction::Bls12_381_G2_Equal + | DefaultFunction::Bls12_381_G1_Add + | DefaultFunction::Bls12_381_G2_Add + ) + } + /// For now all of the curry builtins are not forceable + pub fn can_curry_builtin(self) -> bool { + matches!( + self, + DefaultFunction::AddInteger + | DefaultFunction::SubtractInteger + | DefaultFunction::MultiplyInteger + | DefaultFunction::EqualsInteger + | DefaultFunction::EqualsByteString + | DefaultFunction::EqualsString + | DefaultFunction::EqualsData + | DefaultFunction::Bls12_381_G1_Equal + | DefaultFunction::Bls12_381_G2_Equal + | DefaultFunction::LessThanInteger + | DefaultFunction::LessThanEqualsInteger + | DefaultFunction::AppendByteString + | DefaultFunction::ConsByteString + | DefaultFunction::SliceByteString + | DefaultFunction::IndexByteString + | DefaultFunction::LessThanEqualsByteString + | DefaultFunction::LessThanByteString + | DefaultFunction::Bls12_381_G1_Add + | DefaultFunction::Bls12_381_G2_Add + | DefaultFunction::ConstrData + ) + } +} + +impl Term { fn traverse_uplc_with_helper( - term: &mut Term, + &mut self, scope: &Scope, mut arg_stack: Vec<(usize, Term)>, id_gen: &mut IdGen, with: &mut impl FnMut(Option, &mut Term, Vec<(usize, Term)>, &Scope), ) { - match term { + match self { Term::Apply { function, argument } => { let arg = Rc::make_mut(argument); let argument_arg_stack = vec![]; @@ -502,13 +673,13 @@ impl Program { with, ); - with(Some(apply_id), term, vec![], scope); + with(Some(apply_id), self, vec![], scope); } Term::Delay(d) => { let d = Rc::make_mut(d); // First we recurse further to reduce the inner terms before coming back up to the Delay Self::traverse_uplc_with_helper(d, scope, arg_stack, id_gen, with); - with(None, term, vec![], scope); + with(None, self, vec![], scope); } Term::Lambda { body, .. } => { let body = Rc::make_mut(body); @@ -517,13 +688,13 @@ impl Program { // Pass in either one or zero args. Self::traverse_uplc_with_helper(body, scope, arg_stack, id_gen, with); - with(None, term, args, scope); + with(None, self, args, scope); } Term::Force(f) => { let f = Rc::make_mut(f); Self::traverse_uplc_with_helper(f, scope, arg_stack, id_gen, with); - with(None, term, vec![], scope); + with(None, self, vec![], scope); } Term::Case { .. } => todo!(), Term::Constr { .. } => todo!(), @@ -537,7 +708,7 @@ impl Program { } } // Pass in args up to function arity. - with(None, term, args, scope); + with(None, self, args, scope); } term => { with(None, term, vec![], scope); @@ -545,6 +716,70 @@ impl Program { } } + fn traverse_to_scope<'a, 'b, I>(&'a mut self, scope: &mut Peekable) -> &'a mut Self + where + I: Iterator, + { + if scope.peek().is_none() { + self + } else { + match self { + Term::Apply { function, argument } => { + let scope_path = scope.next(); + + match scope_path.unwrap() { + ScopePath::FUNC => Self::traverse_to_scope(Rc::make_mut(function), scope), + ScopePath::ARG => Self::traverse_to_scope(Rc::make_mut(argument), scope), + } + } + Term::Delay(d) => { + let d = Rc::make_mut(d); + // First we recurse further to reduce the inner terms before coming back up to the Delay + Self::traverse_to_scope(d, scope) + } + Term::Lambda { body, .. } => { + let body = Rc::make_mut(body); + // First we recurse further to reduce the inner terms before coming back up to the Delay + Self::traverse_to_scope(body, scope) + } + + Term::Force(f) => { + let f = Rc::make_mut(f); + // First we recurse further to reduce the inner terms before coming back up to the Delay + Self::traverse_to_scope(f, scope) + } + Term::Case { .. } => todo!(), + Term::Constr { .. } => todo!(), + + _ => unreachable!("Incorrect scope path"), + } + } + } +} + +impl Program { + fn traverse_uplc_with( + self, + with: &mut impl FnMut(Option, &mut Term, Vec<(usize, Term)>, &Scope), + ) -> Self { + let mut term = self.term; + let scope = Scope { scope: vec![] }; + let arg_stack = vec![]; + let mut id_gen = IdGen::new(); + + term.traverse_uplc_with_helper(&scope, arg_stack, &mut id_gen, with); + Program { + version: self.version, + term, + } + } + + fn find_node_by_scope(&mut self, scope: &Scope) -> &mut Term { + let term = &mut self.term; + let iter = &mut scope.scope.iter().peekable(); + term.traverse_to_scope(iter) + } + pub fn lambda_reducer(self) -> Self { let mut lambda_applied_ids = vec![]; @@ -585,7 +820,7 @@ impl Program { }) } - pub fn builtin_force_reducer(self) -> Program { + pub fn builtin_force_reducer(self) -> Self { let mut builtin_map = IndexMap::new(); let program = self.traverse_uplc_with(&mut |_id, term, _arg_stack, _scope| { @@ -638,7 +873,7 @@ impl Program { } } - pub fn inline_reducer(self) -> Program { + pub fn inline_reducer(self) -> Self { let mut lambda_applied_ids = vec![]; let mut identity_applied_ids = vec![]; // TODO: Remove extra traversals @@ -768,7 +1003,7 @@ impl Program { }) } - pub fn force_delay_reducer(self) -> Program { + pub fn force_delay_reducer(self) -> Self { self.traverse_uplc_with(&mut |_id, term, _arg_stack, _scope| { if let Term::Force(f) = term { let f = Rc::make_mut(f); @@ -780,7 +1015,7 @@ impl Program { }) } - pub fn cast_data_reducer(self) -> Program { + pub fn cast_data_reducer(self) -> Self { let mut applied_ids = vec![]; self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| { @@ -908,13 +1143,13 @@ impl Program { } // WIP - pub fn builtin_curry_reducer(self) -> Program { + pub fn builtin_curry_reducer(self) -> Self { let mut curried_terms = vec![]; let mut curry_applied_ids: Vec = vec![]; let a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { Term::Builtin(func) => { - if can_curry_builtin(*func) && arg_stack.len() == func.arity() { + if func.can_curry_builtin() && arg_stack.len() == func.arity() { let mut scope = scope.clone(); // Get upper scope of the function plus args @@ -924,7 +1159,7 @@ impl Program { scope = scope.pop(); } - let is_order_agnostic = is_order_agnostic_builtin(*func); + let is_order_agnostic = func.is_order_agnostic_builtin(); // In the case of order agnostic builtins we want to sort the args by constant first // This gives us the opportunity to curry constants that often pop up in the code @@ -968,39 +1203,50 @@ impl Program { } // TODO: add function to generate names for curried_terms for generating vars to insert - a.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { + let b = a.traverse_uplc_with(&mut |id, term, arg_stack, _scope| match term { Term::Builtin(func) => { - if can_curry_builtin(*func) { + if func.can_curry_builtin() { let Some(curried_builtin) = curried_terms.iter().find(|curry| curry.func == *func) else { return; }; - let arg_stack_ids = arg_stack.iter().map(|(id, _)| *id).collect_vec(); - let builtin_args = BuiltinArgs::args_from_arg_stack( arg_stack, - is_order_agnostic_builtin(*func), + func.is_order_agnostic_builtin(), ); - if let Some(_) = curried_builtin.children.iter().find(|child| { - let x = (*child) - .clone() - .merge_node_by_path(builtin_args.clone(), scope); + let Some(id_vec) = curried_builtin.find_leaf_id_path(&builtin_args) else { + return; + }; - *child == &x - }) { - curry_applied_ids.extend(arg_stack_ids); - } else { - } + let id_str = id_vec + .iter() + .map(|id| id.to_string()) + .collect::>() + .join("_"); + + let name = format!("{}_{}", func.aiken_name(), id_str); + + curry_applied_ids.extend(builtin_args.to_id_vec().iter().take(id_vec.len())); + + *term = Term::var(name); + } + } + Term::Apply { function, .. } => { + let id = id.unwrap(); + + if curry_applied_ids.contains(&id) { + *term = (**function).clone(); } } - Term::Apply { function, argument } => todo!(), Term::Constr { .. } => todo!(), Term::Case { .. } => todo!(), _ => {} - }) + }); + + todo!() } }