diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index a8c061d7..5376ce16 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -2,7 +2,7 @@ pub mod air; pub mod builder; pub mod tree; -use std::{sync::Arc, collections::HashMap}; +use std::sync::Arc; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; @@ -26,7 +26,7 @@ use crate::{ convert_opaque_type, erase_opaque_type_operations, find_and_replace_generics, get_arg_type_name, get_generic_id_and_type, get_variant_name, monomorphize, pattern_has_conditions, wrap_as_multi_validator, wrap_validator_condition, CodeGenFunction, - SpecificClause, identify_recursive_static_params, + SpecificClause, }, tipo::{ ModuleValueConstructor, PatternConstructor, Type, TypeInfo, ValueConstructor, @@ -2695,36 +2695,12 @@ impl<'a> CodeGenerator<'a> { let deps = (tree_path, func_deps.clone()); if !params_empty { - let mut potential_recursive_statics = vec![]; - if is_recursive { - potential_recursive_statics = func_params.clone(); - // identify which parameters are recursively nonstatic (i.e. get modified before the self-call) - // TODO: this would be a lot simpler if each `Var`, `Let`, function argument, etc. had a unique identifier - // rather than just a name; this would let us track if the Var passed to itself was the same value as the method argument - let mut shadowed_parameters: HashMap = HashMap::new(); - body.traverse_tree_with(&mut |air_tree: &mut AirTree, tree_path| { - identify_recursive_static_params(air_tree, tree_path, &func_params, key, variant, &mut shadowed_parameters, &mut potential_recursive_statics) - }); + let recursive_nonstatics = if is_recursive { + modify_self_calls(&mut body, key, variant, &func_params) + } else { + func_params.clone() + }; - // Find the index of any recursively static parameters, - // so we can remove them from the call-site of each recursive call - let recursive_static_indexes = func_params - .iter() - .enumerate() - .filter(|&(_, p)| potential_recursive_statics.contains(p)) - .map(|(idx, _)| idx) - .collect(); - - body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { - modify_self_calls(air_tree, key, variant, &recursive_static_indexes); - }); - - if recursive_static_indexes.len() > 0 { - println!("~~ {}: {:?}", key.function_name, recursive_static_indexes.iter().map(|i| func_params[*i].clone()).collect::>()); - } - } - - let recursive_nonstatics = func_params.iter().filter(|p| !potential_recursive_statics.contains(p)).cloned().collect(); body = AirTree::define_func( &key.function_name, &key.module_name, @@ -2854,19 +2830,21 @@ impl<'a> CodeGenerator<'a> { .iter() .any(|(key, variant)| &dep_key == key && &dep_variant == variant); - if is_dependent_recursive { - dep_air_tree.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { - modify_self_calls(air_tree, &dep_key, &dep_variant, &vec![]); - }); - } + let recursive_nonstatics = if is_dependent_recursive { + modify_self_calls(&mut dep_air_tree, &dep_key, &dep_variant, &dependent_params) + } else { + dependent_params.clone() + }; + + dep_insertions.push(AirTree::define_func( &dep_key.function_name, &dep_key.module_name, &dep_variant, - dependent_params.clone(), - is_dependent_recursive, dependent_params, + is_dependent_recursive, + recursive_nonstatics, dep_air_tree, )); diff --git a/crates/aiken-lang/src/gen_uplc/builder.rs b/crates/aiken-lang/src/gen_uplc/builder.rs index 22f96e22..25c83ae5 100644 --- a/crates/aiken-lang/src/gen_uplc/builder.rs +++ b/crates/aiken-lang/src/gen_uplc/builder.rs @@ -623,11 +623,6 @@ pub fn identify_recursive_static_params( AirTree::Expression(AirExpression::Var { name, .. }) => { // "shadowed in an ancestor scope" means "the definition scope is a prefix of our scope" name != param || if let Some(p) = shadowed_parameters.get(param) { - println!("param: {:?}", param); - println!("arg: {:?}", arg); - println!("p: {:?}", *p); - println!("tree_path: {:?}", tree_path); - println!("common_ancestor: {:?}", p.common_ancestor(tree_path)); p.common_ancestor(tree_path) == *p } else { false @@ -648,34 +643,57 @@ pub fn identify_recursive_static_params( } } -pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, static_recursive_params: &Vec) { - if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree { - if let AirTree::Expression(AirExpression::Var { - constructor: - ValueConstructor { - variant: ValueConstructorVariant::ModuleFn { name, module, .. }, - .. - }, - variant_name, - .. - }) = func.as_ref() - { - if name == &func_key.function_name - && module == &func_key.module_name - && variant == variant_name +pub fn modify_self_calls(body: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, func_params: &Vec) -> Vec { + let mut potential_recursive_statics = func_params.clone(); + // identify which parameters are recursively nonstatic (i.e. get modified before the self-call) + // TODO: this would be a lot simpler if each `Var`, `Let`, function argument, etc. had a unique identifier + // rather than just a name; this would let us track if the Var passed to itself was the same value as the method argument + let mut shadowed_parameters: HashMap = HashMap::new(); + body.traverse_tree_with(&mut |air_tree: &mut AirTree, tree_path| { + identify_recursive_static_params(air_tree, tree_path, &func_params, func_key, variant, &mut shadowed_parameters, &mut potential_recursive_statics); + }); + + // Find the index of any recursively static parameters, + // so we can remove them from the call-site of each recursive call + let recursive_static_indexes: Vec<_> = func_params + .iter() + .enumerate() + .filter(|&(_, p)| potential_recursive_statics.contains(p)) + .map(|(idx, _)| idx) + .collect(); + + // Modify any self calls to remove recursive static parameters and append `self` as a parameter for the recursion + body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { + if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree { + if let AirTree::Expression(AirExpression::Var { + constructor: + ValueConstructor { + variant: ValueConstructorVariant::ModuleFn { name, module, .. }, + .. + }, + variant_name, + .. + }) = func.as_ref() { - // Remove any static-recursive-parameters, because they'll be bound statically - // above the recursive part of the function - // note: assumes that static_recursive_params is sorted - for arg in static_recursive_params.iter().rev() { - args.remove(*arg); + if name == &func_key.function_name + && module == &func_key.module_name + && variant == variant_name + { + // Remove any static-recursive-parameters, because they'll be bound statically + // above the recursive part of the function + // note: assumes that static_recursive_params is sorted + for arg in recursive_static_indexes.iter().rev() { + args.remove(*arg); + } + let mut new_args = vec![func.as_ref().clone()]; + new_args.append(args); + *args = new_args; } - let mut new_args = vec![func.as_ref().clone()]; - new_args.append(args); - *args = new_args; } } - } + }); + let recursive_nonstatics = func_params.iter().filter(|p| !potential_recursive_statics.contains(p)).cloned().collect(); + recursive_nonstatics } pub fn pattern_has_conditions(pattern: &TypedPattern) -> bool {