Rudimentary implementation
Adds an identify_recursive_static_params; doesn't handle all shadowing cases yet
This commit is contained in:
parent
09f889b121
commit
c45caaefc8
|
@ -2,7 +2,7 @@ pub mod air;
|
||||||
pub mod builder;
|
pub mod builder;
|
||||||
pub mod tree;
|
pub mod tree;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::{sync::Arc, collections::HashMap};
|
||||||
|
|
||||||
use indexmap::{IndexMap, IndexSet};
|
use indexmap::{IndexMap, IndexSet};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
@ -26,7 +26,7 @@ use crate::{
|
||||||
convert_opaque_type, erase_opaque_type_operations, find_and_replace_generics,
|
convert_opaque_type, erase_opaque_type_operations, find_and_replace_generics,
|
||||||
get_arg_type_name, get_generic_id_and_type, get_variant_name, monomorphize,
|
get_arg_type_name, get_generic_id_and_type, get_variant_name, monomorphize,
|
||||||
pattern_has_conditions, wrap_as_multi_validator, wrap_validator_condition, CodeGenFunction,
|
pattern_has_conditions, wrap_as_multi_validator, wrap_validator_condition, CodeGenFunction,
|
||||||
SpecificClause,
|
SpecificClause, identify_recursive_static_params,
|
||||||
},
|
},
|
||||||
tipo::{
|
tipo::{
|
||||||
ModuleValueConstructor, PatternConstructor, Type, TypeInfo, ValueConstructor,
|
ModuleValueConstructor, PatternConstructor, Type, TypeInfo, ValueConstructor,
|
||||||
|
@ -2690,30 +2690,41 @@ impl<'a> CodeGenerator<'a> {
|
||||||
// first grab dependencies
|
// first grab dependencies
|
||||||
let func_params = params;
|
let func_params = params;
|
||||||
|
|
||||||
// HACK: partition params into the "static recursives" and otherwise
|
|
||||||
// for now, we just do this based on the name, but it should be detected
|
|
||||||
// as an optimization pass
|
|
||||||
let recursive_static_indexes: Vec<usize> = func_params.iter().enumerate().filter(|(idx, p)| {
|
|
||||||
p.starts_with("pi_recursive_hack_")
|
|
||||||
}).map(|(idx, _)| idx).collect();
|
|
||||||
let recursive_nonstatics: Vec<String> = func_params.iter().cloned().filter(|p| {
|
|
||||||
!p.starts_with("pi_recursive_hack_")
|
|
||||||
}).collect();
|
|
||||||
|
|
||||||
println!("~~ recursive_nonstatics: {:?}", recursive_nonstatics);
|
|
||||||
println!("~~ func_params: {:?}", func_params);
|
|
||||||
|
|
||||||
let params_empty = func_params.is_empty();
|
let params_empty = func_params.is_empty();
|
||||||
|
|
||||||
let deps = (tree_path, func_deps.clone());
|
let deps = (tree_path, func_deps.clone());
|
||||||
|
|
||||||
if !params_empty {
|
if !params_empty {
|
||||||
|
let mut potential_recursive_statics = vec![];
|
||||||
if is_recursive {
|
if is_recursive {
|
||||||
|
potential_recursive_statics = func_params.clone();
|
||||||
|
// identify which parameters are recursively nonstatic (i.e. get modified before the self-call)
|
||||||
|
// TODO: this would be a lot simpler if each `Var`, `Let`, function argument, etc. had a unique identifier
|
||||||
|
// rather than just a name; this would let us track if the Var passed to itself was the same value as the method argument
|
||||||
|
let mut shadowed_parameters: HashMap<String, TreePath> = HashMap::new();
|
||||||
|
body.traverse_tree_with(&mut |air_tree: &mut AirTree, tree_path| {
|
||||||
|
identify_recursive_static_params(air_tree, tree_path, &func_params, key, variant, &mut shadowed_parameters, &mut potential_recursive_statics)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Find the index of any recursively static parameters,
|
||||||
|
// so we can remove them from the call-site of each recursive call
|
||||||
|
let recursive_static_indexes = func_params
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter(|&(_, p)| potential_recursive_statics.contains(p))
|
||||||
|
.map(|(idx, _)| idx)
|
||||||
|
.collect();
|
||||||
|
|
||||||
body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| {
|
body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| {
|
||||||
modify_self_calls(air_tree, key, variant, &recursive_static_indexes);
|
modify_self_calls(air_tree, key, variant, &recursive_static_indexes);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if recursive_static_indexes.len() > 0 {
|
||||||
|
println!("~~ {}: {:?}", key.function_name, recursive_static_indexes.iter().map(|i| func_params[*i].clone()).collect::<Vec<String>>());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let recursive_nonstatics = func_params.iter().filter(|p| !potential_recursive_statics.contains(p)).cloned().collect();
|
||||||
body = AirTree::define_func(
|
body = AirTree::define_func(
|
||||||
&key.function_name,
|
&key.function_name,
|
||||||
&key.module_name,
|
&key.module_name,
|
||||||
|
@ -3758,18 +3769,20 @@ impl<'a> CodeGenerator<'a> {
|
||||||
// If we have parameters that remain static in each recursive call,
|
// If we have parameters that remain static in each recursive call,
|
||||||
// we can construct an *outer* function to take those in
|
// we can construct an *outer* function to take those in
|
||||||
// and simplify the recursive part to only accept the non-static arguments
|
// and simplify the recursive part to only accept the non-static arguments
|
||||||
let mut outer_func_body = Term::var(&func_name).apply(Term::var(&func_name));
|
let mut recursive_func_body = Term::var(&func_name).apply(Term::var(&func_name));
|
||||||
for param in recursive_nonstatic_params.iter() {
|
for param in recursive_nonstatic_params.iter() {
|
||||||
outer_func_body = outer_func_body.apply(Term::var(param));
|
recursive_func_body = recursive_func_body.apply(Term::var(param));
|
||||||
}
|
}
|
||||||
|
|
||||||
outer_func_body = outer_func_body.lambda(&func_name).apply(func_body);
|
// Then construct an outer function with *all* parameters, not just the nonstatic ones.
|
||||||
|
let mut outer_func_body = recursive_func_body.lambda(&func_name).apply(func_body);
|
||||||
|
|
||||||
// Now, add *all* parameters, so that other call sites don't know the difference
|
// Now, add *all* parameters, so that other call sites don't know the difference
|
||||||
for param in params.iter().rev() {
|
for param in params.iter().rev() {
|
||||||
outer_func_body = outer_func_body.lambda(param);
|
outer_func_body = outer_func_body.lambda(param);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// And finally, fold that definition into the rest of our program
|
||||||
term = term.lambda(&func_name).apply(outer_func_body);
|
term = term.lambda(&func_name).apply(outer_func_body);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4191,7 +4204,6 @@ impl<'a> CodeGenerator<'a> {
|
||||||
tipo,
|
tipo,
|
||||||
} => {
|
} => {
|
||||||
let mut arg_vec = vec![];
|
let mut arg_vec = vec![];
|
||||||
|
|
||||||
for _ in 0..count {
|
for _ in 0..count {
|
||||||
arg_vec.push(arg_stack.pop().unwrap());
|
arg_vec.push(arg_stack.pop().unwrap());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::{rc::Rc, sync::Arc};
|
use std::{rc::Rc, sync::Arc, collections::HashMap};
|
||||||
|
|
||||||
use indexmap::{IndexMap, IndexSet};
|
use indexmap::{IndexMap, IndexSet};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
@ -29,7 +29,7 @@ use crate::{
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
air::Air,
|
air::Air,
|
||||||
tree::{AirExpression, AirStatement, AirTree},
|
tree::{AirExpression, AirStatement, AirTree, TreePath},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
@ -583,6 +583,71 @@ pub fn erase_opaque_type_operations(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn identify_recursive_static_params(
|
||||||
|
air_tree: &mut AirTree,
|
||||||
|
tree_path: &TreePath,
|
||||||
|
func_params: &Vec<String>,
|
||||||
|
func_key: &FunctionAccessKey,
|
||||||
|
variant: &String,
|
||||||
|
shadowed_parameters: &mut HashMap<String, TreePath>,
|
||||||
|
potential_recursive_statics: &mut Vec<String>
|
||||||
|
) {
|
||||||
|
match air_tree {
|
||||||
|
AirTree::Statement { statement: AirStatement::Let { name, .. }, .. } => {
|
||||||
|
if potential_recursive_statics.contains(name) {
|
||||||
|
shadowed_parameters.insert(name.clone(), tree_path.clone());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
AirTree::Expression(AirExpression::Call { func, args, .. }) => {
|
||||||
|
if let AirTree::Expression(AirExpression::Var {
|
||||||
|
constructor:
|
||||||
|
ValueConstructor {
|
||||||
|
variant: ValueConstructorVariant::ModuleFn { name, module, .. },
|
||||||
|
..
|
||||||
|
},
|
||||||
|
variant_name,
|
||||||
|
..
|
||||||
|
}) = func.as_ref() {
|
||||||
|
if name == &func_key.function_name
|
||||||
|
&& module == &func_key.module_name
|
||||||
|
&& variant == variant_name
|
||||||
|
{
|
||||||
|
for (param, arg) in func_params.iter().zip(args) {
|
||||||
|
if let Some((idx, _)) = potential_recursive_statics.iter().find_position(|&p| p == param) {
|
||||||
|
// Check if we pass something different in this recursive call site
|
||||||
|
// by different, we mean
|
||||||
|
// - a variable that is bound to a different name
|
||||||
|
// - a variable with the same name, but that was shadowed in an ancestor scope
|
||||||
|
// - any other type of expression
|
||||||
|
let param_is_different = match arg {
|
||||||
|
AirTree::Expression(AirExpression::Var { name, .. }) => {
|
||||||
|
// "shadowed in an ancestor scope" means "the definition scope is a prefix of our scope"
|
||||||
|
name != param || if let Some(p) = shadowed_parameters.get(param) {
|
||||||
|
println!("param: {:?}", param);
|
||||||
|
println!("arg: {:?}", arg);
|
||||||
|
println!("p: {:?}", *p);
|
||||||
|
println!("tree_path: {:?}", tree_path);
|
||||||
|
println!("common_ancestor: {:?}", p.common_ancestor(tree_path));
|
||||||
|
p.common_ancestor(tree_path) == *p
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => true
|
||||||
|
};
|
||||||
|
// If so, then we disqualify this parameter from being a recursive static parameter
|
||||||
|
if param_is_different {
|
||||||
|
potential_recursive_statics.remove(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => ()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, static_recursive_params: &Vec<usize>) {
|
pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, static_recursive_params: &Vec<usize>) {
|
||||||
if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree {
|
if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree {
|
||||||
if let AirTree::Expression(AirExpression::Var {
|
if let AirTree::Expression(AirExpression::Var {
|
||||||
|
@ -601,6 +666,7 @@ pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, v
|
||||||
{
|
{
|
||||||
// Remove any static-recursive-parameters, because they'll be bound statically
|
// Remove any static-recursive-parameters, because they'll be bound statically
|
||||||
// above the recursive part of the function
|
// above the recursive part of the function
|
||||||
|
// note: assumes that static_recursive_params is sorted
|
||||||
for arg in static_recursive_params.iter().rev() {
|
for arg in static_recursive_params.iter().rev() {
|
||||||
args.remove(*arg);
|
args.remove(*arg);
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,3 +21,32 @@ validator {
|
||||||
must_say_hello && must_be_signed
|
must_say_hello && must_be_signed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ABC {
|
||||||
|
a: ByteArray,
|
||||||
|
b: Int,
|
||||||
|
c: ByteArray,
|
||||||
|
}
|
||||||
|
|
||||||
|
type XYZ {
|
||||||
|
a: ByteArray,
|
||||||
|
b: ByteArray,
|
||||||
|
c: ByteArray,
|
||||||
|
d: Int,
|
||||||
|
e: ABC,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recursive(a: ByteArray, b: Int, c: XYZ, d: Int, e: Int) -> ByteArray {
|
||||||
|
if c.e.a == "a" {
|
||||||
|
"d"
|
||||||
|
} else if b == 0 {
|
||||||
|
a
|
||||||
|
} else {
|
||||||
|
recursive(a, b - 1, c, d, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test hah() {
|
||||||
|
expect "a" == recursive("a", 30, XYZ("", "", "", 1, ABC("", 1, "")), 2, 5)
|
||||||
|
True
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue