update inliner to handle no_inline functions

This commit is contained in:
microproofs 2024-03-02 13:39:21 -05:00 committed by Kasey
parent 4e928f39db
commit 06ca22c26a
3 changed files with 140 additions and 39 deletions

View File

@ -3933,7 +3933,8 @@ impl<'a> CodeGenerator<'a> {
interner.program(&mut program); interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> = program.try_into().unwrap(); let eval_program: Program<NamedDeBruijn> =
program.remove_no_inlines().try_into().unwrap();
let evaluated_term: Term<NamedDeBruijn> = let evaluated_term: Term<NamedDeBruijn> =
eval_program.eval(ExBudget::default()).result().unwrap(); eval_program.eval(ExBudget::default()).result().unwrap();
@ -4132,9 +4133,9 @@ impl<'a> CodeGenerator<'a> {
} }
if params.is_empty() { if params.is_empty() {
Some(term.delay()) Some(term.lambda(NO_INLINE).delay())
} else { } else {
Some(term) Some(term.lambda(NO_INLINE))
} }
} }
Air::Call { count, .. } => { Air::Call { count, .. } => {
@ -4187,7 +4188,8 @@ impl<'a> CodeGenerator<'a> {
interner.program(&mut program); interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> = program.try_into().unwrap(); let eval_program: Program<NamedDeBruijn> =
program.remove_no_inlines().try_into().unwrap();
let evaluated_term: Term<NamedDeBruijn> = let evaluated_term: Term<NamedDeBruijn> =
eval_program.eval(ExBudget::max()).result().unwrap(); eval_program.eval(ExBudget::max()).result().unwrap();
@ -4535,7 +4537,8 @@ impl<'a> CodeGenerator<'a> {
interner.program(&mut program); interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> = program.try_into().unwrap(); let eval_program: Program<NamedDeBruijn> =
program.remove_no_inlines().try_into().unwrap();
let evaluated_term: Term<NamedDeBruijn> = let evaluated_term: Term<NamedDeBruijn> =
eval_program.eval(ExBudget::default()).result().unwrap(); eval_program.eval(ExBudget::default()).result().unwrap();
@ -4559,7 +4562,8 @@ impl<'a> CodeGenerator<'a> {
interner.program(&mut program); interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> = program.try_into().unwrap(); let eval_program: Program<NamedDeBruijn> =
program.remove_no_inlines().try_into().unwrap();
let evaluated_term: Term<NamedDeBruijn> = let evaluated_term: Term<NamedDeBruijn> =
eval_program.eval(ExBudget::default()).result().unwrap(); eval_program.eval(ExBudget::default()).result().unwrap();
@ -4964,7 +4968,8 @@ impl<'a> CodeGenerator<'a> {
interner.program(&mut program); interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> = program.try_into().unwrap(); let eval_program: Program<NamedDeBruijn> =
program.remove_no_inlines().try_into().unwrap();
let evaluated_term: Term<NamedDeBruijn> = let evaluated_term: Term<NamedDeBruijn> =
eval_program.eval(ExBudget::default()).result().unwrap(); eval_program.eval(ExBudget::default()).result().unwrap();

View File

@ -8,6 +8,7 @@ pub fn aiken_optimize_and_intern(program: Program<Name>) -> Program<Name> {
.builtin_force_reducer() .builtin_force_reducer()
.lambda_reducer() .lambda_reducer()
.inline_reducer() .inline_reducer()
.identity_reducer()
.lambda_reducer() .lambda_reducer()
.inline_reducer() .inline_reducer()
.force_delay_reducer() .force_delay_reducer()
@ -16,7 +17,9 @@ pub fn aiken_optimize_and_intern(program: Program<Name>) -> Program<Name> {
.builtin_curry_reducer() .builtin_curry_reducer()
.lambda_reducer() .lambda_reducer()
.inline_reducer() .inline_reducer()
.identity_reducer()
.builtin_curry_reducer() .builtin_curry_reducer()
.lambda_reducer() .lambda_reducer()
.inline_reducer() .inline_reducer()
.remove_no_inlines()
} }

View File

@ -93,6 +93,7 @@ pub struct VarLookup {
found: bool, found: bool,
occurrences: isize, occurrences: isize,
delays: isize, delays: isize,
no_inline: bool,
} }
impl VarLookup { impl VarLookup {
@ -101,6 +102,7 @@ impl VarLookup {
found: false, found: false,
occurrences: 0, occurrences: 0,
delays: 0, delays: 0,
no_inline: false,
} }
} }
@ -109,6 +111,7 @@ impl VarLookup {
found: true, found: true,
occurrences: 1, occurrences: 1,
delays: 0, delays: 0,
no_inline: false,
} }
} }
@ -117,6 +120,7 @@ impl VarLookup {
found: self.found || other.found, found: self.found || other.found,
occurrences: self.occurrences + other.occurrences, occurrences: self.occurrences + other.occurrences,
delays: self.delays + other.delays, delays: self.delays + other.delays,
no_inline: self.no_inline || other.no_inline,
} }
} }
@ -126,6 +130,20 @@ impl VarLookup {
found: self.found, found: self.found,
occurrences: self.occurrences, occurrences: self.occurrences,
delays: self.delays + delay_amount, delays: self.delays + delay_amount,
no_inline: self.no_inline,
}
} else {
self
}
}
pub fn no_inline_if_found(self) -> Self {
if self.found {
Self {
found: self.found,
occurrences: self.occurrences,
delays: self.delays,
no_inline: true,
} }
} else { } else {
self self
@ -697,11 +715,11 @@ impl Term<Name> {
match self { match self {
Term::Apply { function, argument } => { Term::Apply { function, argument } => {
let arg = Rc::make_mut(argument); let arg = Rc::make_mut(argument);
let argument_arg_stack = vec![];
Self::traverse_uplc_with_helper( Self::traverse_uplc_with_helper(
arg, arg,
&scope.push(ScopePath::ARG), &scope.push(ScopePath::ARG),
argument_arg_stack, vec![],
id_gen, id_gen,
with, with,
); );
@ -729,10 +747,19 @@ impl Term<Name> {
Self::traverse_uplc_with_helper(d, scope, arg_stack, id_gen, with); Self::traverse_uplc_with_helper(d, scope, arg_stack, id_gen, with);
with(None, self, vec![], scope); with(None, self, vec![], scope);
} }
Term::Lambda { body, .. } => { Term::Lambda {
body,
parameter_name,
} => {
let body = Rc::make_mut(body); let body = Rc::make_mut(body);
// Lambda pops one item off the arg stack. If there is no item then it is a unsaturated lambda // Lambda pops one item off the arg stack. If there is no item then it is a unsaturated lambda
let args = arg_stack.pop().map(|arg| vec![arg]).unwrap_or_default(); // We also skip NO_INLINE lambdas since those are placeholder lambdas created by codegen
let args = if parameter_name.text == NO_INLINE {
vec![]
} else {
arg_stack.pop().map(|arg| vec![arg]).unwrap_or_default()
};
// Pass in either one or zero args. // Pass in either one or zero args.
Self::traverse_uplc_with_helper(body, scope, arg_stack, id_gen, with); Self::traverse_uplc_with_helper(body, scope, arg_stack, id_gen, with);
@ -883,10 +910,8 @@ impl Program<Name> {
Program::<Name>::try_from(program).unwrap() Program::<Name>::try_from(program).unwrap()
} }
pub fn inline_reducer(self) -> Self { pub fn identity_reducer(self) -> Self {
let mut lambda_applied_ids = vec![];
let mut identity_applied_ids = vec![]; let mut identity_applied_ids = vec![];
// TODO: Remove extra traversals
self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| { self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| {
match term { match term {
Term::Apply { function, .. } => { Term::Apply { function, .. } => {
@ -919,7 +944,14 @@ impl Program<Name> {
replace_identity_usage(body.as_ref(), parameter_name.clone()); replace_identity_usage(body.as_ref(), parameter_name.clone());
// Have to check if the body still has any occurrences of the parameter // Have to check if the body still has any occurrences of the parameter
// After attempting replacement // After attempting replacement
if var_occurrences(body.as_ref(), parameter_name.clone()).found { if var_occurrences(
body.as_ref(),
parameter_name.clone(),
vec![],
vec![],
)
.found
{
let body = Rc::make_mut(body); let body = Rc::make_mut(body);
*body = temp_term; *body = temp_term;
} else { } else {
@ -936,7 +968,12 @@ impl Program<Name> {
_ => {} _ => {}
} }
}) })
.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| match term { }
pub fn inline_reducer(self) -> Self {
let mut lambda_applied_ids = vec![];
self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| match term {
Term::Apply { function, .. } => { Term::Apply { function, .. } => {
// We are applying some arg so now we unwrap the id of the applied arg // We are applying some arg so now we unwrap the id of the applied arg
let id = id.unwrap(); let id = id.unwrap();
@ -952,11 +989,25 @@ impl Program<Name> {
} => { } => {
// pops stack here no matter what // pops stack here no matter what
if let Some((arg_id, arg_term)) = arg_stack.pop() { if let Some((arg_id, arg_term)) = arg_stack.pop() {
let body = Rc::make_mut(body); let arg_term = match &arg_term {
let var_lookup = var_occurrences(body, parameter_name.clone()); Term::Lambda {
parameter_name,
body,
} if parameter_name.text == NO_INLINE => body.as_ref().clone(),
_ => arg_term,
};
if var_lookup.occurrences == 1 let body = Rc::make_mut(body);
&& (var_lookup.delays == 0 let var_lookup = var_occurrences(body, parameter_name.clone(), vec![], vec![]);
assert!(
var_lookup.delays >= 0,
"HOW {} AND {:#?}",
parameter_name.text,
var_lookup
);
let substitute_condition = (var_lookup.delays == 0 && !var_lookup.no_inline)
|| matches!( || matches!(
&arg_term, &arg_term,
Term::Var(_) Term::Var(_)
@ -964,8 +1015,9 @@ impl Program<Name> {
| Term::Delay(_) | Term::Delay(_)
| Term::Lambda { .. } | Term::Lambda { .. }
| Term::Builtin(_), | Term::Builtin(_),
)) );
{
if var_lookup.occurrences == 1 && substitute_condition {
*body = substitute_var(body, parameter_name.clone(), &arg_term); *body = substitute_var(body, parameter_name.clone(), &arg_term);
lambda_applied_ids.push(arg_id); lambda_applied_ids.push(arg_id);
@ -1005,6 +1057,16 @@ impl Program<Name> {
}) })
} }
pub fn remove_no_inlines(self) -> Self {
self.traverse_uplc_with(&mut |_, term, _, _| match term {
Term::Lambda {
parameter_name,
body,
} if parameter_name.text == NO_INLINE => *term = body.as_ref().clone(),
_ => {}
})
}
pub fn cast_data_reducer(self) -> Self { pub fn cast_data_reducer(self) -> Self {
let mut applied_ids = vec![]; let mut applied_ids = vec![];
@ -1347,7 +1409,8 @@ impl Program<Name> {
for (key, val) in insert_list.into_iter().rev() { for (key, val) in insert_list.into_iter().rev() {
let name = id_vec_function_to_var(&key.func_name, &key.id_vec); let name = id_vec_function_to_var(&key.func_name, &key.id_vec);
if var_occurrences(term, Name::text(&name).into()).found { if var_occurrences(term, Name::text(&name).into(), vec![], vec![]).found
{
*term = term.clone().lambda(name).apply(val); *term = term.clone().lambda(name).apply(val);
} }
} }
@ -1360,7 +1423,8 @@ impl Program<Name> {
for (key, val) in insert_list.into_iter().rev() { for (key, val) in insert_list.into_iter().rev() {
let name = id_vec_function_to_var(&key.func_name, &key.id_vec); let name = id_vec_function_to_var(&key.func_name, &key.id_vec);
if var_occurrences(term, Name::text(&name).into()).found { if var_occurrences(term, Name::text(&name).into(), vec![], vec![]).found
{
*term = term.clone().lambda(name).apply(val); *term = term.clone().lambda(name).apply(val);
} }
} }
@ -1388,7 +1452,12 @@ fn id_vec_function_to_var(func_name: &str, id_vec: &[usize]) -> String {
) )
} }
fn var_occurrences(term: &Term<Name>, search_for: Rc<Name>) -> VarLookup { fn var_occurrences(
term: &Term<Name>,
search_for: Rc<Name>,
mut arg_stack: Vec<()>,
mut force_stack: Vec<()>,
) -> VarLookup {
match term { match term {
Term::Var(name) => { Term::Var(name) => {
if name.text == search_for.text && name.unique == search_for.unique { if name.text == search_for.text && name.unique == search_for.unique {
@ -1397,24 +1466,48 @@ fn var_occurrences(term: &Term<Name>, search_for: Rc<Name>) -> VarLookup {
VarLookup::new() VarLookup::new()
} }
} }
Term::Delay(body) => var_occurrences(body.as_ref(), search_for).delay_if_found(1), Term::Delay(body) => {
let not_forced: isize = isize::from(force_stack.pop().is_none());
var_occurrences(body, search_for, arg_stack, force_stack).delay_if_found(not_forced)
}
Term::Lambda { Term::Lambda {
parameter_name, parameter_name,
body, body,
} => { } => {
if parameter_name.text != search_for.text || parameter_name.unique != search_for.unique if parameter_name.text == NO_INLINE {
var_occurrences(body.as_ref(), search_for, arg_stack, force_stack)
.no_inline_if_found()
} else if parameter_name.text != search_for.text
|| parameter_name.unique != search_for.unique
{ {
var_occurrences(body.as_ref(), search_for).delay_if_found(1) let not_applied: isize = isize::from(arg_stack.pop().is_none());
var_occurrences(body.as_ref(), search_for, arg_stack, force_stack)
.delay_if_found(not_applied)
} else { } else {
VarLookup::new() VarLookup::new()
} }
} }
Term::Apply { function, argument } => { Term::Apply { function, argument } => {
var_occurrences(function.as_ref(), search_for.clone()) arg_stack.push(());
.delay_if_found(-1)
.combine(var_occurrences(argument.as_ref(), search_for)) var_occurrences(
function.as_ref(),
search_for.clone(),
arg_stack,
force_stack,
)
.combine(var_occurrences(
argument.as_ref(),
search_for,
vec![],
vec![],
))
}
Term::Force(x) => {
force_stack.push(());
var_occurrences(x.as_ref(), search_for, arg_stack, force_stack)
} }
Term::Force(x) => var_occurrences(x.as_ref(), search_for).delay_if_found(-1),
Term::Case { .. } => todo!(), Term::Case { .. } => todo!(),
Term::Constr { .. } => todo!(), Term::Constr { .. } => todo!(),
_ => VarLookup::new(), _ => VarLookup::new(),