chore: continuing progress on implementing currying optimization for builtins

Introduced some new abstractions to make a different number of args easier to deal with
This commit is contained in:
microproofs 2023-12-06 10:06:02 -05:00 committed by Kasey
parent 8fdedb754e
commit 249581e1bc
1 changed files with 285 additions and 70 deletions

View File

@ -9,6 +9,47 @@ use crate::{
ast::{Constant, Data, Name, Program, Term, Type},
builtins::DefaultFunction,
};
#[derive(PartialEq, Clone)]
pub enum BuiltinArgs {
TwoArgs(Term<Name>, Term<Name>),
ThreeArgs(Term<Name>, Term<Name>, Term<Name>),
TwoArgsAnyOrder(Term<Name>, Term<Name>),
}
impl BuiltinArgs {
fn args_to_curried_tree(self, scope: &Scope) -> CurriedTree {
match self {
BuiltinArgs::TwoArgs(arg1, arg2) | BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => {
CurriedTree::Branch {
node: arg1,
multiple_occurrences: false,
children: vec![CurriedTree::Leaf {
node: arg2,
multiple_occurrences: false,
scope: scope.clone(),
}],
scope: scope.clone(),
}
}
BuiltinArgs::ThreeArgs(arg1, arg2, arg3) => CurriedTree::Branch {
node: arg1,
multiple_occurrences: false,
children: vec![CurriedTree::Branch {
node: arg2,
multiple_occurrences: false,
children: vec![CurriedTree::Leaf {
node: arg3,
multiple_occurrences: false,
scope: scope.clone(),
}],
scope: scope.clone(),
}],
scope: scope.clone(),
},
}
}
}
#[derive(PartialEq, Clone)]
pub enum CurriedTree {
@ -25,6 +66,131 @@ pub enum CurriedTree {
},
}
impl CurriedTree {
pub fn node(&self) -> &Term<Name> {
match self {
CurriedTree::Branch { node, .. } => node,
CurriedTree::Leaf { node, .. } => node,
}
}
pub fn node_mut(&mut self) -> &mut Term<Name> {
match self {
CurriedTree::Branch { node, .. } => node,
CurriedTree::Leaf { node, .. } => node,
}
}
pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedTree {
match (self, path) {
(
CurriedTree::Branch {
node,
mut children,
scope: branch_scope,
..
},
BuiltinArgs::TwoArgs(_, arg2) | BuiltinArgs::TwoArgsAnyOrder(_, arg2),
) => {
if let Some(CurriedTree::Leaf {
multiple_occurrences,
scope: leaf_scope,
..
}) = children.iter_mut().find(|child| child.node() == &arg2)
{
// So here we mutate the found child of the branch to update the leaf
// Note this is a 2 arg builtin so the depth is 2
*multiple_occurrences = true;
*leaf_scope = leaf_scope.common_ancestor(scope);
CurriedTree::Branch {
node,
multiple_occurrences: true,
children,
scope: branch_scope.common_ancestor(scope),
}
} else {
children.push(CurriedTree::Leaf {
node: arg2,
multiple_occurrences: false,
scope: scope.clone(),
});
CurriedTree::Branch {
node,
multiple_occurrences: true,
children,
scope: branch_scope.common_ancestor(scope),
}
}
}
(
CurriedTree::Branch {
node,
mut children,
scope: branch_scope,
..
},
BuiltinArgs::ThreeArgs(_, arg2, arg3),
) => {
if let Some(CurriedTree::Branch {
multiple_occurrences: child_multiple_occurrences,
children: child_children,
scope: child_scope,
..
}) = children.iter_mut().find(|child| child.node() == &arg2)
{
if let Some(CurriedTree::Leaf {
multiple_occurrences: leaf_multiple_occurrences,
scope: leaf_scope,
..
}) = child_children
.iter_mut()
.find(|child| child.node() == &arg3)
{
// So here we mutate the found child of the branch to update the leaf
// Note this is a 3 arg builtin so the depth is 3
*leaf_multiple_occurrences = true;
*leaf_scope = leaf_scope.common_ancestor(scope);
*child_multiple_occurrences = true;
*child_scope = child_scope.common_ancestor(scope);
CurriedTree::Branch {
node,
multiple_occurrences: true,
children,
scope: branch_scope.common_ancestor(scope),
}
} else {
child_children.push(CurriedTree::Leaf {
node: arg3,
multiple_occurrences: false,
scope: scope.clone(),
});
*child_multiple_occurrences = true;
*child_scope = child_scope.common_ancestor(scope);
CurriedTree::Branch {
node,
multiple_occurrences: true,
children,
scope: branch_scope.common_ancestor(scope),
}
}
} else {
children.push(BuiltinArgs::TwoArgs(arg2, arg3).args_to_curried_tree(scope));
CurriedTree::Branch {
node,
multiple_occurrences: true,
children,
scope: branch_scope.common_ancestor(scope),
}
}
}
// Since all args are always added to the tree. The minimum depth of a tree is 2 and max is 3
// Therefore we can't have a leaf at the root level of the match
_ => unreachable!(),
}
}
}
#[derive(PartialEq, Clone)]
pub struct CurriedBuiltin {
pub func: DefaultFunction,
/// For use with subtract integer where we can flip the order of the arguments
@ -33,6 +199,68 @@ pub struct CurriedBuiltin {
pub children: Vec<CurriedTree>,
}
impl CurriedBuiltin {
pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedBuiltin {
let mut children = self.children.clone();
match &path {
// First we just peak at the first arg to see if it was curried before
BuiltinArgs::TwoArgs(arg1, _) | BuiltinArgs::ThreeArgs(arg1, _, _) => {
if let Some(child) = children.iter_mut().find(|child| child.node() == arg1) {
// mutate child here so we don't have to recreate the whole tree
// We pass in the scope so we can get the common ancestor if the arg was curried before
*child = child.clone().merge_node_by_path(path, scope);
CurriedBuiltin {
func: self.func,
flipped: self.flipped,
children,
}
} else {
// We found a new arg so we add it to the list of children
children.push(path.args_to_curried_tree(scope));
CurriedBuiltin {
func: self.func,
flipped: self.flipped,
children,
}
}
}
BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => {
// This is the special case where we search by both args before adding a new child
if let Some(child) = children.iter_mut().find(|child| child.node() == arg1) {
*child = child.clone().merge_node_by_path(path, scope);
CurriedBuiltin {
func: self.func,
flipped: self.flipped,
children,
}
} else if let Some(child) = children.iter_mut().find(|child| child.node() == arg2) {
// If we found a curried argument using arg2 then we flip the order of the args
// before merging into the tree
*child = child.clone().merge_node_by_path(
BuiltinArgs::TwoArgsAnyOrder(arg2.clone(), arg1.clone()),
scope,
);
CurriedBuiltin {
func: self.func,
flipped: self.flipped,
children,
}
} else {
// We found a new arg so we add it to the list of children
children.push(path.args_to_curried_tree(scope));
CurriedBuiltin {
func: self.func,
flipped: self.flipped,
children,
}
}
}
}
}
}
#[derive(Eq, Hash, PartialEq, Clone)]
pub enum ScopePath {
FUNC,
@ -599,38 +827,36 @@ impl Program<Name> {
let mut curried_terms = vec![];
let mut curry_applied_ids: Vec<usize> = vec![];
self.traverse_uplc_with(&mut |_id, term, mut arg_stack, scope| match term {
self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term {
Term::Builtin(func) => {
if can_curry_builtin(*func) && arg_stack.len() == func.arity() {
let mut scope = scope.clone();
// Get upper scope of the function plus args
// So for example if the scope is [.., ARG, ARG, FUNC]
// we want to pop off the last 3 to get the scope right above the function applications
for _ in 0..func.arity() {
scope = scope.pop();
}
let is_order_agnostic = is_order_agnostic_builtin(*func);
if let Some(curried_builtin) = curried_terms
.iter_mut()
.find(|curried_term: &&mut CurriedBuiltin| curried_term.func == *func)
{
let mut current_children = &mut curried_builtin.children;
let ordered_args =
arg_stack
// In the case of order agnostic builtins we want to sort the args by constant first
// This gives us the opportunity to curry constants that often pop up in the code
let mut ordered_arg_stack = arg_stack
.into_iter()
.map(|(_, arg)| arg)
.sorted_by(|arg1, arg2| {
// sort by constant first if the builtin is order agnostic
if is_order_agnostic {
if matches!(arg1, Term::Constant(_))
&& matches!(arg2, Term::Constant(_))
{
std::cmp::Ordering::Equal
} else if matches!(arg1, Term::Constant(_)) {
std::cmp::Ordering::Greater
} else if matches!(arg2, Term::Constant(_)) {
std::cmp::Ordering::Less
} else if matches!(arg2, Term::Constant(_)) {
std::cmp::Ordering::Greater
} else {
std::cmp::Ordering::Equal
}
@ -639,54 +865,43 @@ impl Program<Name> {
}
});
todo!("Finish this")
let builtin_args = if ordered_arg_stack.len() == 2 && is_order_agnostic {
// This is the special case where the order of args is irrelevant to the builtin
// An example is addInteger or multiplyInteger
BuiltinArgs::TwoArgsAnyOrder(
ordered_arg_stack.next().unwrap(),
ordered_arg_stack.next().unwrap(),
)
} else if ordered_arg_stack.len() == 2 {
BuiltinArgs::TwoArgs(
ordered_arg_stack.next().unwrap(),
ordered_arg_stack.next().unwrap(),
)
} else {
let Some(curried_tree) = arg_stack
.into_iter()
.map(|(_, arg)| arg)
.sorted_by(|arg1, arg2| {
if is_order_agnostic {
if matches!(arg1, Term::Constant(_))
&& matches!(arg2, Term::Constant(_))
BuiltinArgs::ThreeArgs(
ordered_arg_stack.next().unwrap(),
ordered_arg_stack.next().unwrap(),
ordered_arg_stack.next().unwrap(),
)
};
// First we see if we have already curried this builtin before
if let Some(curried_builtin) = curried_terms
.iter_mut()
.find(|curried_term: &&mut CurriedBuiltin| curried_term.func == *func)
{
std::cmp::Ordering::Equal
} else if matches!(arg1, Term::Constant(_)) {
std::cmp::Ordering::Greater
} else if matches!(arg2, Term::Constant(_)) {
std::cmp::Ordering::Less
// We found it the builtin was curried before
// So now we merge the new args into the existing curried builtin
*curried_builtin = (*curried_builtin)
.clone()
.merge_node_by_path(builtin_args, &scope);
} else {
std::cmp::Ordering::Equal
}
} else {
std::cmp::Ordering::Equal
}
})
.fold(None, |acc, arg| match acc {
Some(curry) => Some(CurriedTree::Branch {
node: arg,
multiple_occurrences: false,
children: vec![curry],
scope: scope.clone(),
}),
None => Some(CurriedTree::Leaf {
node: arg,
multiple_occurrences: false,
scope: scope.clone(),
}),
})
else {
return;
};
let curried_builtin = CurriedBuiltin {
// Brand new buitlin so we add it to the list
curried_terms.push(CurriedBuiltin {
func: *func,
// TODO: handle the special case of subtract integer
flipped: false,
children: vec![curried_tree],
};
curried_terms.push(curried_builtin);
children: vec![builtin_args.args_to_curried_tree(&scope)],
});
}
}
}