hoist functions to super constants too

This commit is contained in:
microproofs 2024-08-14 18:43:20 -04:00 committed by KtorZ
parent cd0a9440e8
commit f674f9ee97
No known key found for this signature in database
GPG Key ID: 33173CB6F77F4277
1 changed files with 41 additions and 114 deletions

View File

@ -68,7 +68,6 @@ pub struct CodeGenerator<'a> {
defined_functions: IndexMap<FunctionAccessKey, ()>, defined_functions: IndexMap<FunctionAccessKey, ()>,
special_functions: CodeGenSpecialFuncs, special_functions: CodeGenSpecialFuncs,
code_gen_functions: IndexMap<String, CodeGenFunction>, code_gen_functions: IndexMap<String, CodeGenFunction>,
zero_arg_functions: IndexMap<(FunctionAccessKey, Variant), Vec<Air>>,
cyclic_functions: cyclic_functions:
IndexMap<(FunctionAccessKey, Variant), (CycleFunctionNames, usize, FunctionAccessKey)>, IndexMap<(FunctionAccessKey, Variant), (CycleFunctionNames, usize, FunctionAccessKey)>,
/// mutable and reset as well /// mutable and reset as well
@ -100,7 +99,6 @@ impl<'a> CodeGenerator<'a> {
defined_functions: IndexMap::new(), defined_functions: IndexMap::new(),
special_functions: CodeGenSpecialFuncs::new(), special_functions: CodeGenSpecialFuncs::new(),
code_gen_functions: IndexMap::new(), code_gen_functions: IndexMap::new(),
zero_arg_functions: IndexMap::new(),
cyclic_functions: IndexMap::new(), cyclic_functions: IndexMap::new(),
id_gen: IdGenerator::new(), id_gen: IdGenerator::new(),
} }
@ -108,7 +106,6 @@ impl<'a> CodeGenerator<'a> {
pub fn reset(&mut self, reset_special_functions: bool) { pub fn reset(&mut self, reset_special_functions: bool) {
self.code_gen_functions = IndexMap::new(); self.code_gen_functions = IndexMap::new();
self.zero_arg_functions = IndexMap::new();
self.defined_functions = IndexMap::new(); self.defined_functions = IndexMap::new();
self.cyclic_functions = IndexMap::new(); self.cyclic_functions = IndexMap::new();
self.id_gen = IdGenerator::new(); self.id_gen = IdGenerator::new();
@ -3686,11 +3683,8 @@ impl<'a> CodeGenerator<'a> {
// first grab dependencies // first grab dependencies
let func_params = params; let func_params = params;
let params_empty = func_params.is_empty();
let deps = (tree_path, func_deps.clone()); let deps = (tree_path, func_deps.clone());
if !params_empty {
let recursive_nonstatics = if is_recursive { let recursive_nonstatics = if is_recursive {
modify_self_calls(&mut body, key, variant, func_params) modify_self_calls(&mut body, key, variant, func_params)
} else { } else {
@ -3712,7 +3706,6 @@ impl<'a> CodeGenerator<'a> {
let defined_dependencies = self.hoist_dependent_functions( let defined_dependencies = self.hoist_dependent_functions(
deps, deps,
params_empty,
(key, variant), (key, variant),
hoisted_functions, hoisted_functions,
functions_to_hoist, functions_to_hoist,
@ -3723,19 +3716,6 @@ impl<'a> CodeGenerator<'a> {
*node_to_edit = defined_dependencies; *node_to_edit = defined_dependencies;
hoisted_functions.push((key.clone(), variant.clone())); hoisted_functions.push((key.clone(), variant.clone()));
} else {
let defined_func = self.hoist_dependent_functions(
deps,
params_empty,
(key, variant),
hoisted_functions,
functions_to_hoist,
body,
);
self.zero_arg_functions
.insert((key.clone(), variant.clone()), defined_func.to_vec());
}
} }
HoistableFunction::CyclicFunction { HoistableFunction::CyclicFunction {
functions, functions,
@ -3763,8 +3743,6 @@ impl<'a> CodeGenerator<'a> {
let defined_dependencies = self.hoist_dependent_functions( let defined_dependencies = self.hoist_dependent_functions(
deps, deps,
// cyclic functions always have params
false,
(key, variant), (key, variant),
hoisted_functions, hoisted_functions,
functions_to_hoist, functions_to_hoist,
@ -3788,7 +3766,6 @@ impl<'a> CodeGenerator<'a> {
fn hoist_dependent_functions( fn hoist_dependent_functions(
&mut self, &mut self,
deps: (&TreePath, Vec<(FunctionAccessKey, String)>), deps: (&TreePath, Vec<(FunctionAccessKey, String)>),
params_empty: bool,
func_key_variant: (&FunctionAccessKey, &Variant), func_key_variant: (&FunctionAccessKey, &Variant),
hoisted_functions: &mut Vec<(FunctionAccessKey, String)>, hoisted_functions: &mut Vec<(FunctionAccessKey, String)>,
functions_to_hoist: &IndexMap< functions_to_hoist: &IndexMap<
@ -3855,12 +3832,12 @@ impl<'a> CodeGenerator<'a> {
sorted_dep_vec sorted_dep_vec
.into_iter() .into_iter()
.fold(air_tree, |then, (dep_key, dep_variant)| { .fold(air_tree, |then, (dep_key, dep_variant)| {
if (!params_empty if
// if the dependency is the same as the function we're hoisting // if the dependency is the same as the function we're hoisting
// or we hoisted it, then skip it // or we hoisted it, then skip it
&& hoisted_functions.iter().any(|(generic, variant)| { hoisted_functions
generic == &dep_key && variant == &dep_variant .iter()
})) .any(|(generic, variant)| generic == &dep_key && variant == &dep_variant)
|| (&dep_key == key && &dep_variant == variant) || (&dep_key == key && &dep_variant == variant)
{ {
return then; return then;
@ -3877,7 +3854,7 @@ impl<'a> CodeGenerator<'a> {
// In the case of zero args, we need to hoist the dependency function to the top of the zero arg function // In the case of zero args, we need to hoist the dependency function to the top of the zero arg function
// The dependency we are hoisting should have an equal path to the function we hoisted // The dependency we are hoisting should have an equal path to the function we hoisted
// if we are going to hoist it // if we are going to hoist it
if &dep_path.common_ancestor(func_path) == func_path || params_empty { if &dep_path.common_ancestor(func_path) == func_path {
match dep_function.clone() { match dep_function.clone() {
HoistableFunction::Function { HoistableFunction::Function {
body: mut dep_air_tree, body: mut dep_air_tree,
@ -3904,9 +3881,7 @@ impl<'a> CodeGenerator<'a> {
dependent_params.clone() dependent_params.clone()
}; };
if !params_empty {
hoisted_functions.push((dep_key.clone(), dep_variant.clone())); hoisted_functions.push((dep_key.clone(), dep_variant.clone()));
}
AirTree::define_func( AirTree::define_func(
&dep_key.function_name, &dep_key.function_name,
@ -3926,9 +3901,7 @@ impl<'a> CodeGenerator<'a> {
modify_cyclic_calls(body, &dep_key, &self.cyclic_functions); modify_cyclic_calls(body, &dep_key, &self.cyclic_functions);
} }
if !params_empty {
hoisted_functions.push((dep_key.clone(), dep_variant.clone())); hoisted_functions.push((dep_key.clone(), dep_variant.clone()));
}
AirTree::define_cyclic_func( AirTree::define_cyclic_func(
&dep_key.function_name, &dep_key.function_name,
@ -4272,6 +4245,8 @@ impl<'a> CodeGenerator<'a> {
true, true,
); );
value = self.hoist_functions_to_validator(value);
let term = self let term = self
.uplc_code_gen(value.to_vec()) .uplc_code_gen(value.to_vec())
.constr_fields_exposer() .constr_fields_exposer()
@ -4638,55 +4613,7 @@ impl<'a> CodeGenerator<'a> {
// How we handle zero arg anon functions has changed // How we handle zero arg anon functions has changed
// We now delay zero arg anon functions and force them on a call operation // We now delay zero arg anon functions and force them on a call operation
match term.pierce_no_inlines() { match term.pierce_no_inlines() {
Term::Var(name) => { Term::Var(_) => Some(term.force()),
let zero_arg_functions = self.zero_arg_functions.clone();
let text = &name.text;
if let Some((_, air_vec)) = zero_arg_functions.iter().find(
|(
(
FunctionAccessKey {
module_name,
function_name,
},
variant,
),
_,
)| {
let name_module =
format!("{module_name}_{function_name}{variant}");
let name = format!("{function_name}{variant}");
text == &name || text == &name_module
},
) {
let mut term = self.uplc_code_gen(air_vec.clone());
term = term.constr_fields_exposer().constr_index_exposer();
let mut program: Program<Name> = Program {
version: (1, 0, 0),
term: self.special_functions.apply_used_functions(term),
};
let mut interner = CodeGenInterner::new();
interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> =
program.remove_no_inlines().try_into().unwrap();
let result = eval_program.eval(ExBudget::max()).result();
let evaluated_term: Term<NamedDeBruijn> = result.unwrap_or_else(|e| {
panic!("Evaluated a zero argument function and received this error: {e:#?}")
});
Some(evaluated_term.try_into().unwrap())
} else {
Some(term.force())
}
}
Term::Delay(inner_term) => Some(inner_term.as_ref().clone()), Term::Delay(inner_term) => Some(inner_term.as_ref().clone()),
Term::Apply { .. } => Some(term.force()), Term::Apply { .. } => Some(term.force()),
_ => unreachable!( _ => unreachable!(