diff --git a/crates/aiken-lang/src/builder.rs b/crates/aiken-lang/src/builder.rs index 694a9403..1766165b 100644 --- a/crates/aiken-lang/src/builder.rs +++ b/crates/aiken-lang/src/builder.rs @@ -1409,9 +1409,26 @@ pub fn handle_func_deps_ir( to_be_defined: &mut HashMap, ) { let mut funt_comp = funt_comp.clone(); - // deal with function dependencies + + let mut dependency_map = IndexMap::new(); + let mut dependency_vec = vec![]; + + // deal with function dependencies by sorting order in which we pop them. while let Some(dependency) = funt_comp.dependencies.pop() { - let mut insert_var_vec = vec![]; + let depend_comp = func_components.get(&dependency).unwrap(); + if dependency_map.contains_key(&dependency) { + dependency_map.shift_remove(&dependency); + } + dependency_map.insert(dependency, ()); + funt_comp + .dependencies + .extend(depend_comp.dependencies.clone().into_iter()); + } + + dependency_vec.extend(dependency_map.keys().cloned()); + dependency_vec.reverse(); + + while let Some(dependency) = dependency_vec.pop() { if (defined_functions.contains_key(&dependency) && !funt_comp.args.is_empty()) || func_components.get(&dependency).is_none() { @@ -1419,7 +1436,6 @@ pub fn handle_func_deps_ir( } let depend_comp = func_components.get(&dependency).unwrap(); - let dep_scope = func_index_map.get(&dependency).unwrap(); if get_common_ancestor(dep_scope, func_scope) == func_scope.to_vec() @@ -1427,39 +1443,8 @@ pub fn handle_func_deps_ir( { // we handle zero arg functions and their dependencies in a unique way if !depend_comp.args.is_empty() { - funt_comp - .dependencies - .extend(depend_comp.dependencies.clone()); - - for (index, ir) in depend_comp.ir.iter().enumerate().rev() { - match_ir_for_recursion( - ir.clone(), - &mut insert_var_vec, - &FunctionAccessKey { - function_name: dependency.function_name.clone(), - module_name: dependency.module_name.clone(), - variant_name: dependency.variant_name.clone(), - }, - index, - ); - } - - let mut recursion_ir = depend_comp.ir.clone(); - for (index, ir) in insert_var_vec.clone() { - recursion_ir.insert(index, ir); - - let current_call = recursion_ir[index - 1].clone(); - - match current_call { - Air::Call { scope, count } => { - recursion_ir[index - 1] = Air::Call { - scope, - count: count + 1, - } - } - _ => unreachable!(), - } - } + let mut recursion_ir = vec![]; + handle_recursion_ir(&dependency, depend_comp, &mut recursion_ir); let mut temp_ir = vec![Air::DefineFunc { scope: func_scope.to_vec(), @@ -1485,3 +1470,41 @@ pub fn handle_func_deps_ir( } } } + +pub fn handle_recursion_ir( + func_key: &FunctionAccessKey, + func_comp: &FuncComponents, + recursion_ir: &mut Vec, +) { + let mut insert_var_vec = vec![]; + + for (index, ir) in func_comp.ir.iter().enumerate().rev() { + match_ir_for_recursion( + ir.clone(), + &mut insert_var_vec, + &FunctionAccessKey { + function_name: func_key.function_name.clone(), + module_name: func_key.module_name.clone(), + variant_name: func_key.variant_name.clone(), + }, + index, + ); + } + *recursion_ir = func_comp.ir.clone(); + // Deals with self recursive function + for (index, ir) in insert_var_vec.clone() { + recursion_ir.insert(index, ir); + + let current_call = recursion_ir[index - 1].clone(); + + match current_call { + Air::Call { scope, count } => { + recursion_ir[index - 1] = Air::Call { + scope, + count: count + 1, + } + } + _ => unreachable!(), + } + } +} diff --git a/crates/aiken-lang/src/uplc.rs b/crates/aiken-lang/src/uplc.rs index 1ae444dc..a7e11971 100644 --- a/crates/aiken-lang/src/uplc.rs +++ b/crates/aiken-lang/src/uplc.rs @@ -29,7 +29,7 @@ use crate::{ builder::{ check_when_pattern_needs, constants_ir, convert_constants_to_data, convert_data_to_type, convert_type_to_data, get_common_ancestor, get_generics_and_type, get_variant_name, - handle_func_deps_ir, list_access_to_uplc, match_ir_for_recursion, monomorphize, + handle_func_deps_ir, handle_recursion_ir, list_access_to_uplc, monomorphize, rearrange_clauses, wrap_validator_args, ClauseProperties, DataTypeKey, FuncComponents, FunctionAccessKey, }, @@ -1845,12 +1845,30 @@ impl<'a> CodeGenerator<'a> { let mut final_func_dep_ir = IndexMap::new(); let mut zero_arg_defined_functions = HashMap::new(); let mut to_be_defined = HashMap::new(); - for func in func_index_map.clone() { - if self.defined_functions.contains_key(&func.0) { + + let mut dependency_map = IndexMap::new(); + let mut dependency_vec = vec![]; + + let mut func_keys = func_components.keys().cloned().collect_vec(); + + // deal with function dependencies by sorting order in which we iter over them. + while let Some(function) = func_keys.pop() { + let funct_comp = func_components.get(&function).unwrap(); + if dependency_map.contains_key(&function) { + dependency_map.shift_remove(&function); + } + dependency_map.insert(function, ()); + func_keys.extend(funct_comp.dependencies.clone().into_iter()); + } + + dependency_vec.extend(dependency_map.keys().cloned()); + + for func in dependency_vec { + if self.defined_functions.contains_key(&func) { continue; } - let funt_comp = func_components.get(&func.0).unwrap(); - let func_scope = func_index_map.get(&func.0).unwrap(); + let funt_comp = func_components.get(&func).unwrap(); + let func_scope = func_index_map.get(&func).unwrap(); let mut dep_ir = vec![]; @@ -1865,7 +1883,7 @@ impl<'a> CodeGenerator<'a> { func_scope, &mut to_be_defined, ); - final_func_dep_ir.insert(func.0, dep_ir); + final_func_dep_ir.insert(func, dep_ir); } else { // since zero arg functions are run at compile time we need to pull all deps let mut defined_functions = HashMap::new(); @@ -1882,7 +1900,7 @@ impl<'a> CodeGenerator<'a> { let mut final_zero_arg_ir = dep_ir; final_zero_arg_ir.extend(funt_comp.ir.clone()); - self.zero_arg_functions.insert(func.0, final_zero_arg_ir); + self.zero_arg_functions.insert(func, final_zero_arg_ir); for (key, val) in defined_functions.into_iter() { zero_arg_defined_functions.insert(key, val); @@ -1890,6 +1908,8 @@ impl<'a> CodeGenerator<'a> { } } + // handle functions that are used in zero arg funcs but also used by the validator + // or a func used by the validator for (key, val) in zero_arg_defined_functions.into_iter() { if !to_be_defined.contains_key(&key) { self.defined_functions.insert(key, val); @@ -1909,7 +1929,6 @@ impl<'a> CodeGenerator<'a> { .collect_vec(); for (function_access_key, scopes) in to_insert.into_iter() { - let mut insert_var_vec = vec![]; func_index_map.remove(&function_access_key); self.defined_functions @@ -1918,33 +1937,12 @@ impl<'a> CodeGenerator<'a> { let mut full_func_ir = final_func_dep_ir.get(&function_access_key).unwrap().clone(); - let mut func_comp = func_components.get(&function_access_key).unwrap().clone(); + let func_comp = func_components.get(&function_access_key).unwrap().clone(); + // zero arg functions are not recursive if !func_comp.args.is_empty() { - for (index, ir) in func_comp.ir.clone().iter().enumerate().rev() { - match_ir_for_recursion( - ir.clone(), - &mut insert_var_vec, - &function_access_key, - index, - ); - } - - for (index, ir) in insert_var_vec { - func_comp.ir.insert(index, ir.clone()); - - let current_call = func_comp.ir[index - 1].clone(); - - match current_call { - Air::Call { scope, count } => { - func_comp.ir[index - 1] = Air::Call { - scope, - count: count + 1, - } - } - _ => unreachable!("{current_call:#?}"), - } - } + let mut recursion_ir = vec![]; + handle_recursion_ir(&function_access_key, &func_comp, &mut recursion_ir); full_func_ir.push(Air::DefineFunc { scope: scopes.clone(), @@ -1955,7 +1953,7 @@ impl<'a> CodeGenerator<'a> { variant_name: function_access_key.variant_name.clone(), }); - full_func_ir.extend(func_comp.ir.clone()); + full_func_ir.extend(recursion_ir); for ir in full_func_ir.into_iter().rev() { ir_stack.insert(index, ir);