hoist functions to super constants too

This commit is contained in:
microproofs 2024-08-14 18:43:20 -04:00 committed by KtorZ
parent cd0a9440e8
commit f674f9ee97
No known key found for this signature in database
GPG Key ID: 33173CB6F77F4277
1 changed files with 41 additions and 114 deletions

View File

@ -68,7 +68,6 @@ pub struct CodeGenerator<'a> {
defined_functions: IndexMap<FunctionAccessKey, ()>,
special_functions: CodeGenSpecialFuncs,
code_gen_functions: IndexMap<String, CodeGenFunction>,
zero_arg_functions: IndexMap<(FunctionAccessKey, Variant), Vec<Air>>,
cyclic_functions:
IndexMap<(FunctionAccessKey, Variant), (CycleFunctionNames, usize, FunctionAccessKey)>,
/// mutable and reset as well
@ -100,7 +99,6 @@ impl<'a> CodeGenerator<'a> {
defined_functions: IndexMap::new(),
special_functions: CodeGenSpecialFuncs::new(),
code_gen_functions: IndexMap::new(),
zero_arg_functions: IndexMap::new(),
cyclic_functions: IndexMap::new(),
id_gen: IdGenerator::new(),
}
@ -108,7 +106,6 @@ impl<'a> CodeGenerator<'a> {
pub fn reset(&mut self, reset_special_functions: bool) {
self.code_gen_functions = IndexMap::new();
self.zero_arg_functions = IndexMap::new();
self.defined_functions = IndexMap::new();
self.cyclic_functions = IndexMap::new();
self.id_gen = IdGenerator::new();
@ -3686,56 +3683,39 @@ impl<'a> CodeGenerator<'a> {
// first grab dependencies
let func_params = params;
let params_empty = func_params.is_empty();
let deps = (tree_path, func_deps.clone());
if !params_empty {
let recursive_nonstatics = if is_recursive {
modify_self_calls(&mut body, key, variant, func_params)
} else {
func_params.clone()
};
let node_to_edit = air_tree.find_air_tree_node(tree_path);
let defined_function = AirTree::define_func(
&key.function_name,
&key.module_name,
variant,
func_params.clone(),
is_recursive,
recursive_nonstatics,
body,
node_to_edit.clone(),
);
let defined_dependencies = self.hoist_dependent_functions(
deps,
params_empty,
(key, variant),
hoisted_functions,
functions_to_hoist,
defined_function,
);
// now hoist full function onto validator tree
*node_to_edit = defined_dependencies;
hoisted_functions.push((key.clone(), variant.clone()));
let recursive_nonstatics = if is_recursive {
modify_self_calls(&mut body, key, variant, func_params)
} else {
let defined_func = self.hoist_dependent_functions(
deps,
params_empty,
(key, variant),
hoisted_functions,
functions_to_hoist,
body,
);
func_params.clone()
};
self.zero_arg_functions
.insert((key.clone(), variant.clone()), defined_func.to_vec());
}
let node_to_edit = air_tree.find_air_tree_node(tree_path);
let defined_function = AirTree::define_func(
&key.function_name,
&key.module_name,
variant,
func_params.clone(),
is_recursive,
recursive_nonstatics,
body,
node_to_edit.clone(),
);
let defined_dependencies = self.hoist_dependent_functions(
deps,
(key, variant),
hoisted_functions,
functions_to_hoist,
defined_function,
);
// now hoist full function onto validator tree
*node_to_edit = defined_dependencies;
hoisted_functions.push((key.clone(), variant.clone()));
}
HoistableFunction::CyclicFunction {
functions,
@ -3763,8 +3743,6 @@ impl<'a> CodeGenerator<'a> {
let defined_dependencies = self.hoist_dependent_functions(
deps,
// cyclic functions always have params
false,
(key, variant),
hoisted_functions,
functions_to_hoist,
@ -3788,7 +3766,6 @@ impl<'a> CodeGenerator<'a> {
fn hoist_dependent_functions(
&mut self,
deps: (&TreePath, Vec<(FunctionAccessKey, String)>),
params_empty: bool,
func_key_variant: (&FunctionAccessKey, &Variant),
hoisted_functions: &mut Vec<(FunctionAccessKey, String)>,
functions_to_hoist: &IndexMap<
@ -3855,12 +3832,12 @@ impl<'a> CodeGenerator<'a> {
sorted_dep_vec
.into_iter()
.fold(air_tree, |then, (dep_key, dep_variant)| {
if (!params_empty
// if the dependency is the same as the function we're hoisting
// or we hoisted it, then skip it
&& hoisted_functions.iter().any(|(generic, variant)| {
generic == &dep_key && variant == &dep_variant
}))
if
// if the dependency is the same as the function we're hoisting
// or we hoisted it, then skip it
hoisted_functions
.iter()
.any(|(generic, variant)| generic == &dep_key && variant == &dep_variant)
|| (&dep_key == key && &dep_variant == variant)
{
return then;
@ -3877,7 +3854,7 @@ impl<'a> CodeGenerator<'a> {
// In the case of zero args, we need to hoist the dependency function to the top of the zero arg function
// The dependency we are hoisting should have an equal path to the function we hoisted
// if we are going to hoist it
if &dep_path.common_ancestor(func_path) == func_path || params_empty {
if &dep_path.common_ancestor(func_path) == func_path {
match dep_function.clone() {
HoistableFunction::Function {
body: mut dep_air_tree,
@ -3904,9 +3881,7 @@ impl<'a> CodeGenerator<'a> {
dependent_params.clone()
};
if !params_empty {
hoisted_functions.push((dep_key.clone(), dep_variant.clone()));
}
hoisted_functions.push((dep_key.clone(), dep_variant.clone()));
AirTree::define_func(
&dep_key.function_name,
@ -3926,9 +3901,7 @@ impl<'a> CodeGenerator<'a> {
modify_cyclic_calls(body, &dep_key, &self.cyclic_functions);
}
if !params_empty {
hoisted_functions.push((dep_key.clone(), dep_variant.clone()));
}
hoisted_functions.push((dep_key.clone(), dep_variant.clone()));
AirTree::define_cyclic_func(
&dep_key.function_name,
@ -4272,6 +4245,8 @@ impl<'a> CodeGenerator<'a> {
true,
);
value = self.hoist_functions_to_validator(value);
let term = self
.uplc_code_gen(value.to_vec())
.constr_fields_exposer()
@ -4638,55 +4613,7 @@ impl<'a> CodeGenerator<'a> {
// How we handle zero arg anon functions has changed
// We now delay zero arg anon functions and force them on a call operation
match term.pierce_no_inlines() {
Term::Var(name) => {
let zero_arg_functions = self.zero_arg_functions.clone();
let text = &name.text;
if let Some((_, air_vec)) = zero_arg_functions.iter().find(
|(
(
FunctionAccessKey {
module_name,
function_name,
},
variant,
),
_,
)| {
let name_module =
format!("{module_name}_{function_name}{variant}");
let name = format!("{function_name}{variant}");
text == &name || text == &name_module
},
) {
let mut term = self.uplc_code_gen(air_vec.clone());
term = term.constr_fields_exposer().constr_index_exposer();
let mut program: Program<Name> = Program {
version: (1, 0, 0),
term: self.special_functions.apply_used_functions(term),
};
let mut interner = CodeGenInterner::new();
interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> =
program.remove_no_inlines().try_into().unwrap();
let result = eval_program.eval(ExBudget::max()).result();
let evaluated_term: Term<NamedDeBruijn> = result.unwrap_or_else(|e| {
panic!("Evaluated a zero argument function and received this error: {e:#?}")
});
Some(evaluated_term.try_into().unwrap())
} else {
Some(term.force())
}
}
Term::Var(_) => Some(term.force()),
Term::Delay(inner_term) => Some(inner_term.as_ref().clone()),
Term::Apply { .. } => Some(term.force()),
_ => unreachable!(