Rudimentary implementation

Adds an identify_recursive_static_params; doesn't handle all shadowing cases yet
This commit is contained in:
Pi Lanningham 2023-07-28 21:55:50 -04:00 committed by Kasey
parent 09f889b121
commit c45caaefc8
3 changed files with 128 additions and 21 deletions

View File

@ -2,7 +2,7 @@ pub mod air;
pub mod builder; pub mod builder;
pub mod tree; pub mod tree;
use std::sync::Arc; use std::{sync::Arc, collections::HashMap};
use indexmap::{IndexMap, IndexSet}; use indexmap::{IndexMap, IndexSet};
use itertools::Itertools; use itertools::Itertools;
@ -26,7 +26,7 @@ use crate::{
convert_opaque_type, erase_opaque_type_operations, find_and_replace_generics, convert_opaque_type, erase_opaque_type_operations, find_and_replace_generics,
get_arg_type_name, get_generic_id_and_type, get_variant_name, monomorphize, get_arg_type_name, get_generic_id_and_type, get_variant_name, monomorphize,
pattern_has_conditions, wrap_as_multi_validator, wrap_validator_condition, CodeGenFunction, pattern_has_conditions, wrap_as_multi_validator, wrap_validator_condition, CodeGenFunction,
SpecificClause, SpecificClause, identify_recursive_static_params,
}, },
tipo::{ tipo::{
ModuleValueConstructor, PatternConstructor, Type, TypeInfo, ValueConstructor, ModuleValueConstructor, PatternConstructor, Type, TypeInfo, ValueConstructor,
@ -2690,30 +2690,41 @@ impl<'a> CodeGenerator<'a> {
// first grab dependencies // first grab dependencies
let func_params = params; let func_params = params;
// HACK: partition params into the "static recursives" and otherwise
// for now, we just do this based on the name, but it should be detected
// as an optimization pass
let recursive_static_indexes: Vec<usize> = func_params.iter().enumerate().filter(|(idx, p)| {
p.starts_with("pi_recursive_hack_")
}).map(|(idx, _)| idx).collect();
let recursive_nonstatics: Vec<String> = func_params.iter().cloned().filter(|p| {
!p.starts_with("pi_recursive_hack_")
}).collect();
println!("~~ recursive_nonstatics: {:?}", recursive_nonstatics);
println!("~~ func_params: {:?}", func_params);
let params_empty = func_params.is_empty(); let params_empty = func_params.is_empty();
let deps = (tree_path, func_deps.clone()); let deps = (tree_path, func_deps.clone());
if !params_empty { if !params_empty {
let mut potential_recursive_statics = vec![];
if is_recursive { if is_recursive {
potential_recursive_statics = func_params.clone();
// identify which parameters are recursively nonstatic (i.e. get modified before the self-call)
// TODO: this would be a lot simpler if each `Var`, `Let`, function argument, etc. had a unique identifier
// rather than just a name; this would let us track if the Var passed to itself was the same value as the method argument
let mut shadowed_parameters: HashMap<String, TreePath> = HashMap::new();
body.traverse_tree_with(&mut |air_tree: &mut AirTree, tree_path| {
identify_recursive_static_params(air_tree, tree_path, &func_params, key, variant, &mut shadowed_parameters, &mut potential_recursive_statics)
});
// Find the index of any recursively static parameters,
// so we can remove them from the call-site of each recursive call
let recursive_static_indexes = func_params
.iter()
.enumerate()
.filter(|&(_, p)| potential_recursive_statics.contains(p))
.map(|(idx, _)| idx)
.collect();
body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| {
modify_self_calls(air_tree, key, variant, &recursive_static_indexes); modify_self_calls(air_tree, key, variant, &recursive_static_indexes);
}); });
if recursive_static_indexes.len() > 0 {
println!("~~ {}: {:?}", key.function_name, recursive_static_indexes.iter().map(|i| func_params[*i].clone()).collect::<Vec<String>>());
}
} }
let recursive_nonstatics = func_params.iter().filter(|p| !potential_recursive_statics.contains(p)).cloned().collect();
body = AirTree::define_func( body = AirTree::define_func(
&key.function_name, &key.function_name,
&key.module_name, &key.module_name,
@ -3758,18 +3769,20 @@ impl<'a> CodeGenerator<'a> {
// If we have parameters that remain static in each recursive call, // If we have parameters that remain static in each recursive call,
// we can construct an *outer* function to take those in // we can construct an *outer* function to take those in
// and simplify the recursive part to only accept the non-static arguments // and simplify the recursive part to only accept the non-static arguments
let mut outer_func_body = Term::var(&func_name).apply(Term::var(&func_name)); let mut recursive_func_body = Term::var(&func_name).apply(Term::var(&func_name));
for param in recursive_nonstatic_params.iter() { for param in recursive_nonstatic_params.iter() {
outer_func_body = outer_func_body.apply(Term::var(param)); recursive_func_body = recursive_func_body.apply(Term::var(param));
} }
outer_func_body = outer_func_body.lambda(&func_name).apply(func_body); // Then construct an outer function with *all* parameters, not just the nonstatic ones.
let mut outer_func_body = recursive_func_body.lambda(&func_name).apply(func_body);
// Now, add *all* parameters, so that other call sites don't know the difference // Now, add *all* parameters, so that other call sites don't know the difference
for param in params.iter().rev() { for param in params.iter().rev() {
outer_func_body = outer_func_body.lambda(param); outer_func_body = outer_func_body.lambda(param);
} }
// And finally, fold that definition into the rest of our program
term = term.lambda(&func_name).apply(outer_func_body); term = term.lambda(&func_name).apply(outer_func_body);
} }
@ -4191,7 +4204,6 @@ impl<'a> CodeGenerator<'a> {
tipo, tipo,
} => { } => {
let mut arg_vec = vec![]; let mut arg_vec = vec![];
for _ in 0..count { for _ in 0..count {
arg_vec.push(arg_stack.pop().unwrap()); arg_vec.push(arg_stack.pop().unwrap());
} }

View File

@ -1,4 +1,4 @@
use std::{rc::Rc, sync::Arc}; use std::{rc::Rc, sync::Arc, collections::HashMap};
use indexmap::{IndexMap, IndexSet}; use indexmap::{IndexMap, IndexSet};
use itertools::Itertools; use itertools::Itertools;
@ -29,7 +29,7 @@ use crate::{
use super::{ use super::{
air::Air, air::Air,
tree::{AirExpression, AirStatement, AirTree}, tree::{AirExpression, AirStatement, AirTree, TreePath},
}; };
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -583,6 +583,71 @@ pub fn erase_opaque_type_operations(
} }
} }
pub fn identify_recursive_static_params(
air_tree: &mut AirTree,
tree_path: &TreePath,
func_params: &Vec<String>,
func_key: &FunctionAccessKey,
variant: &String,
shadowed_parameters: &mut HashMap<String, TreePath>,
potential_recursive_statics: &mut Vec<String>
) {
match air_tree {
AirTree::Statement { statement: AirStatement::Let { name, .. }, .. } => {
if potential_recursive_statics.contains(name) {
shadowed_parameters.insert(name.clone(), tree_path.clone());
}
},
AirTree::Expression(AirExpression::Call { func, args, .. }) => {
if let AirTree::Expression(AirExpression::Var {
constructor:
ValueConstructor {
variant: ValueConstructorVariant::ModuleFn { name, module, .. },
..
},
variant_name,
..
}) = func.as_ref() {
if name == &func_key.function_name
&& module == &func_key.module_name
&& variant == variant_name
{
for (param, arg) in func_params.iter().zip(args) {
if let Some((idx, _)) = potential_recursive_statics.iter().find_position(|&p| p == param) {
// Check if we pass something different in this recursive call site
// by different, we mean
// - a variable that is bound to a different name
// - a variable with the same name, but that was shadowed in an ancestor scope
// - any other type of expression
let param_is_different = match arg {
AirTree::Expression(AirExpression::Var { name, .. }) => {
// "shadowed in an ancestor scope" means "the definition scope is a prefix of our scope"
name != param || if let Some(p) = shadowed_parameters.get(param) {
println!("param: {:?}", param);
println!("arg: {:?}", arg);
println!("p: {:?}", *p);
println!("tree_path: {:?}", tree_path);
println!("common_ancestor: {:?}", p.common_ancestor(tree_path));
p.common_ancestor(tree_path) == *p
} else {
false
}
},
_ => true
};
// If so, then we disqualify this parameter from being a recursive static parameter
if param_is_different {
potential_recursive_statics.remove(idx);
}
}
}
}
}
},
_ => ()
}
}
pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, static_recursive_params: &Vec<usize>) { pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, static_recursive_params: &Vec<usize>) {
if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree { if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree {
if let AirTree::Expression(AirExpression::Var { if let AirTree::Expression(AirExpression::Var {
@ -601,6 +666,7 @@ pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, v
{ {
// Remove any static-recursive-parameters, because they'll be bound statically // Remove any static-recursive-parameters, because they'll be bound statically
// above the recursive part of the function // above the recursive part of the function
// note: assumes that static_recursive_params is sorted
for arg in static_recursive_params.iter().rev() { for arg in static_recursive_params.iter().rev() {
args.remove(*arg); args.remove(*arg);
} }

View File

@ -21,3 +21,32 @@ validator {
must_say_hello && must_be_signed must_say_hello && must_be_signed
} }
} }
type ABC {
a: ByteArray,
b: Int,
c: ByteArray,
}
type XYZ {
a: ByteArray,
b: ByteArray,
c: ByteArray,
d: Int,
e: ABC,
}
fn recursive(a: ByteArray, b: Int, c: XYZ, d: Int, e: Int) -> ByteArray {
if c.e.a == "a" {
"d"
} else if b == 0 {
a
} else {
recursive(a, b - 1, c, d, e)
}
}
test hah() {
expect "a" == recursive("a", 30, XYZ("", "", "", 1, ABC("", 1, "")), 2, 5)
True
}