prevent curried function hoisting if occurrences is 0

This commit is contained in:
microproofs 2024-02-16 15:55:52 -05:00 committed by Kasey
parent 258b5abf23
commit 7d8fdc0f22
1 changed files with 87 additions and 81 deletions

View File

@ -35,17 +35,17 @@ impl Scope {
pub fn push(&self, path: ScopePath) -> Self { pub fn push(&self, path: ScopePath) -> Self {
let mut new_scope = self.scope.clone(); let mut new_scope = self.scope.clone();
new_scope.push(path); new_scope.push(path);
Scope { scope: new_scope } Self { scope: new_scope }
} }
pub fn pop(&self) -> Self { pub fn pop(&self) -> Self {
let mut new_scope = self.scope.clone(); let mut new_scope = self.scope.clone();
new_scope.pop(); new_scope.pop();
Scope { scope: new_scope } Self { scope: new_scope }
} }
pub fn common_ancestor(&self, other: &Scope) -> Self { pub fn common_ancestor(&self, other: &Scope) -> Self {
Scope { Self {
scope: self scope: self
.scope .scope
.iter() .iter()
@ -90,6 +90,51 @@ impl Default for IdGen {
} }
} }
#[derive(PartialEq, PartialOrd, Default, Debug, Clone)]
pub struct VarLookup {
found: bool,
occurrences: isize,
delays: isize,
}
impl VarLookup {
pub fn new() -> Self {
Self {
found: false,
occurrences: 0,
delays: 0,
}
}
pub fn new_found() -> Self {
Self {
found: true,
occurrences: 1,
delays: 0,
}
}
pub fn combine(self, other: Self) -> Self {
Self {
found: self.found || other.found,
occurrences: self.occurrences + other.occurrences,
delays: self.delays + other.delays,
}
}
pub fn delay_if_found(self, delay_amount: isize) -> Self {
if self.found {
Self {
found: self.found,
occurrences: self.occurrences,
delays: self.delays + delay_amount,
}
} else {
self
}
}
}
impl DefaultFunction { impl DefaultFunction {
pub fn is_order_agnostic_builtin(self) -> bool { pub fn is_order_agnostic_builtin(self) -> bool {
matches!( matches!(
@ -606,7 +651,7 @@ pub struct CurriedBuiltin {
impl CurriedBuiltin { impl CurriedBuiltin {
pub fn merge_node_by_path(self, path: BuiltinArgs) -> Self { pub fn merge_node_by_path(self, path: BuiltinArgs) -> Self {
CurriedBuiltin { Self {
func: self.func, func: self.func,
args: self.args.merge_node_by_path(path), args: self.args.merge_node_by_path(path),
} }
@ -808,25 +853,7 @@ impl Program<Name> {
let mut lambda_applied_ids = vec![]; let mut lambda_applied_ids = vec![];
let mut identity_applied_ids = vec![]; let mut identity_applied_ids = vec![];
// TODO: Remove extra traversals // TODO: Remove extra traversals
self.traverse_uplc_with(&mut |_id, term, _arg_stack, _scope| { self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| {
// Since this one just inlines single occurrences. It's probably not needed
if let Term::Apply { function, argument } = term {
let func = Rc::make_mut(function);
if let Term::Lambda {
parameter_name,
body,
} = func
{
if let Term::Var(name) = body.as_ref() {
if name.as_ref() == parameter_name.as_ref() {
*term = argument.as_ref().clone();
}
}
}
}
})
.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| {
match term { match term {
Term::Apply { function, .. } => { Term::Apply { function, .. } => {
// We are applying some arg so now we unwrap the id of the applied arg // We are applying some arg so now we unwrap the id of the applied arg
@ -858,7 +885,7 @@ impl Program<Name> {
replace_identity_usage(body.as_ref(), parameter_name.clone()); replace_identity_usage(body.as_ref(), parameter_name.clone());
// Have to check if the body still has any occurrences of the parameter // Have to check if the body still has any occurrences of the parameter
// After attempting replacement // After attempting replacement
if var_occurrences(body.as_ref(), parameter_name.clone()) > 0 { if var_occurrences(body.as_ref(), parameter_name.clone()).found {
let body = Rc::make_mut(body); let body = Rc::make_mut(body);
*body = temp_term; *body = temp_term;
} else { } else {
@ -881,9 +908,8 @@ impl Program<Name> {
let id = id.unwrap(); let id = id.unwrap();
if lambda_applied_ids.contains(&id) { if lambda_applied_ids.contains(&id) {
let func = Rc::make_mut(function);
// we inlined the arg so now remove the apply and arg from the program // we inlined the arg so now remove the apply and arg from the program
*term = func.clone(); *term = function.as_ref().clone();
} }
} }
Term::Lambda { Term::Lambda {
@ -893,11 +919,10 @@ impl Program<Name> {
// pops stack here no matter what // pops stack here no matter what
if let Some((arg_id, arg_term)) = arg_stack.pop() { if let Some((arg_id, arg_term)) = arg_stack.pop() {
let body = Rc::make_mut(body); let body = Rc::make_mut(body);
let occurrences = var_occurrences(body, parameter_name.clone()); let var_lookup = var_occurrences(body, parameter_name.clone());
let delays = delayed_execution(body);
if occurrences == 1 if var_lookup.occurrences == 1
&& (delays == 0 && (var_lookup.delays == 0
|| matches!( || matches!(
&arg_term, &arg_term,
Term::Var(_) Term::Var(_)
@ -913,7 +938,7 @@ impl Program<Name> {
*term = body.clone(); *term = body.clone();
// This will strip out unused terms that can't throw an error by themselves // This will strip out unused terms that can't throw an error by themselves
} else if occurrences == 0 } else if !var_lookup.found
&& matches!( && matches!(
arg_term, arg_term,
Term::Var(_) Term::Var(_)
@ -1082,7 +1107,7 @@ impl Program<Name> {
IndexMap::new(); IndexMap::new();
let mut final_ids: IndexMap<Vec<usize>, ()> = IndexMap::new(); let mut final_ids: IndexMap<Vec<usize>, ()> = IndexMap::new();
let a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term { let step_a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term {
Term::Builtin(func) => { Term::Builtin(func) => {
if func.can_curry_builtin() && arg_stack.len() == func.arity() { if func.can_curry_builtin() && arg_stack.len() == func.arity() {
let is_order_agnostic = func.is_order_agnostic_builtin(); let is_order_agnostic = func.is_order_agnostic_builtin();
@ -1191,7 +1216,7 @@ impl Program<Name> {
} }
}); });
let mut b = a.traverse_uplc_with(&mut |id, term, arg_stack, scope| match term { let mut step_b = step_a.traverse_uplc_with(&mut |id, term, arg_stack, scope| match term {
Term::Builtin(func) => { Term::Builtin(func) => {
if func.can_curry_builtin() { if func.can_curry_builtin() {
let Some(curried_builtin) = let Some(curried_builtin) =
@ -1255,10 +1280,12 @@ impl Program<Name> {
.join("_") .join("_")
); );
if var_occurrences(term, Name::text(&name).into()).found {
*term = term.clone().lambda(name).apply(val); *term = term.clone().lambda(name).apply(val);
} }
} }
} }
}
Term::Constr { .. } => todo!(), Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(), Term::Case { .. } => todo!(),
_ => { _ => {
@ -1273,76 +1300,59 @@ impl Program<Name> {
.join("_") .join("_")
); );
if var_occurrences(term, Name::text(&name).into()).found {
*term = term.clone().lambda(name).apply(val); *term = term.clone().lambda(name).apply(val);
} }
} }
} }
}
}); });
let mut interner = Interner::new(); let mut interner = Interner::new();
interner.program(&mut b); interner.program(&mut step_b);
b step_b
} }
} }
fn var_occurrences(term: &Term<Name>, search_for: Rc<Name>) -> usize { fn var_occurrences(term: &Term<Name>, search_for: Rc<Name>) -> VarLookup {
match term { match term {
Term::Var(name) => { Term::Var(name) => {
if name.as_ref() == search_for.as_ref() { if name.text == search_for.text && name.unique == search_for.unique {
1 VarLookup::new_found()
} else { } else {
0 VarLookup::new()
} }
} }
Term::Delay(body) => var_occurrences(body.as_ref(), search_for), Term::Delay(body) => var_occurrences(body.as_ref(), search_for).delay_if_found(1),
Term::Lambda { Term::Lambda {
parameter_name, parameter_name,
body, body,
} => { } => {
if parameter_name.clone() != search_for { if parameter_name.text != search_for.text || parameter_name.unique != search_for.unique
var_occurrences(body.as_ref(), search_for) {
var_occurrences(body.as_ref(), search_for).delay_if_found(1)
} else { } else {
0 VarLookup::new()
} }
} }
Term::Apply { function, argument } => { Term::Apply { function, argument } => {
var_occurrences(function.as_ref(), search_for.clone()) var_occurrences(function.as_ref(), search_for.clone())
+ var_occurrences(argument.as_ref(), search_for) .delay_if_found(-1)
.combine(var_occurrences(argument.as_ref(), search_for))
} }
Term::Force(x) => var_occurrences(x.as_ref(), search_for), Term::Force(x) => var_occurrences(x.as_ref(), search_for).delay_if_found(-1),
Term::Case { .. } => todo!(), Term::Case { .. } => todo!(),
Term::Constr { .. } => todo!(), Term::Constr { .. } => todo!(),
_ => 0, _ => VarLookup::new(),
}
}
fn delayed_execution(term: &Term<Name>) -> usize {
match term {
Term::Delay(body) => 1 + delayed_execution(body.as_ref()),
Term::Lambda { body, .. } => 1 + delayed_execution(body.as_ref()),
Term::Apply { function, argument } => {
delayed_execution(function.as_ref()) + delayed_execution(argument.as_ref())
}
Term::Force(x) => delayed_execution(x.as_ref()),
Term::Case { constr, branches } => {
1 + delayed_execution(constr.as_ref())
+ branches
.iter()
.fold(0, |acc, branch| acc + delayed_execution(branch))
}
Term::Constr { fields, .. } => fields
.iter()
.fold(0, |acc, field| acc + delayed_execution(field)),
_ => 0,
} }
} }
fn substitute_var(term: &Term<Name>, original: Rc<Name>, replace_with: &Term<Name>) -> Term<Name> { fn substitute_var(term: &Term<Name>, original: Rc<Name>, replace_with: &Term<Name>) -> Term<Name> {
match term { match term {
Term::Var(name) => { Term::Var(name) => {
if name.as_ref() == original.as_ref() { if name.text == original.text && name.unique == original.unique {
replace_with.clone() replace_with.clone()
} else { } else {
Term::Var(name.clone()) Term::Var(name.clone())
@ -1355,10 +1365,10 @@ fn substitute_var(term: &Term<Name>, original: Rc<Name>, replace_with: &Term<Nam
parameter_name, parameter_name,
body, body,
} => { } => {
if parameter_name.as_ref() != original.as_ref() { if parameter_name.text != original.text || parameter_name.unique != original.unique {
Term::Lambda { Term::Lambda {
parameter_name: parameter_name.clone(), parameter_name: parameter_name.clone(),
body: Rc::new(substitute_var(body.as_ref(), original, replace_with)), body: substitute_var(body.as_ref(), original, replace_with).into(),
} }
} else { } else {
Term::Lambda { Term::Lambda {
@ -1368,14 +1378,10 @@ fn substitute_var(term: &Term<Name>, original: Rc<Name>, replace_with: &Term<Nam
} }
} }
Term::Apply { function, argument } => Term::Apply { Term::Apply { function, argument } => Term::Apply {
function: Rc::new(substitute_var( function: substitute_var(function.as_ref(), original.clone(), replace_with).into(),
function.as_ref(), argument: substitute_var(argument.as_ref(), original, replace_with).into(),
original.clone(),
replace_with,
)),
argument: Rc::new(substitute_var(argument.as_ref(), original, replace_with)),
}, },
Term::Force(f) => Term::Force(Rc::new(substitute_var(f.as_ref(), original, replace_with))), Term::Force(f) => Term::Force(substitute_var(f.as_ref(), original, replace_with).into()),
Term::Case { .. } => todo!(), Term::Case { .. } => todo!(),
Term::Constr { .. } => todo!(), Term::Constr { .. } => todo!(),
x => x.clone(), x => x.clone(),
@ -1389,7 +1395,7 @@ fn replace_identity_usage(term: &Term<Name>, original: Rc<Name>) -> Term<Name> {
parameter_name, parameter_name,
body, body,
} => { } => {
if parameter_name.as_ref() != original.as_ref() { if parameter_name.text != original.text || parameter_name.unique != original.unique {
Term::Lambda { Term::Lambda {
parameter_name: parameter_name.clone(), parameter_name: parameter_name.clone(),
body: Rc::new(replace_identity_usage(body.as_ref(), original)), body: Rc::new(replace_identity_usage(body.as_ref(), original)),
@ -1412,7 +1418,7 @@ fn replace_identity_usage(term: &Term<Name>, original: Rc<Name>) -> Term<Name> {
}; };
}; };
if name.as_ref() == original.as_ref() { if name.text == original.text && name.unique == original.unique {
arg arg
} else { } else {
Term::Apply { Term::Apply {