start testing the first stage of currying builtins

This commit is contained in:
microproofs 2023-12-08 16:05:48 -05:00 committed by Kasey
parent 249581e1bc
commit 4015550f55
2 changed files with 50 additions and 8 deletions

View File

@ -26,6 +26,7 @@ pub fn aiken_optimize_and_intern(program: Program<Name>) -> Program<Name> {
.cast_data_reducer() .cast_data_reducer()
.lambda_reducer() .lambda_reducer()
.inline_reducer() .inline_reducer()
.builtin_curry_reducer()
.lambda_reducer() .lambda_reducer()
.inline_reducer() .inline_reducer()
} }

View File

@ -9,7 +9,7 @@ use crate::{
ast::{Constant, Data, Name, Program, Term, Type}, ast::{Constant, Data, Name, Program, Term, Type},
builtins::DefaultFunction, builtins::DefaultFunction,
}; };
#[derive(PartialEq, Clone)] #[derive(PartialEq, Clone, Debug)]
pub enum BuiltinArgs { pub enum BuiltinArgs {
TwoArgs(Term<Name>, Term<Name>), TwoArgs(Term<Name>, Term<Name>),
ThreeArgs(Term<Name>, Term<Name>, Term<Name>), ThreeArgs(Term<Name>, Term<Name>, Term<Name>),
@ -51,7 +51,7 @@ impl BuiltinArgs {
} }
} }
#[derive(PartialEq, Clone)] #[derive(PartialEq, Clone, Debug)]
pub enum CurriedTree { pub enum CurriedTree {
Branch { Branch {
node: Term<Name>, node: Term<Name>,
@ -81,6 +81,19 @@ impl CurriedTree {
} }
} }
pub fn multiple_occurrences(&self) -> bool {
match self {
CurriedTree::Branch {
multiple_occurrences,
..
} => *multiple_occurrences,
CurriedTree::Leaf {
multiple_occurrences,
..
} => *multiple_occurrences,
}
}
pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedTree { pub fn merge_node_by_path(self, path: BuiltinArgs, scope: &Scope) -> CurriedTree {
match (self, path) { match (self, path) {
( (
@ -188,9 +201,20 @@ impl CurriedTree {
_ => unreachable!(), _ => unreachable!(),
} }
} }
pub fn prune_single_occurrences(self) -> Self {
match self {
CurriedTree::Branch { node, children, .. } => {
// Implement your logic here for the Branch variant
// For example, you might want to prune children here
todo!()
}
_ => unreachable!(),
}
}
} }
#[derive(PartialEq, Clone)] #[derive(PartialEq, Clone, Debug)]
pub struct CurriedBuiltin { pub struct CurriedBuiltin {
pub func: DefaultFunction, pub func: DefaultFunction,
/// For use with subtract integer where we can flip the order of the arguments /// For use with subtract integer where we can flip the order of the arguments
@ -259,15 +283,25 @@ impl CurriedBuiltin {
} }
} }
} }
pub fn prune_single_occurrences(mut self) -> Self {
self.children = self
.children
.into_iter()
.filter(|child| child.multiple_occurrences())
.map(|child| child.prune_single_occurrences())
.collect_vec();
self
}
} }
#[derive(Eq, Hash, PartialEq, Clone)] #[derive(Eq, Hash, PartialEq, Clone, Debug)]
pub enum ScopePath { pub enum ScopePath {
FUNC, FUNC,
ARG, ARG,
} }
#[derive(Eq, Hash, PartialEq, Clone)] #[derive(Eq, Hash, PartialEq, Clone, Debug)]
pub struct Scope { pub struct Scope {
scope: Vec<ScopePath>, scope: Vec<ScopePath>,
} }
@ -361,7 +395,6 @@ pub fn can_curry_builtin(func: DefaultFunction) -> bool {
| DefaultFunction::AppendByteString | DefaultFunction::AppendByteString
| DefaultFunction::ConsByteString | DefaultFunction::ConsByteString
| DefaultFunction::SliceByteString | DefaultFunction::SliceByteString
| DefaultFunction::LengthOfByteString
| DefaultFunction::IndexByteString | DefaultFunction::IndexByteString
| DefaultFunction::Bls12_381_G1_Add | DefaultFunction::Bls12_381_G1_Add
| DefaultFunction::Bls12_381_G2_Add | DefaultFunction::Bls12_381_G2_Add
@ -827,7 +860,7 @@ impl Program<Name> {
let mut curried_terms = vec![]; let mut curried_terms = vec![];
let mut curry_applied_ids: Vec<usize> = vec![]; let mut curry_applied_ids: Vec<usize> = vec![];
self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { let a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term {
Term::Builtin(func) => { Term::Builtin(func) => {
if can_curry_builtin(*func) && arg_stack.len() == func.arity() { if can_curry_builtin(*func) && arg_stack.len() == func.arity() {
let mut scope = scope.clone(); let mut scope = scope.clone();
@ -878,6 +911,7 @@ impl Program<Name> {
ordered_arg_stack.next().unwrap(), ordered_arg_stack.next().unwrap(),
) )
} else { } else {
// println!("ARG STACK FOR FUNC {:#?}, {:#?}", ordered_arg_stack, func);
BuiltinArgs::ThreeArgs( BuiltinArgs::ThreeArgs(
ordered_arg_stack.next().unwrap(), ordered_arg_stack.next().unwrap(),
ordered_arg_stack.next().unwrap(), ordered_arg_stack.next().unwrap(),
@ -909,7 +943,14 @@ impl Program<Name> {
Term::Constr { .. } => todo!(), Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(), Term::Case { .. } => todo!(),
_ => {} _ => {}
}) });
println!("CURRIED ARGS");
for (index, curried_term) in curried_terms.into_iter().enumerate() {
println!("index is {:#?}, term is {:#?}", index, curried_term,);
}
a
} }
} }