Add codegen for recursive statics case

We also flip the recursive_statics fields to recursive_nonstatics; This makes the codegen a little easier.  It also has a hacky way to hard code in some recursive statics for testing
This commit is contained in:
Pi Lanningham 2023-07-28 21:41:37 -04:00 committed by Kasey
parent 586a2d7972
commit 09f889b121
4 changed files with 60 additions and 20 deletions

View File

@ -2690,6 +2690,19 @@ impl<'a> CodeGenerator<'a> {
// first grab dependencies // first grab dependencies
let func_params = params; let func_params = params;
// HACK: partition params into the "static recursives" and otherwise
// for now, we just do this based on the name, but it should be detected
// as an optimization pass
let recursive_static_indexes: Vec<usize> = func_params.iter().enumerate().filter(|(idx, p)| {
p.starts_with("pi_recursive_hack_")
}).map(|(idx, _)| idx).collect();
let recursive_nonstatics: Vec<String> = func_params.iter().cloned().filter(|p| {
!p.starts_with("pi_recursive_hack_")
}).collect();
println!("~~ recursive_nonstatics: {:?}", recursive_nonstatics);
println!("~~ func_params: {:?}", func_params);
let params_empty = func_params.is_empty(); let params_empty = func_params.is_empty();
let deps = (tree_path, func_deps.clone()); let deps = (tree_path, func_deps.clone());
@ -2697,7 +2710,7 @@ impl<'a> CodeGenerator<'a> {
if !params_empty { if !params_empty {
if is_recursive { if is_recursive {
body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| {
modify_self_calls(air_tree, key, variant); modify_self_calls(air_tree, key, variant, &recursive_static_indexes);
}); });
} }
@ -2707,7 +2720,7 @@ impl<'a> CodeGenerator<'a> {
variant, variant,
func_params.clone(), func_params.clone(),
is_recursive, is_recursive,
vec![], recursive_nonstatics,
body, body,
); );
@ -2832,7 +2845,7 @@ impl<'a> CodeGenerator<'a> {
if is_dependent_recursive { if is_dependent_recursive {
dep_air_tree.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { dep_air_tree.traverse_tree_with(&mut |air_tree: &mut AirTree, _| {
modify_self_calls(air_tree, &dep_key, &dep_variant); modify_self_calls(air_tree, &dep_key, &dep_variant, &vec![]);
}); });
} }
@ -2840,9 +2853,9 @@ impl<'a> CodeGenerator<'a> {
&dep_key.function_name, &dep_key.function_name,
&dep_key.module_name, &dep_key.module_name,
&dep_variant, &dep_variant,
dependent_params, dependent_params.clone(),
is_dependent_recursive, is_dependent_recursive,
vec![], dependent_params,
dep_air_tree, dep_air_tree,
)); ));
@ -3706,7 +3719,7 @@ impl<'a> CodeGenerator<'a> {
func_name, func_name,
params, params,
recursive, recursive,
recursive_static_params, recursive_nonstatic_params,
module_name, module_name,
variant_name, variant_name,
} => { } => {
@ -3719,7 +3732,11 @@ impl<'a> CodeGenerator<'a> {
let mut term = arg_stack.pop().unwrap(); let mut term = arg_stack.pop().unwrap();
for param in params.iter().rev() { // Introduce a parameter for each parameter
// NOTE: we use recursive_nonstatic_params here because
// if this is recursive, those are the ones that need to be passed
// each time
for param in recursive_nonstatic_params.iter().rev() {
func_body = func_body.lambda(param.clone()); func_body = func_body.lambda(param.clone());
} }
@ -3730,13 +3747,31 @@ impl<'a> CodeGenerator<'a> {
} else { } else {
func_body = func_body.lambda(func_name.clone()); func_body = func_body.lambda(func_name.clone());
if recursive_nonstatic_params == params {
// If we don't have any recursive-static params, we can just emit the function as is
term = term term = term
.lambda(func_name.clone()) .lambda(func_name.clone())
.apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone()))) .apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone())))
.lambda(func_name) .lambda(func_name)
.apply(func_body); .apply(func_body);
} else {
// If we have parameters that remain static in each recursive call,
// we can construct an *outer* function to take those in
// and simplify the recursive part to only accept the non-static arguments
let mut outer_func_body = Term::var(&func_name).apply(Term::var(&func_name));
for param in recursive_nonstatic_params.iter() {
outer_func_body = outer_func_body.apply(Term::var(param));
}
// TODO: use recursive_static_params here outer_func_body = outer_func_body.lambda(&func_name).apply(func_body);
// Now, add *all* parameters, so that other call sites don't know the difference
for param in params.iter().rev() {
outer_func_body = outer_func_body.lambda(param);
}
term = term.lambda(&func_name).apply(outer_func_body);
}
arg_stack.push(term); arg_stack.push(term);
} }

