prevent curried function hoisting if occurrences is 0

This commit is contained in:
microproofs 2024-02-16 15:55:52 -05:00 committed by Kasey
parent 258b5abf23
commit 7d8fdc0f22
1 changed files with 87 additions and 81 deletions

View File

@ -35,17 +35,17 @@ impl Scope {
pub fn push(&self, path: ScopePath) -> Self {
let mut new_scope = self.scope.clone();
new_scope.push(path);
Scope { scope: new_scope }
Self { scope: new_scope }
}
pub fn pop(&self) -> Self {
let mut new_scope = self.scope.clone();
new_scope.pop();
Scope { scope: new_scope }
Self { scope: new_scope }
}
pub fn common_ancestor(&self, other: &Scope) -> Self {
Scope {
Self {
scope: self
.scope
.iter()
@ -90,6 +90,51 @@ impl Default for IdGen {
}
}
#[derive(PartialEq, PartialOrd, Default, Debug, Clone)]
pub struct VarLookup {
found: bool,
occurrences: isize,
delays: isize,
}
impl VarLookup {
pub fn new() -> Self {
Self {
found: false,
occurrences: 0,
delays: 0,
}
}
pub fn new_found() -> Self {
Self {
found: true,
occurrences: 1,
delays: 0,
}
}
pub fn combine(self, other: Self) -> Self {
Self {
found: self.found || other.found,
occurrences: self.occurrences + other.occurrences,
delays: self.delays + other.delays,
}
}
pub fn delay_if_found(self, delay_amount: isize) -> Self {
if self.found {
Self {
found: self.found,
occurrences: self.occurrences,
delays: self.delays + delay_amount,
}
} else {
self
}
}
}
impl DefaultFunction {
pub fn is_order_agnostic_builtin(self) -> bool {
matches!(
@ -606,7 +651,7 @@ pub struct CurriedBuiltin {
impl CurriedBuiltin {
pub fn merge_node_by_path(self, path: BuiltinArgs) -> Self {
CurriedBuiltin {
Self {
func: self.func,
args: self.args.merge_node_by_path(path),
}
@ -808,25 +853,7 @@ impl Program<Name> {
let mut lambda_applied_ids = vec![];
let mut identity_applied_ids = vec![];
// TODO: Remove extra traversals
self.traverse_uplc_with(&mut |_id, term, _arg_stack, _scope| {
// Since this one just inlines single occurrences. It's probably not needed
if let Term::Apply { function, argument } = term {
let func = Rc::make_mut(function);
if let Term::Lambda {
parameter_name,
body,
} = func
{
if let Term::Var(name) = body.as_ref() {
if name.as_ref() == parameter_name.as_ref() {
*term = argument.as_ref().clone();
}
}
}
}
})
.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| {
self.traverse_uplc_with(&mut |id, term, mut arg_stack, _scope| {
match term {
Term::Apply { function, .. } => {
// We are applying some arg so now we unwrap the id of the applied arg
@ -858,7 +885,7 @@ impl Program<Name> {
replace_identity_usage(body.as_ref(), parameter_name.clone());
// Have to check if the body still has any occurrences of the parameter
// After attempting replacement
if var_occurrences(body.as_ref(), parameter_name.clone()) > 0 {
if var_occurrences(body.as_ref(), parameter_name.clone()).found {
let body = Rc::make_mut(body);
*body = temp_term;
} else {
@ -881,9 +908,8 @@ impl Program<Name> {
let id = id.unwrap();
if lambda_applied_ids.contains(&id) {
let func = Rc::make_mut(function);
// we inlined the arg so now remove the apply and arg from the program
*term = func.clone();
*term = function.as_ref().clone();
}
}
Term::Lambda {
@ -893,11 +919,10 @@ impl Program<Name> {
// pops stack here no matter what
if let Some((arg_id, arg_term)) = arg_stack.pop() {
let body = Rc::make_mut(body);
let occurrences = var_occurrences(body, parameter_name.clone());
let delays = delayed_execution(body);
let var_lookup = var_occurrences(body, parameter_name.clone());
if occurrences == 1
&& (delays == 0
if var_lookup.occurrences == 1
&& (var_lookup.delays == 0
|| matches!(
&arg_term,
Term::Var(_)
@ -913,7 +938,7 @@ impl Program<Name> {
*term = body.clone();
// This will strip out unused terms that can't throw an error by themselves
} else if occurrences == 0
} else if !var_lookup.found
&& matches!(
arg_term,
Term::Var(_)
@ -1082,7 +1107,7 @@ impl Program<Name> {
IndexMap::new();
let mut final_ids: IndexMap<Vec<usize>, ()> = IndexMap::new();
let a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term {
let step_a = self.traverse_uplc_with(&mut |_id, term, arg_stack, scope| match term {
Term::Builtin(func) => {
if func.can_curry_builtin() && arg_stack.len() == func.arity() {
let is_order_agnostic = func.is_order_agnostic_builtin();
@ -1191,7 +1216,7 @@ impl Program<Name> {
}
});
let mut b = a.traverse_uplc_with(&mut |id, term, arg_stack, scope| match term {
let mut step_b = step_a.traverse_uplc_with(&mut |id, term, arg_stack, scope| match term {
Term::Builtin(func) => {
if func.can_curry_builtin() {
let Some(curried_builtin) =
@ -1255,10 +1280,12 @@ impl Program<Name> {
.join("_")
);
if var_occurrences(term, Name::text(&name).into()).found {
*term = term.clone().lambda(name).apply(val);
}
}
}
}
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
_ => {
@ -1273,76 +1300,59 @@ impl Program<Name> {
.join("_")
);
if var_occurrences(term, Name::text(&name).into()).found {
*term = term.clone().lambda(name).apply(val);
}
}
}
}
});
let mut interner = Interner::new();
interner.program(&mut b);
interner.program(&mut step_b);
b
step_b
}
}
fn var_occurrences(term: &Term<Name>, search_for: Rc<Name>) -> usize {
fn var_occurrences(term: &Term<Name>, search_for: Rc<Name>) -> VarLookup {
match term {
Term::Var(name) => {
if name.as_ref() == search_for.as_ref() {
1
if name.text == search_for.text && name.unique == search_for.unique {
VarLookup::new_found()
} else {
0
VarLookup::new()
}
}
Term::Delay(body) => var_occurrences(body.as_ref(), search_for),
Term::Delay(body) => var_occurrences(body.as_ref(), search_for).delay_if_found(1),
Term::Lambda {
parameter_name,
body,
} => {
if parameter_name.clone() != search_for {
var_occurrences(body.as_ref(), search_for)
if parameter_name.text != search_for.text || parameter_name.unique != search_for.unique
{
var_occurrences(body.as_ref(), search_for).delay_if_found(1)
} else {
0
VarLookup::new()
}
}
Term::Apply { function, argument } => {
var_occurrences(function.as_ref(), search_for.clone())
+ var_occurrences(argument.as_ref(), search_for)
.delay_if_found(-1)
.combine(var_occurrences(argument.as_ref(), search_for))
}
Term::Force(x) => var_occurrences(x.as_ref(), search_for),
Term::Force(x) => var_occurrences(x.as_ref(), search_for).delay_if_found(-1),
Term::Case { .. } => todo!(),
Term::Constr { .. } => todo!(),
_ => 0,
}
}
fn delayed_execution(term: &Term<Name>) -> usize {
match term {
Term::Delay(body) => 1 + delayed_execution(body.as_ref()),
Term::Lambda { body, .. } => 1 + delayed_execution(body.as_ref()),
Term::Apply { function, argument } => {
delayed_execution(function.as_ref()) + delayed_execution(argument.as_ref())
}
Term::Force(x) => delayed_execution(x.as_ref()),
Term::Case { constr, branches } => {
1 + delayed_execution(constr.as_ref())
+ branches
.iter()
.fold(0, |acc, branch| acc + delayed_execution(branch))
}
Term::Constr { fields, .. } => fields
.iter()
.fold(0, |acc, field| acc + delayed_execution(field)),
_ => 0,
_ => VarLookup::new(),
}
}
fn substitute_var(term: &Term<Name>, original: Rc<Name>, replace_with: &Term<Name>) -> Term<Name> {
match term {
Term::Var(name) => {
if name.as_ref() == original.as_ref() {
if name.text == original.text && name.unique == original.unique {
replace_with.clone()
} else {
Term::Var(name.clone())
@ -1355,10 +1365,10 @@ fn substitute_var(term: &Term<Name>, original: Rc<Name>, replace_with: &Term<Nam
parameter_name,
body,
} => {
if parameter_name.as_ref() != original.as_ref() {
if parameter_name.text != original.text || parameter_name.unique != original.unique {
Term::Lambda {
parameter_name: parameter_name.clone(),
body: Rc::new(substitute_var(body.as_ref(), original, replace_with)),
body: substitute_var(body.as_ref(), original, replace_with).into(),
}
} else {
Term::Lambda {
@ -1368,14 +1378,10 @@ fn substitute_var(term: &Term<Name>, original: Rc<Name>, replace_with: &Term<Nam
}
}
Term::Apply { function, argument } => Term::Apply {
function: Rc::new(substitute_var(
function.as_ref(),
original.clone(),
replace_with,
)),
argument: Rc::new(substitute_var(argument.as_ref(), original, replace_with)),
function: substitute_var(function.as_ref(), original.clone(), replace_with).into(),
argument: substitute_var(argument.as_ref(), original, replace_with).into(),
},
Term::Force(f) => Term::Force(Rc::new(substitute_var(f.as_ref(), original, replace_with))),
Term::Force(f) => Term::Force(substitute_var(f.as_ref(), original, replace_with).into()),
Term::Case { .. } => todo!(),
Term::Constr { .. } => todo!(),
x => x.clone(),
@ -1389,7 +1395,7 @@ fn replace_identity_usage(term: &Term<Name>, original: Rc<Name>) -> Term<Name> {
parameter_name,
body,
} => {
if parameter_name.as_ref() != original.as_ref() {
if parameter_name.text != original.text || parameter_name.unique != original.unique {
Term::Lambda {
parameter_name: parameter_name.clone(),
body: Rc::new(replace_identity_usage(body.as_ref(), original)),
@ -1412,7 +1418,7 @@ fn replace_identity_usage(term: &Term<Name>, original: Rc<Name>) -> Term<Name> {
};
};
if name.as_ref() == original.as_ref() {
if name.text == original.text && name.unique == original.unique {
arg
} else {
Term::Apply {