chore: Add back curry code removed in a previous commit

This commit is contained in:
microproofs 2023-12-17 19:47:06 -05:00 committed by Kasey
parent 7c2bae0904
commit 2f72510102
1 changed files with 472 additions and 0 deletions

View File

@ -77,6 +77,382 @@ impl Default for IdGen {
} }
} }
#[derive(PartialEq, Clone, Debug)]
pub enum BuiltinArgs {
TwoArgs(Term<Name>, Term<Name>),
ThreeArgs(Term<Name>, Term<Name>, Term<Name>),
TwoArgsAnyOrder(Term<Name>, Term<Name>),
}
impl BuiltinArgs {
fn args_from_arg_stack(stack: Vec<(usize, Term<Name>)>, is_order_agnostic: bool) -> Self {
let mut ordered_arg_stack =
stack
.into_iter()
.rev()
.map(|(_, arg)| arg)
.sorted_by(|arg1, arg2| {
// sort by constant first if the builtin is order agnostic
if is_order_agnostic {
if matches!(arg1, Term::Constant(_)) && matches!(arg2, Term::Constant(_)) {
std::cmp::Ordering::Equal
} else if matches!(arg1, Term::Constant(_)) {
std::cmp::Ordering::Less
} else if matches!(arg2, Term::Constant(_)) {
std::cmp::Ordering::Greater
} else {
std::cmp::Ordering::Equal
}
} else {
std::cmp::Ordering::Equal
}
});
if ordered_arg_stack.len() == 2 && is_order_agnostic {
// This is the special case where the order of args is irrelevant to the builtin
// An example is addInteger or multiplyInteger
BuiltinArgs::TwoArgsAnyOrder(
ordered_arg_stack.next().unwrap(),
ordered_arg_stack.next().unwrap(),
)
} else if ordered_arg_stack.len() == 2 {
BuiltinArgs::TwoArgs(
ordered_arg_stack.next().unwrap(),
ordered_arg_stack.next().unwrap(),
)
} else {
// println!("ARG STACK FOR FUNC {:#?}, {:#?}", ordered_arg_stack, func);
BuiltinArgs::ThreeArgs(
ordered_arg_stack.next().unwrap(),
ordered_arg_stack.next().unwrap(),
ordered_arg_stack.next().unwrap(),
)
}
}
fn args_to_curried_tree(self, scope: &Scope) -> CurriedTree {
match self {
BuiltinArgs::TwoArgs(arg1, arg2) | BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => {
CurriedTree::Branch {
node: arg1,
multiple_occurrences: false,
children: vec![CurriedTree::Leaf {
node: arg2,
multiple_occurrences: false,
scope: scope.clone(),
}],
scope: scope.clone(),
}
}
BuiltinArgs::ThreeArgs(arg1, arg2, arg3) => CurriedTree::Branch {
node: arg1,
multiple_occurrences: false,
children: vec![CurriedTree::Branch {
node: arg2,
multiple_occurrences: false,
children: vec![CurriedTree::Leaf {
node: arg3,
multiple_occurrences: false,
scope: scope.clone(),
}],
scope: scope.clone(),
}],
scope: scope.clone(),
},
}
}
}
#[derive(PartialEq, Clone, Debug)]
pub enum CurriedTree {
Branch {
node: Term<Name>,
multiple_occurrences: bool,
children: Vec<CurriedTree>,
scope: Scope,
},
Leaf {
node: Term<Name>,
multiple_occurrences: bool,
scope: Scope,
},
}
impl CurriedTree {
pub fn node(&self) -> &Term<Name> {
match self {
CurriedTree::Branch { node, .. } => node,
CurriedTree::Leaf { node, .. } => node,
}
}
pub fn node_mut(&mut self) -> &mut Term<Name> {
match self {
CurriedTree::Branch { node, .. } => node,
CurriedTree::Leaf { node, .. } => node,
}
}
pub fn multiple_occurrences(&self) -> bool {
match self {
CurriedTree::Branch {
multiple_occurrences,
..
} => *multiple_occurrences,
CurriedTree::Leaf {
multiple_occurrences,
..
} => *multiple_occurrences,
}
}
pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedTree {
match (self, path) {
(
CurriedTree::Branch {
node,
mut children,
scope: branch_scope,
..
},
BuiltinArgs::TwoArgs(_, arg2) | BuiltinArgs::TwoArgsAnyOrder(_, arg2),
) => {
if let Some(CurriedTree::Leaf {
multiple_occurrences,
scope: leaf_scope,
..
}) = children.iter_mut().find(|child| child.node() == &arg2)
{
// So here we mutate the found child of the branch to update the leaf
// Note this is a 2 arg builtin so the depth is 2
*multiple_occurrences = true;
*leaf_scope = leaf_scope.common_ancestor(scope);
CurriedTree::Branch {
node,
multiple_occurrences: true,
children,
scope: branch_scope.common_ancestor(scope),
}
} else {
children.push(CurriedTree::Leaf {
node: arg2,
multiple_occurrences: false,
scope: scope.clone(),
});
CurriedTree::Branch {
node,
multiple_occurrences: true,
children,
scope: branch_scope.common_ancestor(scope),
}
}
}
(
CurriedTree::Branch {
node,
mut children,
scope: branch_scope,
..
},
BuiltinArgs::ThreeArgs(_, arg2, arg3),
) => {
if let Some(CurriedTree::Branch {
multiple_occurrences: child_multiple_occurrences,
children: child_children,
scope: child_scope,
..
}) = children.iter_mut().find(|child| child.node() == &arg2)
{
if let Some(CurriedTree::Leaf {
multiple_occurrences: leaf_multiple_occurrences,
scope: leaf_scope,
..
}) = child_children
.iter_mut()
.find(|child| child.node() == &arg3)
{
// So here we mutate the found child of the branch to update the leaf
// Note this is a 3 arg builtin so the depth is 3
*leaf_multiple_occurrences = true;
*leaf_scope = leaf_scope.common_ancestor(scope);
*child_multiple_occurrences = true;
*child_scope = child_scope.common_ancestor(scope);
CurriedTree::Branch {
node,
multiple_occurrences: true,
children,
scope: branch_scope.common_ancestor(scope),
}
} else {
child_children.push(CurriedTree::Leaf {
node: arg3,
multiple_occurrences: false,
scope: scope.clone(),
});
*child_multiple_occurrences = true;
*child_scope = child_scope.common_ancestor(scope);
CurriedTree::Branch {
node,
multiple_occurrences: true,
children,
scope: branch_scope.common_ancestor(scope),
}
}
} else {
children.push(BuiltinArgs::TwoArgs(arg2, arg3).args_to_curried_tree(scope));
CurriedTree::Branch {
node,
multiple_occurrences: true,
children,
scope: branch_scope.common_ancestor(scope),
}
}
}
// Since all args are always added to the tree. The minimum depth of a tree is 2 and max is 3
// Therefore we can't have a leaf at the root level of the match
_ => unreachable!(),
}
}
pub fn prune_single_occurrences(mut self) -> Self {
match &mut self {
CurriedTree::Branch { children, .. } => {
*children = children
.clone()
.into_iter()
.filter(|child| child.multiple_occurrences())
.map(|child| {
if matches!(child, CurriedTree::Branch { .. }) {
child.prune_single_occurrences()
} else {
child
}
})
.collect_vec();
}
_ => unreachable!(),
}
self
}
}
#[derive(PartialEq, Clone, Debug)]
pub struct CurriedBuiltin {
pub func: DefaultFunction,
/// For use with subtract integer where we can flip the order of the arguments
/// if the second argument is a constant
pub children: Vec<CurriedTree>,
}
impl CurriedBuiltin {
pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedBuiltin {
let mut children = self.children.clone();
match &path {
// First we just peak at the first arg to see if it was curried before
BuiltinArgs::TwoArgs(arg1, _) | BuiltinArgs::ThreeArgs(arg1, _, _) => {
if let Some(child) = children.iter_mut().find(|child| child.node() == arg1) {
// mutate child here so we don't have to recreate the whole tree
// We pass in the scope so we can get the common ancestor if the arg was curried before
*child = child.clone().merge_node_by_path(path, scope);
CurriedBuiltin {
func: self.func,
children,
}
} else {
// We found a new arg so we add it to the list of children
children.push(path.args_to_curried_tree(scope));
CurriedBuiltin {
func: self.func,
children,
}
}
}
BuiltinArgs::TwoArgsAnyOrder(arg1, arg2) => {
// This is the special case where we search by both args before adding a new child
if let Some(child) = children.iter_mut().find(|child| child.node() == arg1) {
*child = child.clone().merge_node_by_path(path, scope);
CurriedBuiltin {
func: self.func,
children,
}
} else if let Some(child) = children.iter_mut().find(|child| child.node() == arg2) {
// If we found a curried argument using arg2 then we flip the order of the args
// before merging into the tree
*child = child.clone().merge_node_by_path(
BuiltinArgs::TwoArgsAnyOrder(arg2.clone(), arg1.clone()),
scope,
);
CurriedBuiltin {
func: self.func,
children,
}
} else {
// We found a new arg so we add it to the list of children
children.push(path.args_to_curried_tree(scope));
CurriedBuiltin {
func: self.func,
children,
}
}
}
}
}
pub fn prune_single_occurrences(mut self) -> Self {
self.children = self
.children
.into_iter()
.filter(|child| child.multiple_occurrences())
.map(|child| child.prune_single_occurrences())
.collect_vec();
self
}
}
pub fn is_order_agnostic_builtin(func: DefaultFunction) -> bool {
matches!(
func,
DefaultFunction::AddInteger
| DefaultFunction::MultiplyInteger
| DefaultFunction::EqualsInteger
| DefaultFunction::EqualsByteString
| DefaultFunction::EqualsString
| DefaultFunction::EqualsData
| DefaultFunction::Bls12_381_G1_Equal
| DefaultFunction::Bls12_381_G2_Equal
| DefaultFunction::Bls12_381_G1_Add
| DefaultFunction::Bls12_381_G2_Add
)
}
/// For now all of the curry builtins are not forceable
pub fn can_curry_builtin(func: DefaultFunction) -> bool {
matches!(
func,
DefaultFunction::AddInteger
| DefaultFunction::SubtractInteger
| DefaultFunction::MultiplyInteger
| DefaultFunction::EqualsInteger
| DefaultFunction::EqualsByteString
| DefaultFunction::EqualsString
| DefaultFunction::EqualsData
| DefaultFunction::Bls12_381_G1_Equal
| DefaultFunction::Bls12_381_G2_Equal
| DefaultFunction::LessThanInteger
| DefaultFunction::LessThanEqualsInteger
| DefaultFunction::AppendByteString
| DefaultFunction::ConsByteString
| DefaultFunction::SliceByteString
| DefaultFunction::IndexByteString
| DefaultFunction::LessThanEqualsByteString
| DefaultFunction::LessThanByteString
| DefaultFunction::Bls12_381_G1_Add
| DefaultFunction::Bls12_381_G2_Add
)
}
impl Program<Name> { impl Program<Name> {
fn traverse_uplc_with( fn traverse_uplc_with(
self, self,
@ -530,6 +906,102 @@ impl Program<Name> {
} }
}) })
} }
// WIP
pub fn builtin_curry_reducer(self) -> Program<Name> {
let mut curried_terms = vec![];
let mut curry_applied_ids: Vec<usize> = vec![];
let a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term {
Term::Builtin(func) => {
if can_curry_builtin(*func) && arg_stack.len() == func.arity() {
let mut scope = scope.clone();
// Get upper scope of the function plus args
// So for example if the scope is [.., ARG, ARG, FUNC]
// we want to pop off the last 3 to get the scope right above the function applications
for _ in 0..func.arity() {
scope = scope.pop();
}
let is_order_agnostic = is_order_agnostic_builtin(*func);
// In the case of order agnostic builtins we want to sort the args by constant first
// This gives us the opportunity to curry constants that often pop up in the code
let builtin_args =
BuiltinArgs::args_from_arg_stack(arg_stack, is_order_agnostic);
// First we see if we have already curried this builtin before
if let Some(curried_builtin) = curried_terms
.iter_mut()
.find(|curried_term: &&mut CurriedBuiltin| curried_term.func == *func)
{
// We found it the builtin was curried before
// So now we merge the new args into the existing curried builtin
*curried_builtin = (*curried_builtin)
.clone()
.merge_node_by_path(builtin_args, &scope);
} else {
// Brand new buitlin so we add it to the list
curried_terms.push(CurriedBuiltin {
func: *func,
children: vec![builtin_args.args_to_curried_tree(&scope)],
});
}
}
}
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
_ => {}
});
curried_terms = curried_terms
.into_iter()
.map(|func| func.prune_single_occurrences())
.filter(|func| !func.children.is_empty())
.collect_vec();
println!("CURRIED ARGS");
for (index, curried_term) in curried_terms.iter().enumerate() {
println!("index is {:#?}, term is {:#?}", index, curried_term);
}
// TODO: add function to generate names for curried_terms for generating vars to insert
a.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term {
Term::Builtin(func) => {
if can_curry_builtin(*func) {
let Some(curried_builtin) =
curried_terms.iter().find(|curry| curry.func == *func)
else {
return;
};
let arg_stack_ids = arg_stack.iter().map(|(id, _)| *id).collect_vec();
let builtin_args = BuiltinArgs::args_from_arg_stack(
arg_stack,
is_order_agnostic_builtin(*func),
);
if let Some(_) = curried_builtin.children.iter().find(|child| {
let x = (*child)
.clone()
.merge_node_by_path(builtin_args.clone(), scope);
*child == &x
}) {
curry_applied_ids.extend(arg_stack_ids);
} else {
}
}
}
Term::Apply { function, argument } => todo!(),
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
_ => {}
})
}
} }
fn var_occurrences(term: &Term<Name>, search_for: Rc<Name>) -> usize { fn var_occurrences(term: &Term<Name>, search_for: Rc<Name>) -> usize {