Add case constr for applies greater than 2 optimization

This commit is contained in:
microproofs
2025-01-09 17:45:33 +07:00
parent c1ed0dcbb5
commit 33392f1532
3 changed files with 150 additions and 26 deletions

View File

@@ -30,6 +30,17 @@ where
Term::Delay(self.into())
}
pub fn constr(tag: usize, fields: Vec<Term<T>>) -> Self {
Term::Constr { tag, fields }
}
pub fn case(self, branches: Vec<Term<T>>) -> Self {
Term::Case {
constr: self.into(),
branches,
}
}
// Primitives
pub fn integer(i: num_bigint::BigInt) -> Self {
Term::Constant(Constant::Integer(i).into())

View File

@@ -1362,6 +1362,62 @@ impl Term<Name> {
}
}
// IMPORTANT: RUNS ONE TIME AND ONLY ON THE LAST PASS
fn case_constr_apply_reducer(
&mut self,
_id: Option<usize>,
_arg_stack: Vec<Args>,
_scope: &Scope,
_context: &mut Context,
) {
let mut term = &mut std::mem::replace(self, Term::Error.force());
let mut arg_vec = vec![];
while let Term::Apply { function, argument } = term {
arg_vec.push(Rc::make_mut(argument));
term = Rc::make_mut(function);
}
arg_vec.reverse();
match term {
Term::Case { constr, branches }
if branches.len() == 1 && matches!(constr.as_ref(), Term::Constr { .. }) =>
{
let Term::Constr { fields, .. } = Rc::make_mut(constr) else {
unreachable!();
};
for arg in arg_vec {
fields.push(std::mem::replace(arg, Term::Error.force()));
}
*self = std::mem::replace(term, Term::Error.force());
}
_ => {
if arg_vec.len() > 3 {
let mut fields = vec![];
for arg in arg_vec {
fields.push(std::mem::replace(arg, Term::Error.force()));
}
*self = Term::constr(0, fields)
.case(vec![std::mem::replace(term, Term::Error.force())]);
} else {
for arg in arg_vec {
*term = (std::mem::replace(term, Term::Error.force()))
.apply(std::mem::replace(arg, Term::Error.force()));
}
*self = std::mem::replace(term, Term::Error.force());
}
}
}
}
fn identity_reducer(
&mut self,
_id: Option<usize>,
@@ -1754,7 +1810,7 @@ impl Term<Name> {
acc.apply(arg.pierce_no_inlines().clone())
});
// Check above for is error safe
// The check above is to make sure the program is error safe
let eval_term: Term<Name> = Program {
version: (1, 0, 0),
term: applied_term,
@@ -1881,7 +1937,7 @@ impl Program<Name> {
context,
)
}
// This one runs the optimizations that are only done a single time
// This runs the optimizations that are only done a single time
pub fn run_once_pass(self) -> Self {
let program = self
.traverse_uplc_with(false, &mut |id, term, _arg_stack, scope, context| {
@@ -1946,38 +2002,36 @@ impl Program<Name> {
pub fn multi_pass(self) -> (Self, Context) {
self.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| {
let mut changed;
let false = term.lambda_reducer(id, arg_stack.clone(), scope, context) else {
term.remove_inlined_ids(id, vec![], scope, context);
return;
};
changed = term.lambda_reducer(id, arg_stack.clone(), scope, context);
if changed {
let false = term.identity_reducer(id, arg_stack.clone(), scope, context) else {
term.remove_inlined_ids(id, vec![], scope, context);
return;
}
changed = term.identity_reducer(id, arg_stack.clone(), scope, context);
if changed {
};
let false = term.inline_reducer(id, arg_stack.clone(), scope, context) else {
term.remove_inlined_ids(id, vec![], scope, context);
return;
}
changed = term.inline_reducer(id, arg_stack.clone(), scope, context);
if changed {
};
let false = term.force_delay_reducer(id, arg_stack.clone(), scope, context) else {
term.remove_inlined_ids(id, vec![], scope, context);
return;
}
changed = term.force_delay_reducer(id, arg_stack.clone(), scope, context);
if changed {
};
let false = term.cast_data_reducer(id, arg_stack.clone(), scope, context) else {
term.remove_inlined_ids(id, vec![], scope, context);
return;
}
changed = term.cast_data_reducer(id, arg_stack.clone(), scope, context);
if changed {
};
let false = term.builtin_eval_reducer(id, arg_stack.clone(), scope, context) else {
term.remove_inlined_ids(id, vec![], scope, context);
return;
}
changed = term.builtin_eval_reducer(id, arg_stack.clone(), scope, context);
if changed {
term.remove_inlined_ids(id, vec![], scope, context);
return;
}
};
term.convert_arithmetic_ops(id, arg_stack, scope, context);
term.flip_constants(id, vec![], scope, context);
term.remove_inlined_ids(id, vec![], scope, context);
@@ -2000,6 +2054,7 @@ impl Program<Name> {
pub fn clean_up(self) -> Self {
self.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| {
term.remove_no_inlines(id, vec![], scope, context);
term.case_constr_apply_reducer(id, vec![], scope, context);
})
.0
}
@@ -2043,7 +2098,6 @@ impl Program<Name> {
) {
// We found it the builtin was curried before
// So now we merge the new args into the existing curried builtin
let curried_builtin = curried_terms.swap_remove(index);
let curried_builtin =
@@ -3403,4 +3457,63 @@ mod tests {
compare_optimization(expected, program, |p| p.builtin_curry_reducer());
}
#[test]
fn case_constr_apply_test_1() {
let program: Program<Name> = Program {
version: (1, 1, 0),
term: Term::add_integer()
.apply(Term::integer(0.into()))
.apply(Term::integer(0.into()))
.apply(Term::integer(0.into()))
.apply(Term::integer(0.into()))
.apply(Term::integer(0.into()))
.apply(Term::integer(0.into())),
};
let expected = Program {
version: (1, 1, 0),
term: Term::constr(
0,
vec![
Term::integer(0.into()),
Term::integer(0.into()),
Term::integer(0.into()),
Term::integer(0.into()),
Term::integer(0.into()),
Term::integer(0.into()),
],
)
.case(vec![Term::add_integer()]),
};
compare_optimization(expected, program, |p| {
p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| {
term.case_constr_apply_reducer(id, arg_stack, scope, context);
})
});
}
#[test]
fn case_constr_apply_test_2() {
let program: Program<Name> = Program {
version: (1, 1, 0),
term: Term::add_integer()
.apply(Term::integer(0.into()))
.apply(Term::integer(0.into())),
};
let expected = Program {
version: (1, 1, 0),
term: Term::add_integer()
.apply(Term::integer(0.into()))
.apply(Term::integer(0.into())),
};
compare_optimization(expected, program, |p| {
p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| {
term.case_constr_apply_reducer(id, arg_stack, scope, context);
})
});
}
}