WIP: first part of mutual recursion is done.

This involves creating the function definition and detecting cycles.
The remaining part is to "fix" the call sites
of the mutually recursive functions
This commit is contained in:
microproofs 2023-09-03 16:29:27 -04:00 committed by Kasey
parent ecc5769c64
commit a4aa51ed2d
5 changed files with 152 additions and 23 deletions

1
Cargo.lock generated vendored
View File

@ -93,6 +93,7 @@ dependencies = [
"num-bigint", "num-bigint",
"ordinal", "ordinal",
"owo-colors", "owo-colors",
"petgraph",
"pretty_assertions", "pretty_assertions",
"strum", "strum",
"thiserror", "thiserror",

View File

@ -26,6 +26,7 @@ thiserror = "1.0.39"
vec1 = "1.10.1" vec1 = "1.10.1"
uplc = { path = '../uplc', version = "1.0.17-alpha" } uplc = { path = '../uplc', version = "1.0.17-alpha" }
num-bigint = "0.4.3" num-bigint = "0.4.3"
petgraph = "0.6.3"
[target.'cfg(not(target_family="wasm"))'.dependencies] [target.'cfg(not(target_family="wasm"))'.dependencies]
chumsky = "0.9.2" chumsky = "0.9.2"

View File

@ -2,6 +2,8 @@ pub mod air;
pub mod builder; pub mod builder;
pub mod tree; pub mod tree;
use petgraph::{algo, Graph};
use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
use indexmap::{IndexMap, IndexSet}; use indexmap::{IndexMap, IndexSet};
@ -2626,6 +2628,63 @@ impl<'a> CodeGenerator<'a> {
// First we need to sort functions by dependencies // First we need to sort functions by dependencies
// here's also where we deal with mutual recursion // here's also where we deal with mutual recursion
// Mutual Recursion
let inputs = functions_to_hoist
.iter()
.flat_map(|(function_name, val)| {
val.into_iter()
.map(|(variant, (_, function))| {
if let UserFunction::Function { deps, .. } = function {
((function_name.clone(), variant.clone()), deps)
} else {
todo!("Deal with Link later")
}
})
.collect_vec()
})
.collect_vec();
let capacity = inputs.len();
let mut graph = Graph::<(), ()>::with_capacity(capacity, capacity * 5);
let mut indices = HashMap::with_capacity(capacity);
let mut values = HashMap::with_capacity(capacity);
for (value, _) in &inputs {
let index = graph.add_node(());
indices.insert(value.clone(), index);
values.insert(index, value.clone());
}
for (value, deps) in inputs {
if let Some(from_index) = indices.get(&value) {
let deps = deps.iter().filter_map(|dep| indices.get(dep));
for to_index in deps {
graph.add_edge(*from_index, *to_index, ());
}
}
}
let strong_connections = algo::tarjan_scc(&graph);
for connections in strong_connections {
// If there's only one function, then it's only self recursive
if connections.len() < 2 {
continue;
}
let function_names = connections
.iter()
.map(|index| values.get(index).unwrap())
.collect_vec();
}
// Rest of code is for hoisting functions
let mut sorted_function_vec = vec![]; let mut sorted_function_vec = vec![];
let functions_to_hoist_cloned = functions_to_hoist.clone(); let functions_to_hoist_cloned = functions_to_hoist.clone();
@ -3840,6 +3899,45 @@ impl<'a> CodeGenerator<'a> {
arg_stack.push(term); arg_stack.push(term);
} }
} }
Air::DefineCyclicFuncs {
func_name,
module_name,
variant_name,
contained_functions,
} => {
let func_name = if module_name.is_empty() {
format!("{func_name}{variant_name}")
} else {
format!("{module_name}_{func_name}{variant_name}")
};
let mut cyclic_functions = vec![];
for params in contained_functions {
let func_body = arg_stack.pop().unwrap();
cyclic_functions.push((params, func_body));
}
let mut term = arg_stack.pop().unwrap();
let mut cyclic_body = Term::var("__chooser");
for (params, func_body) in cyclic_functions.into_iter() {
let mut function = func_body;
for param in params.iter().rev() {
function = function.lambda(param);
}
cyclic_body = cyclic_body.apply(function)
}
term = term
.lambda(&func_name)
.apply(Term::var(&func_name).apply(Term::var(&func_name)))
.lambda(&func_name)
.apply(cyclic_body.lambda("__chooser").lambda(func_name));
arg_stack.push(term);
}
Air::Let { name } => { Air::Let { name } => {
let arg = arg_stack.pop().unwrap(); let arg = arg_stack.pop().unwrap();

View File

@ -50,6 +50,13 @@ pub enum Air {
recursive_nonstatic_params: Vec<String>, recursive_nonstatic_params: Vec<String>,
variant_name: String, variant_name: String,
}, },
DefineCyclicFuncs {
func_name: String,
module_name: String,
variant_name: String,
// just the params
contained_functions: Vec<Vec<String>>,
},
Fn { Fn {
params: Vec<String>, params: Vec<String>,
}, },

View File

@ -54,27 +54,6 @@ impl TreePath {
common_ancestor common_ancestor
} }
pub fn diff_ancestor(&self, ancestor: &Self) -> Self {
let mut self_iter = self.path.iter();
let ancestor_iter = ancestor.path.iter();
for ancestor in ancestor_iter {
if let Some(self_path) = self_iter.next() {
if self_path == ancestor {
continue;
} else {
unreachable!("Other path is not a common ancestor self path.")
}
} else {
unreachable!("Other path is longer than self path.")
}
}
TreePath {
path: self_iter.cloned().collect_vec(),
}
}
} }
impl Default for TreePath { impl Default for TreePath {
@ -133,6 +112,13 @@ pub enum AirStatement {
variant_name: String, variant_name: String,
func_body: Box<AirTree>, func_body: Box<AirTree>,
}, },
DefineCyclicFuncs {
func_name: String,
module_name: String,
variant_name: String,
// params and body
contained_functions: Vec<(Vec<String>, AirTree)>,
},
// Assertions // Assertions
AssertConstr { AssertConstr {
constr_index: usize, constr_index: usize,
@ -892,6 +878,26 @@ impl AirTree {
}); });
func_body.create_air_vec(air_vec); func_body.create_air_vec(air_vec);
} }
AirStatement::DefineCyclicFuncs {
func_name,
module_name,
variant_name,
contained_functions,
} => {
air_vec.push(Air::DefineCyclicFuncs {
func_name: func_name.clone(),
module_name: module_name.clone(),
variant_name: variant_name.clone(),
contained_functions: contained_functions
.into_iter()
.map(|(params, _)| params.clone())
.collect_vec(),
});
for (_, func_body) in contained_functions {
func_body.create_air_vec(air_vec);
}
}
AirStatement::AssertConstr { AirStatement::AssertConstr {
constr, constr,
constr_index, constr_index,
@ -1413,6 +1419,21 @@ impl AirTree {
apply_with_last, apply_with_last,
); );
} }
AirStatement::DefineCyclicFuncs {
contained_functions,
..
} => {
for (_, func_body) in contained_functions {
func_body.do_traverse_tree_with(
tree_path,
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
);
}
}
AirStatement::AssertConstr { constr, .. } => { AirStatement::AssertConstr { constr, .. } => {
constr.do_traverse_tree_with( constr.do_traverse_tree_with(
tree_path, tree_path,
@ -1852,9 +1873,9 @@ impl AirTree {
&'a mut self, &'a mut self,
tree_path_iter: &mut Iter<(usize, usize)>, tree_path_iter: &mut Iter<(usize, usize)>,
) -> &'a mut AirTree { ) -> &'a mut AirTree {
// For Finding the air node we skip over the define func ops since those are added later on. // For finding the air node we skip over the define func ops since those are added later on.
if let AirTree::Statement { if let AirTree::Statement {
statement: AirStatement::DefineFunc { .. }, statement: AirStatement::DefineFunc { .. } | AirStatement::DefineCyclicFuncs { .. },
hoisted_over: Some(hoisted_over), hoisted_over: Some(hoisted_over),
} = self } = self
{ {
@ -1958,6 +1979,7 @@ impl AirTree {
} }
} }
AirStatement::DefineFunc { .. } => unreachable!(), AirStatement::DefineFunc { .. } => unreachable!(),
AirStatement::DefineCyclicFuncs { .. } => unreachable!(),
AirStatement::FieldsEmpty { constr } => { AirStatement::FieldsEmpty { constr } => {
if *index == 0 { if *index == 0 {
constr.as_mut().do_find_air_tree_node(tree_path_iter) constr.as_mut().do_find_air_tree_node(tree_path_iter)