From 06ca22c26a803dc4af8e0b8e29edd06052c356ca Mon Sep 17 00:00:00 2001 From: microproofs Date: Sat, 2 Mar 2024 13:39:21 -0500 Subject: [PATCH] update inliner to handle no_inline functions --- crates/aiken-lang/src/gen_uplc.rs | 19 ++-- crates/uplc/src/optimize.rs | 3 + crates/uplc/src/optimize/shrinker.rs | 157 +++++++++++++++++++++------ 3 files changed, 140 insertions(+), 39 deletions(-) diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index 5a3f71e3..2032bd1b 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -3933,7 +3933,8 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); - let eval_program: Program = program.try_into().unwrap(); + let eval_program: Program = + program.remove_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program.eval(ExBudget::default()).result().unwrap(); @@ -4132,9 +4133,9 @@ impl<'a> CodeGenerator<'a> { } if params.is_empty() { - Some(term.delay()) + Some(term.lambda(NO_INLINE).delay()) } else { - Some(term) + Some(term.lambda(NO_INLINE)) } } Air::Call { count, .. } => { @@ -4187,7 +4188,8 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); - let eval_program: Program = program.try_into().unwrap(); + let eval_program: Program = + program.remove_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program.eval(ExBudget::max()).result().unwrap(); @@ -4535,7 +4537,8 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); - let eval_program: Program = program.try_into().unwrap(); + let eval_program: Program = + program.remove_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program.eval(ExBudget::default()).result().unwrap(); @@ -4559,7 +4562,8 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); - let eval_program: Program = program.try_into().unwrap(); + let eval_program: Program = + program.remove_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program.eval(ExBudget::default()).result().unwrap(); @@ -4964,7 +4968,8 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); - let eval_program: Program = program.try_into().unwrap(); + let eval_program: Program = + program.remove_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program.eval(ExBudget::default()).result().unwrap(); diff --git a/crates/uplc/src/optimize.rs b/crates/uplc/src/optimize.rs index 7868a2ff..e8cc47fa 100644 --- a/crates/uplc/src/optimize.rs +++ b/crates/uplc/src/optimize.rs @@ -8,6 +8,7 @@ pub fn aiken_optimize_and_intern(program: Program) -> Program { .builtin_force_reducer() .lambda_reducer() .inline_reducer() + .identity_reducer() .lambda_reducer() .inline_reducer() .force_delay_reducer() @@ -16,7 +17,9 @@ pub fn aiken_optimize_and_intern(program: Program) -> Program { .builtin_curry_reducer() .lambda_reducer() .inline_reducer() + .identity_reducer() .builtin_curry_reducer() .lambda_reducer() .inline_reducer() + .remove_no_inlines() } diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index dd60f148..d328c50f 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -93,6 +93,7 @@ pub struct VarLookup { found: bool, occurrences: isize, delays: isize, + no_inline: bool, } impl VarLookup { @@ -101,6 +102,7 @@ impl VarLookup { found: false, occurrences: 0, delays: 0, + no_inline: false, } } @@ -109,6 +111,7 @@ impl VarLookup { found: true, occurrences: 1, delays: 0, + no_inline: false, } } @@ -117,6 +120,7 @@ impl VarLookup { found: self.found || other.found, occurrences: self.occurrences + other.occurrences, delays: self.delays + other.delays, + no_inline: self.no_inline || other.no_inline, } } @@ -126,6 +130,20 @@ impl VarLookup { found: self.found, occurrences: self.occurrences, delays: self.delays + delay_amount, + no_inline: self.no_inline, + } + } else { + self + } + } + + pub fn no_inline_if_found(self) -> Self { + if self.found { + Self { + found: self.found, + occurrences: self.occurrences, + delays: self.delays, + no_inline: true, } } else { self @@ -697,11 +715,11 @@ impl Term { match self { Term::Apply { function, argument } => { let arg = Rc::make_mut(argument); - let argument_arg_stack = vec![]; + Self::traverse_uplc_with_helper( arg, &scope.push(ScopePath::ARG), - argument_arg_stack, + vec![], id_gen, with, ); @@ -729,10 +747,19 @@ impl Term { Self::traverse_uplc_with_helper(d, scope, arg_stack, id_gen, with); with(None, self, vec![], scope); } - Term::Lambda { body, .. } => { + Term::Lambda { + body, + parameter_name, + } => { let body = Rc::make_mut(body); // Lambda pops one item off the arg stack. If there is no item then it is a unsaturated lambda - let args = arg_stack.pop().map(|arg| vec![arg]).unwrap_or_default(); + // We also skip NO_INLINE lambdas since those are placeholder lambdas created by codegen + + let args = if parameter_name.text == NO_INLINE { + vec![] + } else { + arg_stack.pop().map(|arg| vec![arg]).unwrap_or_default() + }; // Pass in either one or zero args. Self::traverse_uplc_with_helper(body, scope, arg_stack, id_gen, with); @@ -883,10 +910,8 @@ impl Program { Program::::try_from(program).unwrap() } - pub fn inline_reducer(self) -> Self { - let mut lambda_applied_ids = vec![]; + pub fn identity_reducer(self) -> Self { let mut identity_applied_ids = vec![]; - // TODO: Remove extra traversals self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| { match term { Term::Apply { function, .. } => { @@ -919,7 +944,14 @@ impl Program { replace_identity_usage(body.as_ref(), parameter_name.clone()); // Have to check if the body still has any occurrences of the parameter // After attempting replacement - if var_occurrences(body.as_ref(), parameter_name.clone()).found { + if var_occurrences( + body.as_ref(), + parameter_name.clone(), + vec![], + vec![], + ) + .found + { let body = Rc::make_mut(body); *body = temp_term; } else { @@ -936,7 +968,12 @@ impl Program { _ => {} } }) - .traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| match term { + } + + pub fn inline_reducer(self) -> Self { + let mut lambda_applied_ids = vec![]; + + self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| match term { Term::Apply { function, .. } => { // We are applying some arg so now we unwrap the id of the applied arg let id = id.unwrap(); @@ -952,20 +989,35 @@ impl Program { } => { // pops stack here no matter what if let Some((arg_id, arg_term)) = arg_stack.pop() { - let body = Rc::make_mut(body); - let var_lookup = var_occurrences(body, parameter_name.clone()); + let arg_term = match &arg_term { + Term::Lambda { + parameter_name, + body, + } if parameter_name.text == NO_INLINE => body.as_ref().clone(), + _ => arg_term, + }; - if var_lookup.occurrences == 1 - && (var_lookup.delays == 0 - || matches!( - &arg_term, - Term::Var(_) - | Term::Constant(_) - | Term::Delay(_) - | Term::Lambda { .. } - | Term::Builtin(_), - )) - { + let body = Rc::make_mut(body); + let var_lookup = var_occurrences(body, parameter_name.clone(), vec![], vec![]); + + assert!( + var_lookup.delays >= 0, + "HOW {} AND {:#?}", + parameter_name.text, + var_lookup + ); + + let substitute_condition = (var_lookup.delays == 0 && !var_lookup.no_inline) + || matches!( + &arg_term, + Term::Var(_) + | Term::Constant(_) + | Term::Delay(_) + | Term::Lambda { .. } + | Term::Builtin(_), + ); + + if var_lookup.occurrences == 1 && substitute_condition { *body = substitute_var(body, parameter_name.clone(), &arg_term); lambda_applied_ids.push(arg_id); @@ -1005,6 +1057,16 @@ impl Program { }) } + pub fn remove_no_inlines(self) -> Self { + self.traverse_uplc_with(&mut |_, term, _, _| match term { + Term::Lambda { + parameter_name, + body, + } if parameter_name.text == NO_INLINE => *term = body.as_ref().clone(), + _ => {} + }) + } + pub fn cast_data_reducer(self) -> Self { let mut applied_ids = vec![]; @@ -1347,7 +1409,8 @@ impl Program { for (key, val) in insert_list.into_iter().rev() { let name = id_vec_function_to_var(&key.func_name, &key.id_vec); - if var_occurrences(term, Name::text(&name).into()).found { + if var_occurrences(term, Name::text(&name).into(), vec![], vec![]).found + { *term = term.clone().lambda(name).apply(val); } } @@ -1360,7 +1423,8 @@ impl Program { for (key, val) in insert_list.into_iter().rev() { let name = id_vec_function_to_var(&key.func_name, &key.id_vec); - if var_occurrences(term, Name::text(&name).into()).found { + if var_occurrences(term, Name::text(&name).into(), vec![], vec![]).found + { *term = term.clone().lambda(name).apply(val); } } @@ -1388,7 +1452,12 @@ fn id_vec_function_to_var(func_name: &str, id_vec: &[usize]) -> String { ) } -fn var_occurrences(term: &Term, search_for: Rc) -> VarLookup { +fn var_occurrences( + term: &Term, + search_for: Rc, + mut arg_stack: Vec<()>, + mut force_stack: Vec<()>, +) -> VarLookup { match term { Term::Var(name) => { if name.text == search_for.text && name.unique == search_for.unique { @@ -1397,24 +1466,48 @@ fn var_occurrences(term: &Term, search_for: Rc) -> VarLookup { VarLookup::new() } } - Term::Delay(body) => var_occurrences(body.as_ref(), search_for).delay_if_found(1), + Term::Delay(body) => { + let not_forced: isize = isize::from(force_stack.pop().is_none()); + + var_occurrences(body, search_for, arg_stack, force_stack).delay_if_found(not_forced) + } Term::Lambda { parameter_name, body, } => { - if parameter_name.text != search_for.text || parameter_name.unique != search_for.unique + if parameter_name.text == NO_INLINE { + var_occurrences(body.as_ref(), search_for, arg_stack, force_stack) + .no_inline_if_found() + } else if parameter_name.text != search_for.text + || parameter_name.unique != search_for.unique { - var_occurrences(body.as_ref(), search_for).delay_if_found(1) + let not_applied: isize = isize::from(arg_stack.pop().is_none()); + var_occurrences(body.as_ref(), search_for, arg_stack, force_stack) + .delay_if_found(not_applied) } else { VarLookup::new() } } Term::Apply { function, argument } => { - var_occurrences(function.as_ref(), search_for.clone()) - .delay_if_found(-1) - .combine(var_occurrences(argument.as_ref(), search_for)) + arg_stack.push(()); + + var_occurrences( + function.as_ref(), + search_for.clone(), + arg_stack, + force_stack, + ) + .combine(var_occurrences( + argument.as_ref(), + search_for, + vec![], + vec![], + )) + } + Term::Force(x) => { + force_stack.push(()); + var_occurrences(x.as_ref(), search_for, arg_stack, force_stack) } - Term::Force(x) => var_occurrences(x.as_ref(), search_for).delay_if_found(-1), Term::Case { .. } => todo!(), Term::Constr { .. } => todo!(), _ => VarLookup::new(),