diff --git a/Cargo.lock b/Cargo.lock index 0a043bbc..ba7e900a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -93,6 +93,7 @@ dependencies = [ "num-bigint", "ordinal", "owo-colors", + "petgraph", "pretty_assertions", "strum", "thiserror", diff --git a/crates/aiken-lang/Cargo.toml b/crates/aiken-lang/Cargo.toml index e85ead2e..3fc87918 100644 --- a/crates/aiken-lang/Cargo.toml +++ b/crates/aiken-lang/Cargo.toml @@ -26,6 +26,7 @@ thiserror = "1.0.39" vec1 = "1.10.1" uplc = { path = '../uplc', version = "1.0.17-alpha" } num-bigint = "0.4.3" +petgraph = "0.6.3" [target.'cfg(not(target_family="wasm"))'.dependencies] chumsky = "0.9.2" diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index 97b9a302..3ca17e7f 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -2,6 +2,8 @@ pub mod air; pub mod builder; pub mod tree; +use petgraph::{algo, Graph}; +use std::collections::HashMap; use std::rc::Rc; use indexmap::{IndexMap, IndexSet}; @@ -2626,6 +2628,63 @@ impl<'a> CodeGenerator<'a> { // First we need to sort functions by dependencies // here's also where we deal with mutual recursion + + // Mutual Recursion + let inputs = functions_to_hoist + .iter() + .flat_map(|(function_name, val)| { + val.into_iter() + .map(|(variant, (_, function))| { + if let UserFunction::Function { deps, .. } = function { + ((function_name.clone(), variant.clone()), deps) + } else { + todo!("Deal with Link later") + } + }) + .collect_vec() + }) + .collect_vec(); + + let capacity = inputs.len(); + + let mut graph = Graph::<(), ()>::with_capacity(capacity, capacity * 5); + + let mut indices = HashMap::with_capacity(capacity); + let mut values = HashMap::with_capacity(capacity); + + for (value, _) in &inputs { + let index = graph.add_node(()); + + indices.insert(value.clone(), index); + + values.insert(index, value.clone()); + } + + for (value, deps) in inputs { + if let Some(from_index) = indices.get(&value) { + let deps = deps.iter().filter_map(|dep| indices.get(dep)); + + for to_index in deps { + graph.add_edge(*from_index, *to_index, ()); + } + } + } + + let strong_connections = algo::tarjan_scc(&graph); + + for connections in strong_connections { + // If there's only one function, then it's only self recursive + if connections.len() < 2 { + continue; + } + + let function_names = connections + .iter() + .map(|index| values.get(index).unwrap()) + .collect_vec(); + } + + // Rest of code is for hoisting functions let mut sorted_function_vec = vec![]; let functions_to_hoist_cloned = functions_to_hoist.clone(); @@ -3840,6 +3899,45 @@ impl<'a> CodeGenerator<'a> { arg_stack.push(term); } } + Air::DefineCyclicFuncs { + func_name, + module_name, + variant_name, + contained_functions, + } => { + let func_name = if module_name.is_empty() { + format!("{func_name}{variant_name}") + } else { + format!("{module_name}_{func_name}{variant_name}") + }; + let mut cyclic_functions = vec![]; + + for params in contained_functions { + let func_body = arg_stack.pop().unwrap(); + + cyclic_functions.push((params, func_body)); + } + let mut term = arg_stack.pop().unwrap(); + + let mut cyclic_body = Term::var("__chooser"); + + for (params, func_body) in cyclic_functions.into_iter() { + let mut function = func_body; + for param in params.iter().rev() { + function = function.lambda(param); + } + + cyclic_body = cyclic_body.apply(function) + } + + term = term + .lambda(&func_name) + .apply(Term::var(&func_name).apply(Term::var(&func_name))) + .lambda(&func_name) + .apply(cyclic_body.lambda("__chooser").lambda(func_name)); + + arg_stack.push(term); + } Air::Let { name } => { let arg = arg_stack.pop().unwrap(); diff --git a/crates/aiken-lang/src/gen_uplc/air.rs b/crates/aiken-lang/src/gen_uplc/air.rs index 277dbf4a..d70cc1be 100644 --- a/crates/aiken-lang/src/gen_uplc/air.rs +++ b/crates/aiken-lang/src/gen_uplc/air.rs @@ -50,6 +50,13 @@ pub enum Air { recursive_nonstatic_params: Vec, variant_name: String, }, + DefineCyclicFuncs { + func_name: String, + module_name: String, + variant_name: String, + // just the params + contained_functions: Vec>, + }, Fn { params: Vec, }, diff --git a/crates/aiken-lang/src/gen_uplc/tree.rs b/crates/aiken-lang/src/gen_uplc/tree.rs index 98145795..570a9d1c 100644 --- a/crates/aiken-lang/src/gen_uplc/tree.rs +++ b/crates/aiken-lang/src/gen_uplc/tree.rs @@ -54,27 +54,6 @@ impl TreePath { common_ancestor } - - pub fn diff_ancestor(&self, ancestor: &Self) -> Self { - let mut self_iter = self.path.iter(); - let ancestor_iter = ancestor.path.iter(); - - for ancestor in ancestor_iter { - if let Some(self_path) = self_iter.next() { - if self_path == ancestor { - continue; - } else { - unreachable!("Other path is not a common ancestor self path.") - } - } else { - unreachable!("Other path is longer than self path.") - } - } - - TreePath { - path: self_iter.cloned().collect_vec(), - } - } } impl Default for TreePath { @@ -133,6 +112,13 @@ pub enum AirStatement { variant_name: String, func_body: Box, }, + DefineCyclicFuncs { + func_name: String, + module_name: String, + variant_name: String, + // params and body + contained_functions: Vec<(Vec, AirTree)>, + }, // Assertions AssertConstr { constr_index: usize, @@ -892,6 +878,26 @@ impl AirTree { }); func_body.create_air_vec(air_vec); } + AirStatement::DefineCyclicFuncs { + func_name, + module_name, + variant_name, + contained_functions, + } => { + air_vec.push(Air::DefineCyclicFuncs { + func_name: func_name.clone(), + module_name: module_name.clone(), + variant_name: variant_name.clone(), + contained_functions: contained_functions + .into_iter() + .map(|(params, _)| params.clone()) + .collect_vec(), + }); + + for (_, func_body) in contained_functions { + func_body.create_air_vec(air_vec); + } + } AirStatement::AssertConstr { constr, constr_index, @@ -1413,6 +1419,21 @@ impl AirTree { apply_with_last, ); } + AirStatement::DefineCyclicFuncs { + contained_functions, + .. + } => { + for (_, func_body) in contained_functions { + func_body.do_traverse_tree_with( + tree_path, + current_depth + 1, + index_count.next_number(), + with, + apply_with_last, + ); + } + } + AirStatement::AssertConstr { constr, .. } => { constr.do_traverse_tree_with( tree_path, @@ -1852,9 +1873,9 @@ impl AirTree { &'a mut self, tree_path_iter: &mut Iter<(usize, usize)>, ) -> &'a mut AirTree { - // For Finding the air node we skip over the define func ops since those are added later on. + // For finding the air node we skip over the define func ops since those are added later on. if let AirTree::Statement { - statement: AirStatement::DefineFunc { .. }, + statement: AirStatement::DefineFunc { .. } | AirStatement::DefineCyclicFuncs { .. }, hoisted_over: Some(hoisted_over), } = self { @@ -1958,6 +1979,7 @@ impl AirTree { } } AirStatement::DefineFunc { .. } => unreachable!(), + AirStatement::DefineCyclicFuncs { .. } => unreachable!(), AirStatement::FieldsEmpty { constr } => { if *index == 0 { constr.as_mut().do_find_air_tree_node(tree_path_iter)