New optimization to split independent lam function applications to enable case constr to optimize further

This commit is contained in:
microproofs 2025-01-11 17:04:26 +07:00
parent d559e384ec
commit 09ddec6b41
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
5 changed files with 167 additions and 27 deletions

View File

@ -3712,7 +3712,7 @@ impl<'a> CodeGenerator<'a> {
interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> =
program.clean_up(false).try_into().unwrap();
program.clean_up_no_inlines().try_into().unwrap();
Some(
eval_program
@ -3822,7 +3822,7 @@ impl<'a> CodeGenerator<'a> {
interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> =
program.clean_up(false).try_into().unwrap();
program.clean_up_no_inlines().try_into().unwrap();
let evaluated_term: Term<NamedDeBruijn> = eval_program
.eval(ExBudget::default())
@ -4364,7 +4364,7 @@ impl<'a> CodeGenerator<'a> {
interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> =
program.clean_up(false).try_into().unwrap();
program.clean_up_no_inlines().try_into().unwrap();
let evaluated_term: Term<NamedDeBruijn> = eval_program
.eval(ExBudget::default())
@ -4389,7 +4389,7 @@ impl<'a> CodeGenerator<'a> {
interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> =
program.clean_up(false).try_into().unwrap();
program.clean_up_no_inlines().try_into().unwrap();
let evaluated_term: Term<NamedDeBruijn> = eval_program
.eval(ExBudget::default())
@ -4802,7 +4802,7 @@ impl<'a> CodeGenerator<'a> {
interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> =
program.clean_up(false).try_into().unwrap();
program.clean_up_no_inlines().try_into().unwrap();
let evaluated_term: Term<NamedDeBruijn> = eval_program
.eval(ExBudget::default())

View File

@ -1,7 +1,6 @@
---
source: crates/aiken-project/src/export.rs
description: "Code:\n\npub fn add(a: Int, b: Int) -> Int {\n a + b\n}\n"
snapshot_kind: text
---
{
"name": "test_module.add",
@ -25,8 +24,8 @@ snapshot_kind: text
"$ref": "#/definitions/Int"
}
},
"compiledCode": "500101002322337000046eb4004dd68009",
"hash": "b8374597a772cef80d891b7f6a03588e10cc19b780251228ba4ce9c6",
"compiledCode": "500101002232337000026eb4008dd68011",
"hash": "e5951afb3263ef11acc0b4c88cd5f5b30b8621ce63fe024b3ea2bec8",
"definitions": {
"Int": {
"dataType": "integer"

View File

@ -24,8 +24,8 @@ description: "Code:\n\npub type Foo<a> {\n Empty\n Bar(a, Foo<a>)\n}\n\npub fn
"$ref": "#/definitions/Int"
}
},
"compiledCode": "5901870101009800aba2aba1aba0aab9eaab9dab9a488888888c8c8c8c8c966002600860126ea8006264b30013005300a375400314800226466e00dd698070009980226103d8798000300e300f001300b37540028048c030c034016264b30013370e900118051baa0018991919b80337006eb4c03c008dd6980780099802980798080011807980800098061baa002300b3754005132337006eb4c038004cc010c038c03c00530103d8798000300b37540048048c030c0340150081805802180080091119192cc004c018c02cdd5000c4c966002600e60186ea80062900044c8cdc01bad30100019800803d300103d879800098081808800a00e300d37540028058c038c03c00a264b30013370e900118061baa0018991919b80337006eb4c044008dd69808800cc00402260226024005301130120014020601c6ea8008c034dd500144c8cdc01bad30100019800803cc040c044006980103d8798000401c601a6ea800900b180718078012014300d0013300b0023300b0014bd701b8748000cc018008cc0180052f5c01",
"hash": "247535960781372d3b2097595ebd748bd61be7c8f2f264e460e095b3",
"compiledCode": "590186010100229800aba2aba1aba0aab9eaab9dab9a9b874800122222223322332259800980298039baa0018992cc004c018c020dd5000c5200089919b80375a60180026600898103d8798000300c300d001300937540028038c028c02c012264b30013370e900118041baa0018999119b80337006eb4c034008dd6980680099802980698070011806980700098049baa00230093754003132337006eb4c030004cc010c030c03400530103d8798000300937540048038c028c02c01100618008009804001198028049980280425eb80888c8c966002600c60106ea8006264b300130073009375400314800226466e00dd69806800cc00401e98103d879800098069807000a00e300a37540028040c02cc03000a264b30013370e900118049baa0018999119b80337006eb4c038008dd69807000cc004022601c601e005300e300f001402060146ea8008c028dd5000c4c8cdc01bad300d0019800803cc034c038006980103d8798000401c60146ea800900818059806001200e300a00133008002330080014bd701",
"hash": "dc9b9c2bbcfb1cb422534ed1c4d04f2e2b9b57a0a498175d055f83e8",
"definitions": {
"Int": {
"dataType": "integer"

View File

@ -38,5 +38,5 @@ pub fn aiken_optimize_and_intern(program: Program<Name>) -> Program<Name> {
}
}
prog.clean_up(true)
prog.clean_up_no_inlines().afterwards()
}

View File

@ -1409,6 +1409,135 @@ impl Term<Name> {
}
}
// The ultimate function when used in conjunction with case_constr_apply
// This splits [lam fun_name [lam fun_name2 rest ..] ..] into
// [[lam fun_name lam fun_name2 rest ..]..] thus
// allowing for some crazy gains from cast_constr_apply_reducer
fn split_body_lambda(&mut self) {
let mut arg_stack = vec![];
let mut current_term = &mut std::mem::replace(self, Term::Error.force());
let mut unsat_lams = vec![];
let mut function_groups: Vec<Vec<(Rc<Name>, Term<Name>)>> = vec![vec![]];
loop {
match current_term {
Term::Apply { function, argument } => {
current_term = Rc::make_mut(function);
let arg = Rc::make_mut(argument);
arg.split_body_lambda();
arg_stack.push(std::mem::replace(arg, Term::Error.force()));
}
Term::Lambda {
parameter_name,
body,
} => {
current_term = Rc::make_mut(body);
if let Some(arg) = arg_stack.pop() {
let names = arg.get_var_names();
let func = (parameter_name.clone(), arg);
if let Some((position, _)) =
function_groups.iter().enumerate().rfind(|named_functions| {
named_functions
.1
.iter()
.any(|(name, _)| names.contains(name))
})
{
let insert_position = position + 1;
if insert_position == function_groups.len() {
function_groups.push(vec![func]);
} else {
function_groups[insert_position].push(func);
}
} else {
function_groups[0].push(func);
}
} else {
unsat_lams.push(parameter_name.clone());
}
}
Term::Delay(term) | Term::Force(term) => {
Rc::make_mut(term).split_body_lambda();
break;
}
Term::Case { .. } => todo!(),
Term::Constr { .. } => todo!(),
_ => break,
}
}
let term_to_build_on = std::mem::replace(current_term, Term::Error.force());
// Replace args that weren't consumed
let term = arg_stack
.into_iter()
.rfold(term_to_build_on, |term, arg| term.apply(arg));
let term = function_groups.into_iter().rfold(term, |term, group| {
let term = group.iter().rfold(term, |term, (name, _)| Term::Lambda {
parameter_name: name.clone(),
body: term.into(),
});
group
.into_iter()
.fold(term, |term, (_, arg)| term.apply(arg))
});
let term = unsat_lams
.into_iter()
.rfold(term, |term, name| Term::Lambda {
parameter_name: name.clone(),
body: term.into(),
});
*self = term;
}
fn get_var_names(&self) -> Vec<Rc<Name>> {
let mut names = vec![];
let mut term = self;
loop {
match term {
Term::Apply { function, argument } => {
let arg_names = argument.get_var_names();
names.extend(arg_names);
term = function;
}
Term::Var(name) => {
names.push(name.clone());
break;
}
Term::Delay(t) => {
term = t;
}
Term::Lambda { body, .. } => {
term = body;
}
Term::Constant(_) | Term::Error | Term::Builtin(_) => {
break;
}
Term::Force(t) => {
term = t;
}
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
}
}
names
}
// IMPORTANT: RUNS ONE TIME AND ONLY ON THE LAST PASS
fn case_constr_apply_reducer(
&mut self,
@ -2079,14 +2208,14 @@ impl Program<Name> {
}
// This runs the optimizations that are only done a single time
pub fn run_once_pass(self) -> Self {
let program = self
// First pass is necessary to ensure fst_pair and snd_pair are inlined before
// builtin_force_reducer is run
let (program, context) = self
.traverse_uplc_with(false, &mut |id, term, _arg_stack, scope, context| {
term.inline_constr_ops(id, vec![], scope, context);
})
.0;
let (program, context) =
program.traverse_uplc_with(false, &mut |id, term, arg_stack, scope, context| {
.0
.traverse_uplc_with(false, &mut |id, term, arg_stack, scope, context| {
term.bls381_compressor(id, vec![], scope, context);
term.builtin_force_reducer(id, arg_stack, scope, context);
term.remove_inlined_ids(id, vec![], scope, context);
@ -2193,20 +2322,26 @@ impl Program<Name> {
program
}
pub fn clean_up(self, case: bool) -> Self {
let (mut program, context) = self
.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| {
term.remove_no_inlines(id, vec![], scope, context);
})
.0
.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| {
term.write_bits_convert_arg(id, arg_stack, scope, context);
pub fn clean_up_no_inlines(self) -> Self {
self.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| {
term.remove_no_inlines(id, vec![], scope, context);
})
.0
}
if case {
term.case_constr_apply_reducer(id, vec![], scope, context);
}
pub fn afterwards(self) -> Self {
let (mut program, context) =
self.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| {
term.write_bits_convert_arg(id, arg_stack, scope, context);
});
program = program
.split_body_lambda_reducer()
.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| {
term.case_constr_apply_reducer(id, vec![], scope, context);
})
.0;
if context.write_bits_convert {
program.term = program.term.data_list_to_integer_list();
}
@ -2451,6 +2586,12 @@ impl Program<Name> {
step_b
}
pub fn split_body_lambda_reducer(mut self) -> Self {
self.term.split_body_lambda();
self
}
}
fn id_vec_function_to_var(func_name: &str, id_vec: &[usize]) -> String {