Add codegen for recursive statics case
We also flip the recursive_statics fields to recursive_nonstatics; This makes the codegen a little easier. It also has a hacky way to hard code in some recursive statics for testing
This commit is contained in:
parent
586a2d7972
commit
09f889b121
|
@ -2690,6 +2690,19 @@ impl<'a> CodeGenerator<'a> {
|
||||||
// first grab dependencies
|
// first grab dependencies
|
||||||
let func_params = params;
|
let func_params = params;
|
||||||
|
|
||||||
|
// HACK: partition params into the "static recursives" and otherwise
|
||||||
|
// for now, we just do this based on the name, but it should be detected
|
||||||
|
// as an optimization pass
|
||||||
|
let recursive_static_indexes: Vec<usize> = func_params.iter().enumerate().filter(|(idx, p)| {
|
||||||
|
p.starts_with("pi_recursive_hack_")
|
||||||
|
}).map(|(idx, _)| idx).collect();
|
||||||
|
let recursive_nonstatics: Vec<String> = func_params.iter().cloned().filter(|p| {
|
||||||
|
!p.starts_with("pi_recursive_hack_")
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
println!("~~ recursive_nonstatics: {:?}", recursive_nonstatics);
|
||||||
|
println!("~~ func_params: {:?}", func_params);
|
||||||
|
|
||||||
let params_empty = func_params.is_empty();
|
let params_empty = func_params.is_empty();
|
||||||
|
|
||||||
let deps = (tree_path, func_deps.clone());
|
let deps = (tree_path, func_deps.clone());
|
||||||
|
@ -2697,7 +2710,7 @@ impl<'a> CodeGenerator<'a> {
|
||||||
if !params_empty {
|
if !params_empty {
|
||||||
if is_recursive {
|
if is_recursive {
|
||||||
body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| {
|
body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| {
|
||||||
modify_self_calls(air_tree, key, variant);
|
modify_self_calls(air_tree, key, variant, &recursive_static_indexes);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2707,7 +2720,7 @@ impl<'a> CodeGenerator<'a> {
|
||||||
variant,
|
variant,
|
||||||
func_params.clone(),
|
func_params.clone(),
|
||||||
is_recursive,
|
is_recursive,
|
||||||
vec![],
|
recursive_nonstatics,
|
||||||
body,
|
body,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -2832,7 +2845,7 @@ impl<'a> CodeGenerator<'a> {
|
||||||
|
|
||||||
if is_dependent_recursive {
|
if is_dependent_recursive {
|
||||||
dep_air_tree.traverse_tree_with(&mut |air_tree: &mut AirTree, _| {
|
dep_air_tree.traverse_tree_with(&mut |air_tree: &mut AirTree, _| {
|
||||||
modify_self_calls(air_tree, &dep_key, &dep_variant);
|
modify_self_calls(air_tree, &dep_key, &dep_variant, &vec![]);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2840,9 +2853,9 @@ impl<'a> CodeGenerator<'a> {
|
||||||
&dep_key.function_name,
|
&dep_key.function_name,
|
||||||
&dep_key.module_name,
|
&dep_key.module_name,
|
||||||
&dep_variant,
|
&dep_variant,
|
||||||
dependent_params,
|
dependent_params.clone(),
|
||||||
is_dependent_recursive,
|
is_dependent_recursive,
|
||||||
vec![],
|
dependent_params,
|
||||||
dep_air_tree,
|
dep_air_tree,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -3706,7 +3719,7 @@ impl<'a> CodeGenerator<'a> {
|
||||||
func_name,
|
func_name,
|
||||||
params,
|
params,
|
||||||
recursive,
|
recursive,
|
||||||
recursive_static_params,
|
recursive_nonstatic_params,
|
||||||
module_name,
|
module_name,
|
||||||
variant_name,
|
variant_name,
|
||||||
} => {
|
} => {
|
||||||
|
@ -3719,7 +3732,11 @@ impl<'a> CodeGenerator<'a> {
|
||||||
|
|
||||||
let mut term = arg_stack.pop().unwrap();
|
let mut term = arg_stack.pop().unwrap();
|
||||||
|
|
||||||
for param in params.iter().rev() {
|
// Introduce a parameter for each parameter
|
||||||
|
// NOTE: we use recursive_nonstatic_params here because
|
||||||
|
// if this is recursive, those are the ones that need to be passed
|
||||||
|
// each time
|
||||||
|
for param in recursive_nonstatic_params.iter().rev() {
|
||||||
func_body = func_body.lambda(param.clone());
|
func_body = func_body.lambda(param.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3730,13 +3747,31 @@ impl<'a> CodeGenerator<'a> {
|
||||||
} else {
|
} else {
|
||||||
func_body = func_body.lambda(func_name.clone());
|
func_body = func_body.lambda(func_name.clone());
|
||||||
|
|
||||||
|
if recursive_nonstatic_params == params {
|
||||||
|
// If we don't have any recursive-static params, we can just emit the function as is
|
||||||
term = term
|
term = term
|
||||||
.lambda(func_name.clone())
|
.lambda(func_name.clone())
|
||||||
.apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone())))
|
.apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone())))
|
||||||
.lambda(func_name)
|
.lambda(func_name)
|
||||||
.apply(func_body);
|
.apply(func_body);
|
||||||
|
} else {
|
||||||
|
// If we have parameters that remain static in each recursive call,
|
||||||
|
// we can construct an *outer* function to take those in
|
||||||
|
// and simplify the recursive part to only accept the non-static arguments
|
||||||
|
let mut outer_func_body = Term::var(&func_name).apply(Term::var(&func_name));
|
||||||
|
for param in recursive_nonstatic_params.iter() {
|
||||||
|
outer_func_body = outer_func_body.apply(Term::var(param));
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: use recursive_static_params here
|
outer_func_body = outer_func_body.lambda(&func_name).apply(func_body);
|
||||||
|
|
||||||
|
// Now, add *all* parameters, so that other call sites don't know the difference
|
||||||
|
for param in params.iter().rev() {
|
||||||
|
outer_func_body = outer_func_body.lambda(param);
|
||||||
|
}
|
||||||
|
|
||||||
|
term = term.lambda(&func_name).apply(outer_func_body);
|
||||||
|
}
|
||||||
|
|
||||||
arg_stack.push(term);
|
arg_stack.push(term);
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ pub enum Air {
|
||||||
module_name: String,
|
module_name: String,
|
||||||
params: Vec<String>,
|
params: Vec<String>,
|
||||||
recursive: bool,
|
recursive: bool,
|
||||||
recursive_static_params: Vec<String>,
|
recursive_nonstatic_params: Vec<String>,
|
||||||
variant_name: String,
|
variant_name: String,
|
||||||
},
|
},
|
||||||
Fn {
|
Fn {
|
||||||
|
|
|
@ -583,7 +583,7 @@ pub fn erase_opaque_type_operations(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String) {
|
pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, static_recursive_params: &Vec<usize>) {
|
||||||
if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree {
|
if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree {
|
||||||
if let AirTree::Expression(AirExpression::Var {
|
if let AirTree::Expression(AirExpression::Var {
|
||||||
constructor:
|
constructor:
|
||||||
|
@ -599,6 +599,11 @@ pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, v
|
||||||
&& module == &func_key.module_name
|
&& module == &func_key.module_name
|
||||||
&& variant == variant_name
|
&& variant == variant_name
|
||||||
{
|
{
|
||||||
|
// Remove any static-recursive-parameters, because they'll be bound statically
|
||||||
|
// above the recursive part of the function
|
||||||
|
for arg in static_recursive_params.iter().rev() {
|
||||||
|
args.remove(*arg);
|
||||||
|
}
|
||||||
let mut new_args = vec![func.as_ref().clone()];
|
let mut new_args = vec![func.as_ref().clone()];
|
||||||
new_args.append(args);
|
new_args.append(args);
|
||||||
*args = new_args;
|
*args = new_args;
|
||||||
|
|
|
@ -129,7 +129,7 @@ pub enum AirStatement {
|
||||||
module_name: String,
|
module_name: String,
|
||||||
params: Vec<String>,
|
params: Vec<String>,
|
||||||
recursive: bool,
|
recursive: bool,
|
||||||
recursive_static_params: Vec<String>,
|
recursive_nonstatic_params: Vec<String>,
|
||||||
variant_name: String,
|
variant_name: String,
|
||||||
func_body: Box<AirTree>,
|
func_body: Box<AirTree>,
|
||||||
},
|
},
|
||||||
|
@ -424,7 +424,7 @@ impl AirTree {
|
||||||
variant_name: impl ToString,
|
variant_name: impl ToString,
|
||||||
params: Vec<String>,
|
params: Vec<String>,
|
||||||
recursive: bool,
|
recursive: bool,
|
||||||
recursive_static_params: Vec<String>,
|
recursive_nonstatic_params: Vec<String>,
|
||||||
func_body: AirTree,
|
func_body: AirTree,
|
||||||
) -> AirTree {
|
) -> AirTree {
|
||||||
AirTree::Statement {
|
AirTree::Statement {
|
||||||
|
@ -433,7 +433,7 @@ impl AirTree {
|
||||||
module_name: module_name.to_string(),
|
module_name: module_name.to_string(),
|
||||||
params,
|
params,
|
||||||
recursive,
|
recursive,
|
||||||
recursive_static_params,
|
recursive_nonstatic_params,
|
||||||
variant_name: variant_name.to_string(),
|
variant_name: variant_name.to_string(),
|
||||||
func_body: func_body.into(),
|
func_body: func_body.into(),
|
||||||
},
|
},
|
||||||
|
@ -878,7 +878,7 @@ impl AirTree {
|
||||||
module_name,
|
module_name,
|
||||||
params,
|
params,
|
||||||
recursive,
|
recursive,
|
||||||
recursive_static_params,
|
recursive_nonstatic_params,
|
||||||
variant_name,
|
variant_name,
|
||||||
func_body,
|
func_body,
|
||||||
} => {
|
} => {
|
||||||
|
@ -887,7 +887,7 @@ impl AirTree {
|
||||||
module_name: module_name.clone(),
|
module_name: module_name.clone(),
|
||||||
params: params.clone(),
|
params: params.clone(),
|
||||||
recursive: *recursive,
|
recursive: *recursive,
|
||||||
recursive_static_params: recursive_static_params.clone(),
|
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
|
||||||
variant_name: variant_name.clone(),
|
variant_name: variant_name.clone(),
|
||||||
});
|
});
|
||||||
func_body.create_air_vec(air_vec);
|
func_body.create_air_vec(air_vec);
|
||||||
|
|
Loading…
Reference in New Issue