feat: implement curried tree pruning

This commit is contained in:
microproofs 2023-12-11 15:25:09 -05:00 committed by Kasey
parent 51079b8590
commit 058a190294
1 changed files with 25 additions and 14 deletions

View File

@ -202,15 +202,25 @@ impl CurriedTree {
} }
} }
pub fn prune_single_occurrences(self) -> Self { pub fn prune_single_occurrences(mut self) -> Self {
match self { match &mut self {
CurriedTree::Branch { node, children, .. } => { CurriedTree::Branch { children, .. } => {
// Implement your logic here for the Branch variant *children = children
// For example, you might want to prune children here .clone()
todo!() .into_iter()
.filter(|child| child.multiple_occurrences())
.map(|child| {
if matches!(child, CurriedTree::Branch { .. }) {
child.prune_single_occurrences()
} else {
child
}
})
.collect_vec();
} }
_ => unreachable!(), _ => unreachable!(),
} }
self
} }
} }
@ -219,7 +229,6 @@ pub struct CurriedBuiltin {
pub func: DefaultFunction, pub func: DefaultFunction,
/// For use with subtract integer where we can flip the order of the arguments /// For use with subtract integer where we can flip the order of the arguments
/// if the second argument is a constant /// if the second argument is a constant
pub flipped: bool,
pub children: Vec<CurriedTree>, pub children: Vec<CurriedTree>,
} }
@ -236,7 +245,6 @@ impl CurriedBuiltin {
*child = child.clone().merge_node_by_path(path, scope); *child = child.clone().merge_node_by_path(path, scope);
CurriedBuiltin { CurriedBuiltin {
func: self.func, func: self.func,
flipped: self.flipped,
children, children,
} }
} else { } else {
@ -244,7 +252,6 @@ impl CurriedBuiltin {
children.push(path.args_to_curried_tree(scope)); children.push(path.args_to_curried_tree(scope));
CurriedBuiltin { CurriedBuiltin {
func: self.func, func: self.func,
flipped: self.flipped,
children, children,
} }
} }
@ -255,7 +262,6 @@ impl CurriedBuiltin {
*child = child.clone().merge_node_by_path(path, scope); *child = child.clone().merge_node_by_path(path, scope);
CurriedBuiltin { CurriedBuiltin {
func: self.func, func: self.func,
flipped: self.flipped,
children, children,
} }
} else if let Some(child) = children.iter_mut().find(|child| child.node() == arg2) { } else if let Some(child) = children.iter_mut().find(|child| child.node() == arg2) {
@ -268,7 +274,6 @@ impl CurriedBuiltin {
CurriedBuiltin { CurriedBuiltin {
func: self.func, func: self.func,
flipped: self.flipped,
children, children,
} }
} else { } else {
@ -276,7 +281,6 @@ impl CurriedBuiltin {
children.push(path.args_to_curried_tree(scope)); children.push(path.args_to_curried_tree(scope));
CurriedBuiltin { CurriedBuiltin {
func: self.func, func: self.func,
flipped: self.flipped,
children, children,
} }
} }
@ -396,6 +400,8 @@ pub fn can_curry_builtin(func: DefaultFunction) -> bool {
| DefaultFunction::ConsByteString | DefaultFunction::ConsByteString
| DefaultFunction::SliceByteString | DefaultFunction::SliceByteString
| DefaultFunction::IndexByteString | DefaultFunction::IndexByteString
| DefaultFunction::LessThanEqualsByteString
| DefaultFunction::LessThanByteString
| DefaultFunction::Bls12_381_G1_Add | DefaultFunction::Bls12_381_G1_Add
| DefaultFunction::Bls12_381_G2_Add | DefaultFunction::Bls12_381_G2_Add
) )
@ -933,7 +939,6 @@ impl Program<Name> {
// Brand new buitlin so we add it to the list // Brand new buitlin so we add it to the list
curried_terms.push(CurriedBuiltin { curried_terms.push(CurriedBuiltin {
func: *func, func: *func,
flipped: false,
children: vec![builtin_args.args_to_curried_tree(&scope)], children: vec![builtin_args.args_to_curried_tree(&scope)],
}); });
} }
@ -945,9 +950,15 @@ impl Program<Name> {
_ => {} _ => {}
}); });
curried_terms = curried_terms
.into_iter()
.map(|func| func.prune_single_occurrences())
.filter(|func| !func.children.is_empty())
.collect_vec();
println!("CURRIED ARGS"); println!("CURRIED ARGS");
for (index, curried_term) in curried_terms.into_iter().enumerate() { for (index, curried_term) in curried_terms.into_iter().enumerate() {
println!("index is {:#?}, term is {:#?}", index, curried_term,); println!("index is {:#?}, term is {:#?}", index, curried_term);
} }
a a