diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index 0a1fc282..d250209b 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -42,7 +42,7 @@ use self::{ builder::{ cast_validator_args, constants_ir, convert_type_to_data, extract_constant, lookup_data_type_by_tipo, modify_self_calls, rearrange_list_clauses, AssignmentProperties, - ClauseProperties, DataTypeKey, FunctionAccessKey, HoistableFunction, + ClauseProperties, DataTypeKey, FunctionAccessKey, HoistableFunction, Variant, }, tree::{AirExpression, AirTree, TreePath}, }; @@ -55,7 +55,8 @@ pub struct CodeGenerator<'a> { module_types: IndexMap<&'a String, &'a TypeInfo>, needs_field_access: bool, code_gen_functions: IndexMap, - zero_arg_functions: IndexMap<(FunctionAccessKey, String), Vec>, + zero_arg_functions: IndexMap<(FunctionAccessKey, Variant), Vec>, + cyclic_functions: IndexMap<(FunctionAccessKey, Variant), (usize, FunctionAccessKey)>, tracing: bool, id_gen: IdGenerator, } @@ -75,6 +76,7 @@ impl<'a> CodeGenerator<'a> { needs_field_access: false, code_gen_functions: IndexMap::new(), zero_arg_functions: IndexMap::new(), + cyclic_functions: IndexMap::new(), tracing, id_gen: IdGenerator::new(), } @@ -85,6 +87,7 @@ impl<'a> CodeGenerator<'a> { self.zero_arg_functions = IndexMap::new(); self.needs_field_access = false; self.defined_functions = IndexMap::new(); + self.cyclic_functions = IndexMap::new(); self.id_gen = IdGenerator::new(); } @@ -2679,55 +2682,88 @@ impl<'a> CodeGenerator<'a> { continue; } - let function_names = connections + let cyclic_function_names = connections .iter() .map(|index| values.get(index).unwrap()) .collect_vec(); + // TODO: Maybe I could come up with a name based off the functions involved? let function_key = FunctionAccessKey { function_name: format!("__cyclic_function_{}", index), module_name: "".to_string(), }; - let (functions, mut deps) = function_names - .into_iter() - .map(|(func_key, variant)| { - let (_, func) = functions_to_hoist - .get(func_key) - .expect("Missing Function Definition") - .get(variant) - .expect("Missing Function Variant Definition"); + let mut path = TreePath::new(); + let mut cycle_of_functions = vec![]; + let mut cycle_deps = vec![]; - match func { - HoistableFunction::Function { body, deps, params } => { - (params.clone(), body.clone(), deps.clone()) - } + // By doing this any vars that "call" into a function in the cycle will be + // redirected to call the cyclic function instead with the proper index + for (index, (func_name, variant)) in cyclic_function_names.iter().enumerate() { + self.cyclic_functions.insert( + (func_name.clone(), variant.clone()), + (index, function_key.clone()), + ); - _ => unreachable!(), + let (tree_path, func) = functions_to_hoist + .get_mut(func_name) + .expect("Missing Function Definition") + .get_mut(variant) + .expect("Missing Function Variant Definition"); + + match func { + HoistableFunction::Function { params, body, deps } => { + cycle_of_functions.push((params.clone(), body.clone())); + cycle_deps.push(deps.clone()); } - }) - .fold((vec![], vec![]), |mut acc, f| { - acc.0.push((f.0, f.1)); - acc.1.push(f.2); + _ => unreachable!(), + } - acc - }); + if path.is_empty() { + path = tree_path.clone(); + } else { + path = path.common_ancestor(tree_path); + } + + // Here we change function to be link so all functions that depend on it know its a + // cyclic function + *func = HoistableFunction::CyclicLink(function_key.clone()); + } let cyclic_function = HoistableFunction::CyclicFunction { - functions, - deps: deps.into_iter().flatten().dedup().collect_vec(), + functions: cycle_of_functions, + deps: cycle_deps + .into_iter() + .flatten() + .dedup() + // Make sure to filter out cyclic dependencies + .filter(|dependency| { + !cyclic_function_names.iter().any(|(func_name, variant)| { + func_name == &dependency.0 && variant == &dependency.1 + }) + }) + .collect_vec(), }; + + let mut cyclic_map = IndexMap::new(); + cyclic_map.insert("".to_string(), (path, cyclic_function)); + + functions_to_hoist.insert(function_key, cyclic_map); } - todo!(); - // Rest of code is for hoisting functions - let mut sorted_function_vec = vec![]; + let mut sorted_function_vec: Vec<(FunctionAccessKey, String)> = vec![]; let functions_to_hoist_cloned = functions_to_hoist.clone(); + let mut sorting_attempts: u64 = 0; while let Some((generic_func, variant)) = validator_hoistable.pop() { + assert!( + sorting_attempts < 5_000_000_000, + "Sorting dependency attempts exceeded" + ); + let function_variants = functions_to_hoist_cloned .get(&generic_func) .unwrap_or_else(|| panic!("Missing Function Definition")); @@ -2770,14 +2806,67 @@ impl<'a> CodeGenerator<'a> { } } HoistableFunction::Link(_) => todo!("Deal with Link later"), - HoistableFunction::CyclicLink(_) => {} - HoistableFunction::CyclicFunction { functions, deps } => todo!(), + HoistableFunction::CyclicLink(cyclic_name) => { + validator_hoistable.insert(0, (cyclic_name.clone(), "".to_string())); + + sorted_function_vec.retain(|(generic_func, variant)| { + !(generic_func == cyclic_name && variant.is_empty()) + }); + + let (func_tree_path, _) = functions_to_hoist + .get(&generic_func) + .unwrap() + .get(&variant) + .unwrap() + .clone(); + + let (dep_path, _) = functions_to_hoist + .get_mut(cyclic_name) + .unwrap() + .get_mut("") + .unwrap(); + + *dep_path = func_tree_path.common_ancestor(dep_path); + } + HoistableFunction::CyclicFunction { deps, .. } => { + for (dep_generic_func, dep_variant) in deps.iter() { + if !(dep_generic_func == &generic_func && dep_variant == &variant) { + validator_hoistable + .insert(0, (dep_generic_func.clone(), dep_variant.clone())); + + sorted_function_vec.retain(|(generic_func, variant)| { + !(generic_func == dep_generic_func && variant == dep_variant) + }); + } + } + + // Fix dependencies path to be updated to common ancestor + for (dep_key, dep_variant) in deps { + let (func_tree_path, _) = functions_to_hoist + .get(&generic_func) + .unwrap() + .get(&variant) + .unwrap() + .clone(); + + let (dep_path, _) = functions_to_hoist + .get_mut(dep_key) + .unwrap() + .get_mut(dep_variant) + .unwrap(); + + *dep_path = func_tree_path.common_ancestor(dep_path); + } + } } sorted_function_vec.push((generic_func, variant)); + sorting_attempts += 1; } sorted_function_vec.dedup(); + todo!(); + // Now we need to hoist the functions to the top of the validator for (key, variant) in sorted_function_vec { if hoisted_functions @@ -2820,6 +2909,13 @@ impl<'a> CodeGenerator<'a> { >, hoisted_functions: &mut Vec<(FunctionAccessKey, String)>, ) { + match function { + HoistableFunction::Function { body, deps, params } => todo!(), + HoistableFunction::CyclicFunction { functions, deps } => todo!(), + HoistableFunction::Link(_) => todo!(), + HoistableFunction::CyclicLink(_) => todo!(), + } + if let HoistableFunction::Function { body, deps: func_deps, diff --git a/crates/aiken-lang/src/gen_uplc/builder.rs b/crates/aiken-lang/src/gen_uplc/builder.rs index 6c2b0369..298b7cd5 100644 --- a/crates/aiken-lang/src/gen_uplc/builder.rs +++ b/crates/aiken-lang/src/gen_uplc/builder.rs @@ -52,7 +52,7 @@ pub enum HoistableFunction { deps: Vec<(FunctionAccessKey, Variant)>, }, Link((FunctionAccessKey, Variant)), - CyclicLink((FunctionAccessKey, Variant)), + CyclicLink(FunctionAccessKey), } #[derive(Clone, Debug, Eq, PartialEq, Hash)] diff --git a/crates/aiken-lang/src/gen_uplc/tree.rs b/crates/aiken-lang/src/gen_uplc/tree.rs index 20af5917..f3c39536 100644 --- a/crates/aiken-lang/src/gen_uplc/tree.rs +++ b/crates/aiken-lang/src/gen_uplc/tree.rs @@ -21,6 +21,10 @@ impl TreePath { TreePath { path: vec![] } } + pub fn is_empty(&self) -> bool { + self.path.is_empty() + } + pub fn push(&mut self, depth: usize, index: usize) { self.path.push((depth, index)); } diff --git a/examples/acceptance_tests/066/aiken.lock b/examples/acceptance_tests/066/aiken.lock index 3a78b1e7..6e350cda 100644 --- a/examples/acceptance_tests/066/aiken.lock +++ b/examples/acceptance_tests/066/aiken.lock @@ -3,3 +3,5 @@ requirements = [] packages = [] + +[etags] diff --git a/examples/acceptance_tests/066/lib/tests.ak b/examples/acceptance_tests/066/lib/tests.ak index a4981ef3..69e69e01 100644 --- a/examples/acceptance_tests/066/lib/tests.ak +++ b/examples/acceptance_tests/066/lib/tests.ak @@ -18,8 +18,8 @@ fn sum_list(list: List) -> Int { } test foo() { - False + // False // Can't enable the "real" test because it puts the UPLC evaluator in an infinite loop. // - - // sum(List([List([Integer(1), Integer(2)]), Integer(3), Integer(4)])) == 10 + sum(List([List([Integer(1), Integer(2)]), Integer(3), Integer(4)])) == 10 }