Minor fix to optimization to actually detect vars that are just forced builtins

This commit is contained in:
microproofs 2025-01-17 11:34:41 +07:00
parent 91b6e6da31
commit 1075be1b71
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
1 changed files with 23 additions and 13 deletions

View File

@ -10,6 +10,7 @@ use indexmap::IndexMap;
use itertools::{FoldWhile, Itertools};
use pallas_primitives::conway::{BigInt, PlutusData};
use std::{cmp::Ordering, iter, ops::Neg, rc::Rc};
use strum::IntoEnumIterator;
#[derive(Eq, Hash, PartialEq, Clone, Debug, PartialOrd)]
pub enum ScopePath {
@ -348,6 +349,12 @@ impl DefaultFunction {
format!("__{}_wrapped", self.aiken_name())
}
}
pub fn forceable_wrapped_names() -> Vec<String> {
DefaultFunction::iter()
.filter(|df| df.force_count() > 0)
.map(|df| df.wrapped_name())
.collect_vec()
}
#[derive(PartialEq, Clone, Debug)]
pub enum BuiltinArgs {
@ -1316,7 +1323,7 @@ impl Term<Name> {
// So it costs more size to have them hoisted
Term::Delay(e) if matches!(e.as_ref(), Term::Error) => true,
// If it wraps a builtin with consts or arguments passed in then inline
Term::Lambda { .. } => arg_term.is_a_builtin_wrapper(context),
Term::Lambda { .. } => arg_term.is_a_builtin_wrapper(),
// Inline smaller terms too
Term::Constant(_) | Term::Var(_) | Term::Builtin(_) => true,
@ -1430,7 +1437,7 @@ impl Term<Name> {
arg.split_body_lambda();
arg_stack.push(std::mem::replace(arg, Term::Error.force()));
arg_stack.push(Args::Apply(0, std::mem::replace(arg, Term::Error.force())));
}
Term::Lambda {
parameter_name,
@ -1438,7 +1445,7 @@ impl Term<Name> {
} => {
current_term = Rc::make_mut(body);
if let Some(arg) = arg_stack.pop() {
if let Some(Args::Apply(_, arg)) = arg_stack.pop() {
let names = arg.get_var_names();
let func = (parameter_name.clone(), arg);
@ -1467,7 +1474,12 @@ impl Term<Name> {
unsat_lams.push(parameter_name.clone());
}
}
Term::Delay(term) | Term::Force(term) => {
Term::Force(term) => {
current_term = Rc::make_mut(term);
arg_stack.push(Args::Force(0));
}
Term::Delay(term) => {
Rc::make_mut(term).split_body_lambda();
break;
}
@ -1522,7 +1534,10 @@ impl Term<Name> {
// Replace args that weren't consumed
let term = arg_stack
.into_iter()
.rfold(term_to_build_on, |term, arg| term.apply(arg));
.rfold(term_to_build_on, |term, arg| match arg {
Args::Force(_) => term.force(),
Args::Apply(_, arg) => term.apply(arg),
});
let term = function_groups.into_iter().rfold(term, |term, group| {
let term = group.iter().rfold(term, |term, (name, _)| Term::Lambda {
@ -2212,7 +2227,7 @@ impl Term<Name> {
std::mem::replace(term, Term::Error.force())
}
fn is_a_builtin_wrapper(&self, context: &Context) -> bool {
fn is_a_builtin_wrapper(&self) -> bool {
let (names, term) = self.pop_lambdas_and_get_names();
let mut arg_names = vec![];
@ -2233,12 +2248,7 @@ impl Term<Name> {
}
let func_is_builtin = match term {
Term::Var(name) => context
.builtins_map
.keys()
.map(|func| func.wrapped_name())
.any(|func| func == name.text),
Term::Var(name) => forceable_wrapped_names().contains(&name.text),
Term::Builtin(_) => true,
_ => false,
};
@ -2508,7 +2518,7 @@ impl Program<Name> {
id_vec
} else {
// Brand new buitlin so we add it to the list
// Brand new builtin so we add it to the list
let curried_builtin = builtin_args.clone().args_to_curried_args(*func);
let Some(id_vec) = curried_builtin.get_id_args(builtin_args) else {