feat: implement curried tree pruning

This commit is contained in:
microproofs 2023-12-11 15:25:09 -05:00 committed by Kasey
parent 51079b8590
commit 058a190294
1 changed files with 25 additions and 14 deletions

View File

@ -202,15 +202,25 @@ impl CurriedTree {
}
}
pub fn prune_single_occurrences(self) -> Self {
match self {
CurriedTree::Branch { node, children, .. } => {
// Implement your logic here for the Branch variant
// For example, you might want to prune children here
todo!()
pub fn prune_single_occurrences(mut self) -> Self {
match &mut self {
CurriedTree::Branch { children, .. } => {
*children = children
.clone()
.into_iter()
.filter(|child| child.multiple_occurrences())
.map(|child| {
if matches!(child, CurriedTree::Branch { .. }) {
child.prune_single_occurrences()
} else {
child
}
})
.collect_vec();
}
_ => unreachable!(),
}
self
}
}
@ -219,7 +229,6 @@ pub struct CurriedBuiltin {
pub func: DefaultFunction,
/// For use with subtract integer where we can flip the order of the arguments
/// if the second argument is a constant
pub flipped: bool,
pub children: Vec<CurriedTree>,
}
@ -236,7 +245,6 @@ impl CurriedBuiltin {
*child = child.clone().merge_node_by_path(path, scope);
CurriedBuiltin {
func: self.func,
flipped: self.flipped,
children,
}
} else {
@ -244,7 +252,6 @@ impl CurriedBuiltin {
children.push(path.args_to_curried_tree(scope));
CurriedBuiltin {
func: self.func,
flipped: self.flipped,
children,
}
}
@ -255,7 +262,6 @@ impl CurriedBuiltin {
*child = child.clone().merge_node_by_path(path, scope);
CurriedBuiltin {
func: self.func,
flipped: self.flipped,
children,
}
} else if let Some(child) = children.iter_mut().find(|child| child.node() == arg2) {
@ -268,7 +274,6 @@ impl CurriedBuiltin {
CurriedBuiltin {
func: self.func,
flipped: self.flipped,
children,
}
} else {
@ -276,7 +281,6 @@ impl CurriedBuiltin {
children.push(path.args_to_curried_tree(scope));
CurriedBuiltin {
func: self.func,
flipped: self.flipped,
children,
}
}
@ -396,6 +400,8 @@ pub fn can_curry_builtin(func: DefaultFunction) -> bool {
| DefaultFunction::ConsByteString
| DefaultFunction::SliceByteString
| DefaultFunction::IndexByteString
| DefaultFunction::LessThanEqualsByteString
| DefaultFunction::LessThanByteString
| DefaultFunction::Bls12_381_G1_Add
| DefaultFunction::Bls12_381_G2_Add
)
@ -933,7 +939,6 @@ impl Program<Name> {
// Brand new buitlin so we add it to the list
curried_terms.push(CurriedBuiltin {
func: *func,
flipped: false,
children: vec![builtin_args.args_to_curried_tree(&scope)],
});
}
@ -945,9 +950,15 @@ impl Program<Name> {
_ => {}
});
curried_terms = curried_terms
.into_iter()
.map(|func| func.prune_single_occurrences())
.filter(|func| !func.children.is_empty())
.collect_vec();
println!("CURRIED ARGS");
for (index, curried_term) in curried_terms.into_iter().enumerate() {
println!("index is {:#?}, term is {:#?}", index, curried_term,);
println!("index is {:#?}, term is {:#?}", index, curried_term);
}
a