View File

@ -47,7 +47,7 @@ pub enum Air {
module_name: String, module_name: String,
params: Vec<String>, params: Vec<String>,
recursive: bool, recursive: bool,
recursive_static_params: Vec<String>, recursive_nonstatic_params: Vec<String>,
variant_name: String, variant_name: String,
}, },
Fn { Fn {

View File

@ -583,7 +583,7 @@ pub fn erase_opaque_type_operations(
} }
} }
pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String) { pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, static_recursive_params: &Vec<usize>) {
if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree { if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree {
if let AirTree::Expression(AirExpression::Var { if let AirTree::Expression(AirExpression::Var {
constructor: constructor:
@ -599,6 +599,11 @@ pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, v
&& module == &func_key.module_name && module == &func_key.module_name
&& variant == variant_name && variant == variant_name
{ {
// Remove any static-recursive-parameters, because they'll be bound statically
// above the recursive part of the function
for arg in static_recursive_params.iter().rev() {
args.remove(*arg);
}
let mut new_args = vec![func.as_ref().clone()]; let mut new_args = vec![func.as_ref().clone()];
new_args.append(args); new_args.append(args);
*args = new_args; *args = new_args;

View File

@ -129,7 +129,7 @@ pub enum AirStatement {
module_name: String, module_name: String,
params: Vec<String>, params: Vec<String>,
recursive: bool, recursive: bool,
recursive_static_params: Vec<String>, recursive_nonstatic_params: Vec<String>,
variant_name: String, variant_name: String,
func_body: Box<AirTree>, func_body: Box<AirTree>,
}, },
@ -424,7 +424,7 @@ impl AirTree {
variant_name: impl ToString, variant_name: impl ToString,
params: Vec<String>, params: Vec<String>,
recursive: bool, recursive: bool,
recursive_static_params: Vec<String>, recursive_nonstatic_params: Vec<String>,
func_body: AirTree, func_body: AirTree,
) -> AirTree { ) -> AirTree {
AirTree::Statement { AirTree::Statement {
@ -433,7 +433,7 @@ impl AirTree {
module_name: module_name.to_string(), module_name: module_name.to_string(),
params, params,
recursive, recursive,
recursive_static_params, recursive_nonstatic_params,
variant_name: variant_name.to_string(), variant_name: variant_name.to_string(),
func_body: func_body.into(), func_body: func_body.into(),
}, },
@ -878,7 +878,7 @@ impl AirTree {
module_name, module_name,
params, params,
recursive, recursive,
recursive_static_params, recursive_nonstatic_params,
variant_name, variant_name,
func_body, func_body,
} => { } => {
@ -887,7 +887,7 @@ impl AirTree {
module_name: module_name.clone(), module_name: module_name.clone(),
params: params.clone(), params: params.clone(),
recursive: *recursive, recursive: *recursive,
recursive_static_params: recursive_static_params.clone(), recursive_nonstatic_params: recursive_nonstatic_params.clone(),
variant_name: variant_name.clone(), variant_name: variant_name.clone(),
}); });
func_body.create_air_vec(air_vec); func_body.create_air_vec(air_vec);