Add case constr for applies greater than 2 optimization
This commit is contained in:
@@ -30,6 +30,17 @@ where
|
||||
Term::Delay(self.into())
|
||||
}
|
||||
|
||||
pub fn constr(tag: usize, fields: Vec<Term<T>>) -> Self {
|
||||
Term::Constr { tag, fields }
|
||||
}
|
||||
|
||||
pub fn case(self, branches: Vec<Term<T>>) -> Self {
|
||||
Term::Case {
|
||||
constr: self.into(),
|
||||
branches,
|
||||
}
|
||||
}
|
||||
|
||||
// Primitives
|
||||
pub fn integer(i: num_bigint::BigInt) -> Self {
|
||||
Term::Constant(Constant::Integer(i).into())
|
||||
|
||||
@@ -1362,6 +1362,62 @@ impl Term<Name> {
|
||||
}
|
||||
}
|
||||
|
||||
// IMPORTANT: RUNS ONE TIME AND ONLY ON THE LAST PASS
|
||||
fn case_constr_apply_reducer(
|
||||
&mut self,
|
||||
_id: Option<usize>,
|
||||
_arg_stack: Vec<Args>,
|
||||
_scope: &Scope,
|
||||
_context: &mut Context,
|
||||
) {
|
||||
let mut term = &mut std::mem::replace(self, Term::Error.force());
|
||||
|
||||
let mut arg_vec = vec![];
|
||||
|
||||
while let Term::Apply { function, argument } = term {
|
||||
arg_vec.push(Rc::make_mut(argument));
|
||||
|
||||
term = Rc::make_mut(function);
|
||||
}
|
||||
|
||||
arg_vec.reverse();
|
||||
|
||||
match term {
|
||||
Term::Case { constr, branches }
|
||||
if branches.len() == 1 && matches!(constr.as_ref(), Term::Constr { .. }) =>
|
||||
{
|
||||
let Term::Constr { fields, .. } = Rc::make_mut(constr) else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
for arg in arg_vec {
|
||||
fields.push(std::mem::replace(arg, Term::Error.force()));
|
||||
}
|
||||
|
||||
*self = std::mem::replace(term, Term::Error.force());
|
||||
}
|
||||
_ => {
|
||||
if arg_vec.len() > 3 {
|
||||
let mut fields = vec![];
|
||||
|
||||
for arg in arg_vec {
|
||||
fields.push(std::mem::replace(arg, Term::Error.force()));
|
||||
}
|
||||
|
||||
*self = Term::constr(0, fields)
|
||||
.case(vec![std::mem::replace(term, Term::Error.force())]);
|
||||
} else {
|
||||
for arg in arg_vec {
|
||||
*term = (std::mem::replace(term, Term::Error.force()))
|
||||
.apply(std::mem::replace(arg, Term::Error.force()));
|
||||
}
|
||||
|
||||
*self = std::mem::replace(term, Term::Error.force());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn identity_reducer(
|
||||
&mut self,
|
||||
_id: Option<usize>,
|
||||
@@ -1754,7 +1810,7 @@ impl Term<Name> {
|
||||
acc.apply(arg.pierce_no_inlines().clone())
|
||||
});
|
||||
|
||||
// Check above for is error safe
|
||||
// The check above is to make sure the program is error safe
|
||||
let eval_term: Term<Name> = Program {
|
||||
version: (1, 0, 0),
|
||||
term: applied_term,
|
||||
@@ -1881,7 +1937,7 @@ impl Program<Name> {
|
||||
context,
|
||||
)
|
||||
}
|
||||
// This one runs the optimizations that are only done a single time
|
||||
// This runs the optimizations that are only done a single time
|
||||
pub fn run_once_pass(self) -> Self {
|
||||
let program = self
|
||||
.traverse_uplc_with(false, &mut |id, term, _arg_stack, scope, context| {
|
||||
@@ -1946,38 +2002,36 @@ impl Program<Name> {
|
||||
|
||||
pub fn multi_pass(self) -> (Self, Context) {
|
||||
self.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| {
|
||||
let mut changed;
|
||||
let false = term.lambda_reducer(id, arg_stack.clone(), scope, context) else {
|
||||
term.remove_inlined_ids(id, vec![], scope, context);
|
||||
return;
|
||||
};
|
||||
|
||||
changed = term.lambda_reducer(id, arg_stack.clone(), scope, context);
|
||||
if changed {
|
||||
let false = term.identity_reducer(id, arg_stack.clone(), scope, context) else {
|
||||
term.remove_inlined_ids(id, vec![], scope, context);
|
||||
return;
|
||||
}
|
||||
changed = term.identity_reducer(id, arg_stack.clone(), scope, context);
|
||||
if changed {
|
||||
};
|
||||
|
||||
let false = term.inline_reducer(id, arg_stack.clone(), scope, context) else {
|
||||
term.remove_inlined_ids(id, vec![], scope, context);
|
||||
return;
|
||||
}
|
||||
changed = term.inline_reducer(id, arg_stack.clone(), scope, context);
|
||||
if changed {
|
||||
};
|
||||
|
||||
let false = term.force_delay_reducer(id, arg_stack.clone(), scope, context) else {
|
||||
term.remove_inlined_ids(id, vec![], scope, context);
|
||||
return;
|
||||
}
|
||||
changed = term.force_delay_reducer(id, arg_stack.clone(), scope, context);
|
||||
if changed {
|
||||
};
|
||||
|
||||
let false = term.cast_data_reducer(id, arg_stack.clone(), scope, context) else {
|
||||
term.remove_inlined_ids(id, vec![], scope, context);
|
||||
return;
|
||||
}
|
||||
changed = term.cast_data_reducer(id, arg_stack.clone(), scope, context);
|
||||
if changed {
|
||||
};
|
||||
|
||||
let false = term.builtin_eval_reducer(id, arg_stack.clone(), scope, context) else {
|
||||
term.remove_inlined_ids(id, vec![], scope, context);
|
||||
return;
|
||||
}
|
||||
changed = term.builtin_eval_reducer(id, arg_stack.clone(), scope, context);
|
||||
if changed {
|
||||
term.remove_inlined_ids(id, vec![], scope, context);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
term.convert_arithmetic_ops(id, arg_stack, scope, context);
|
||||
term.flip_constants(id, vec![], scope, context);
|
||||
term.remove_inlined_ids(id, vec![], scope, context);
|
||||
@@ -2000,6 +2054,7 @@ impl Program<Name> {
|
||||
pub fn clean_up(self) -> Self {
|
||||
self.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| {
|
||||
term.remove_no_inlines(id, vec![], scope, context);
|
||||
term.case_constr_apply_reducer(id, vec![], scope, context);
|
||||
})
|
||||
.0
|
||||
}
|
||||
@@ -2043,7 +2098,6 @@ impl Program<Name> {
|
||||
) {
|
||||
// We found it the builtin was curried before
|
||||
// So now we merge the new args into the existing curried builtin
|
||||
|
||||
let curried_builtin = curried_terms.swap_remove(index);
|
||||
|
||||
let curried_builtin =
|
||||
@@ -3403,4 +3457,63 @@ mod tests {
|
||||
|
||||
compare_optimization(expected, program, |p| p.builtin_curry_reducer());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn case_constr_apply_test_1() {
|
||||
let program: Program<Name> = Program {
|
||||
version: (1, 1, 0),
|
||||
term: Term::add_integer()
|
||||
.apply(Term::integer(0.into()))
|
||||
.apply(Term::integer(0.into()))
|
||||
.apply(Term::integer(0.into()))
|
||||
.apply(Term::integer(0.into()))
|
||||
.apply(Term::integer(0.into()))
|
||||
.apply(Term::integer(0.into())),
|
||||
};
|
||||
|
||||
let expected = Program {
|
||||
version: (1, 1, 0),
|
||||
term: Term::constr(
|
||||
0,
|
||||
vec![
|
||||
Term::integer(0.into()),
|
||||
Term::integer(0.into()),
|
||||
Term::integer(0.into()),
|
||||
Term::integer(0.into()),
|
||||
Term::integer(0.into()),
|
||||
Term::integer(0.into()),
|
||||
],
|
||||
)
|
||||
.case(vec![Term::add_integer()]),
|
||||
};
|
||||
|
||||
compare_optimization(expected, program, |p| {
|
||||
p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| {
|
||||
term.case_constr_apply_reducer(id, arg_stack, scope, context);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn case_constr_apply_test_2() {
|
||||
let program: Program<Name> = Program {
|
||||
version: (1, 1, 0),
|
||||
term: Term::add_integer()
|
||||
.apply(Term::integer(0.into()))
|
||||
.apply(Term::integer(0.into())),
|
||||
};
|
||||
|
||||
let expected = Program {
|
||||
version: (1, 1, 0),
|
||||
term: Term::add_integer()
|
||||
.apply(Term::integer(0.into()))
|
||||
.apply(Term::integer(0.into())),
|
||||
};
|
||||
|
||||
compare_optimization(expected, program, |p| {
|
||||
p.run_one_opt(true, &mut |id, term, arg_stack, scope, context| {
|
||||
term.case_constr_apply_reducer(id, arg_stack, scope, context);
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user