From c45caaefc8c4f96fb77fa100c1daf45027faa1b6 Mon Sep 17 00:00:00 2001 From: Pi Lanningham Date: Fri, 28 Jul 2023 21:55:50 -0400 Subject: [PATCH] Rudimentary implementation Adds an identify_recursive_static_params; doesn't handle all shadowing cases yet --- crates/aiken-lang/src/gen_uplc.rs | 50 ++++++++----- crates/aiken-lang/src/gen_uplc/builder.rs | 70 ++++++++++++++++++- .../hello_world/validators/hello_world.ak | 29 ++++++++ 3 files changed, 128 insertions(+), 21 deletions(-) diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index b072ce2f..a8c061d7 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -2,7 +2,7 @@ pub mod air; pub mod builder; pub mod tree; -use std::sync::Arc; +use std::{sync::Arc, collections::HashMap}; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; @@ -26,7 +26,7 @@ use crate::{ convert_opaque_type, erase_opaque_type_operations, find_and_replace_generics, get_arg_type_name, get_generic_id_and_type, get_variant_name, monomorphize, pattern_has_conditions, wrap_as_multi_validator, wrap_validator_condition, CodeGenFunction, - SpecificClause, + SpecificClause, identify_recursive_static_params, }, tipo::{ ModuleValueConstructor, PatternConstructor, Type, TypeInfo, ValueConstructor, @@ -2690,30 +2690,41 @@ impl<'a> CodeGenerator<'a> { // first grab dependencies let func_params = params; - // HACK: partition params into the "static recursives" and otherwise - // for now, we just do this based on the name, but it should be detected - // as an optimization pass - let recursive_static_indexes: Vec = func_params.iter().enumerate().filter(|(idx, p)| { - p.starts_with("pi_recursive_hack_") - }).map(|(idx, _)| idx).collect(); - let recursive_nonstatics: Vec = func_params.iter().cloned().filter(|p| { - !p.starts_with("pi_recursive_hack_") - }).collect(); - - println!("~~ recursive_nonstatics: {:?}", recursive_nonstatics); - println!("~~ func_params: {:?}", func_params); - let params_empty = func_params.is_empty(); let deps = (tree_path, func_deps.clone()); if !params_empty { + let mut potential_recursive_statics = vec![]; if is_recursive { + potential_recursive_statics = func_params.clone(); + // identify which parameters are recursively nonstatic (i.e. get modified before the self-call) + // TODO: this would be a lot simpler if each `Var`, `Let`, function argument, etc. had a unique identifier + // rather than just a name; this would let us track if the Var passed to itself was the same value as the method argument + let mut shadowed_parameters: HashMap = HashMap::new(); + body.traverse_tree_with(&mut |air_tree: &mut AirTree, tree_path| { + identify_recursive_static_params(air_tree, tree_path, &func_params, key, variant, &mut shadowed_parameters, &mut potential_recursive_statics) + }); + + // Find the index of any recursively static parameters, + // so we can remove them from the call-site of each recursive call + let recursive_static_indexes = func_params + .iter() + .enumerate() + .filter(|&(_, p)| potential_recursive_statics.contains(p)) + .map(|(idx, _)| idx) + .collect(); + body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { modify_self_calls(air_tree, key, variant, &recursive_static_indexes); }); + + if recursive_static_indexes.len() > 0 { + println!("~~ {}: {:?}", key.function_name, recursive_static_indexes.iter().map(|i| func_params[*i].clone()).collect::>()); + } } + let recursive_nonstatics = func_params.iter().filter(|p| !potential_recursive_statics.contains(p)).cloned().collect(); body = AirTree::define_func( &key.function_name, &key.module_name, @@ -3758,18 +3769,20 @@ impl<'a> CodeGenerator<'a> { // If we have parameters that remain static in each recursive call, // we can construct an *outer* function to take those in // and simplify the recursive part to only accept the non-static arguments - let mut outer_func_body = Term::var(&func_name).apply(Term::var(&func_name)); + let mut recursive_func_body = Term::var(&func_name).apply(Term::var(&func_name)); for param in recursive_nonstatic_params.iter() { - outer_func_body = outer_func_body.apply(Term::var(param)); + recursive_func_body = recursive_func_body.apply(Term::var(param)); } - outer_func_body = outer_func_body.lambda(&func_name).apply(func_body); + // Then construct an outer function with *all* parameters, not just the nonstatic ones. + let mut outer_func_body = recursive_func_body.lambda(&func_name).apply(func_body); // Now, add *all* parameters, so that other call sites don't know the difference for param in params.iter().rev() { outer_func_body = outer_func_body.lambda(param); } + // And finally, fold that definition into the rest of our program term = term.lambda(&func_name).apply(outer_func_body); } @@ -4191,7 +4204,6 @@ impl<'a> CodeGenerator<'a> { tipo, } => { let mut arg_vec = vec![]; - for _ in 0..count { arg_vec.push(arg_stack.pop().unwrap()); } diff --git a/crates/aiken-lang/src/gen_uplc/builder.rs b/crates/aiken-lang/src/gen_uplc/builder.rs index 3d9f11d3..22f96e22 100644 --- a/crates/aiken-lang/src/gen_uplc/builder.rs +++ b/crates/aiken-lang/src/gen_uplc/builder.rs @@ -1,4 +1,4 @@ -use std::{rc::Rc, sync::Arc}; +use std::{rc::Rc, sync::Arc, collections::HashMap}; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; @@ -29,7 +29,7 @@ use crate::{ use super::{ air::Air, - tree::{AirExpression, AirStatement, AirTree}, + tree::{AirExpression, AirStatement, AirTree, TreePath}, }; #[derive(Clone, Debug)] @@ -583,6 +583,71 @@ pub fn erase_opaque_type_operations( } } +pub fn identify_recursive_static_params( + air_tree: &mut AirTree, + tree_path: &TreePath, + func_params: &Vec, + func_key: &FunctionAccessKey, + variant: &String, + shadowed_parameters: &mut HashMap, + potential_recursive_statics: &mut Vec +) { + match air_tree { + AirTree::Statement { statement: AirStatement::Let { name, .. }, .. } => { + if potential_recursive_statics.contains(name) { + shadowed_parameters.insert(name.clone(), tree_path.clone()); + } + }, + AirTree::Expression(AirExpression::Call { func, args, .. }) => { + if let AirTree::Expression(AirExpression::Var { + constructor: + ValueConstructor { + variant: ValueConstructorVariant::ModuleFn { name, module, .. }, + .. + }, + variant_name, + .. + }) = func.as_ref() { + if name == &func_key.function_name + && module == &func_key.module_name + && variant == variant_name + { + for (param, arg) in func_params.iter().zip(args) { + if let Some((idx, _)) = potential_recursive_statics.iter().find_position(|&p| p == param) { + // Check if we pass something different in this recursive call site + // by different, we mean + // - a variable that is bound to a different name + // - a variable with the same name, but that was shadowed in an ancestor scope + // - any other type of expression + let param_is_different = match arg { + AirTree::Expression(AirExpression::Var { name, .. }) => { + // "shadowed in an ancestor scope" means "the definition scope is a prefix of our scope" + name != param || if let Some(p) = shadowed_parameters.get(param) { + println!("param: {:?}", param); + println!("arg: {:?}", arg); + println!("p: {:?}", *p); + println!("tree_path: {:?}", tree_path); + println!("common_ancestor: {:?}", p.common_ancestor(tree_path)); + p.common_ancestor(tree_path) == *p + } else { + false + } + }, + _ => true + }; + // If so, then we disqualify this parameter from being a recursive static parameter + if param_is_different { + potential_recursive_statics.remove(idx); + } + } + } + } + } + }, + _ => () + } +} + pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, static_recursive_params: &Vec) { if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree { if let AirTree::Expression(AirExpression::Var { @@ -601,6 +666,7 @@ pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, v { // Remove any static-recursive-parameters, because they'll be bound statically // above the recursive part of the function + // note: assumes that static_recursive_params is sorted for arg in static_recursive_params.iter().rev() { args.remove(*arg); } diff --git a/examples/hello_world/validators/hello_world.ak b/examples/hello_world/validators/hello_world.ak index 334aa194..5ab93a40 100644 --- a/examples/hello_world/validators/hello_world.ak +++ b/examples/hello_world/validators/hello_world.ak @@ -21,3 +21,32 @@ validator { must_say_hello && must_be_signed } } + +type ABC { + a: ByteArray, + b: Int, + c: ByteArray, +} + +type XYZ { + a: ByteArray, + b: ByteArray, + c: ByteArray, + d: Int, + e: ABC, +} + +fn recursive(a: ByteArray, b: Int, c: XYZ, d: Int, e: Int) -> ByteArray { + if c.e.a == "a" { + "d" + } else if b == 0 { + a + } else { + recursive(a, b - 1, c, d, e) + } +} + +test hah() { + expect "a" == recursive("a", 30, XYZ("", "", "", 1, ABC("", 1, "")), 2, 5) + True +}