feat: finished up mutual recursion

Now we "handle" vars that call the cyclic function.
That includes vars in the cyclic function as well as in other functions
"handle" meaning we modify the var to be a call that takes in more arguments.
This commit is contained in:
microproofs 2023-09-22 17:22:19 -04:00 committed by Kasey
parent ae3053522e
commit f4310bcf33
3 changed files with 358 additions and 157 deletions

View File

@ -41,8 +41,9 @@ use self::{
air::Air,
builder::{
cast_validator_args, constants_ir, convert_type_to_data, extract_constant,
lookup_data_type_by_tipo, modify_self_calls, rearrange_list_clauses, AssignmentProperties,
ClauseProperties, DataTypeKey, FunctionAccessKey, HoistableFunction, Variant,
lookup_data_type_by_tipo, modify_cyclic_calls, modify_self_calls, rearrange_list_clauses,
AssignmentProperties, ClauseProperties, DataTypeKey, FunctionAccessKey, HoistableFunction,
Variant,
},
tree::{AirExpression, AirTree, TreePath},
};
@ -56,7 +57,8 @@ pub struct CodeGenerator<'a> {
needs_field_access: bool,
code_gen_functions: IndexMap<String, CodeGenFunction>,
zero_arg_functions: IndexMap<(FunctionAccessKey, Variant), Vec<Air>>,
cyclic_functions: IndexMap<(FunctionAccessKey, Variant), (usize, FunctionAccessKey)>,
cyclic_functions:
IndexMap<(FunctionAccessKey, Variant), (Vec<String>, usize, FunctionAccessKey)>,
tracing: bool,
id_gen: IdGenerator,
}
@ -2697,12 +2699,28 @@ impl<'a> CodeGenerator<'a> {
let mut cycle_of_functions = vec![];
let mut cycle_deps = vec![];
let function_list = cyclic_function_names
.iter()
.map(|(key, variant)| {
format!(
"{}{}{}",
key.module_name,
key.function_name,
if variant.is_empty() {
"".to_string()
} else {
format!("_{}", variant)
}
)
})
.collect_vec();
// By doing this any vars that "call" into a function in the cycle will be
// redirected to call the cyclic function instead with the proper index
for (index, (func_name, variant)) in cyclic_function_names.iter().enumerate() {
self.cyclic_functions.insert(
(func_name.clone(), variant.clone()),
(index, function_key.clone()),
(function_list.clone(), index, function_key.clone()),
);
let (tree_path, func) = functions_to_hoist
@ -2753,6 +2771,7 @@ impl<'a> CodeGenerator<'a> {
}
// Rest of code is for hoisting functions
// TODO: replace with graph implementation of sorting
let mut sorted_function_vec: Vec<(FunctionAccessKey, String)> = vec![];
let functions_to_hoist_cloned = functions_to_hoist.clone();
@ -2865,8 +2884,6 @@ impl<'a> CodeGenerator<'a> {
}
sorted_function_vec.dedup();
todo!();
// Now we need to hoist the functions to the top of the validator
for (key, variant) in sorted_function_vec {
if hoisted_functions
@ -2910,18 +2927,11 @@ impl<'a> CodeGenerator<'a> {
hoisted_functions: &mut Vec<(FunctionAccessKey, String)>,
) {
match function {
HoistableFunction::Function { body, deps, params } => todo!(),
HoistableFunction::CyclicFunction { functions, deps } => todo!(),
HoistableFunction::Link(_) => todo!(),
HoistableFunction::CyclicLink(_) => todo!(),
}
if let HoistableFunction::Function {
HoistableFunction::Function {
body,
deps: func_deps,
params,
} = function
{
} => {
let mut body = body.clone();
let (key, variant) = key_var;
@ -2984,8 +2994,51 @@ impl<'a> CodeGenerator<'a> {
self.zero_arg_functions
.insert((key.clone(), variant.clone()), body.to_vec());
}
} else {
todo!()
}
HoistableFunction::CyclicFunction {
functions,
deps: func_deps,
} => {
let (key, variant) = key_var;
let deps = (tree_path, func_deps.clone());
let mut functions = functions.clone();
for (_, body) in functions.iter_mut() {
modify_cyclic_calls(body, key, &self.cyclic_functions);
}
let cyclic_body = AirTree::define_cyclic_func(
&key.function_name,
&key.module_name,
variant,
functions,
);
let function_deps = self.hoist_dependent_functions(
deps,
// cyclic functions always have params
false,
key,
variant,
hoisted_functions,
functions_to_hoist,
);
let node_to_edit = air_tree.find_air_tree_node(tree_path);
// now hoist full function onto validator tree
*node_to_edit =
function_deps.hoist_over(cyclic_body.hoist_over(node_to_edit.clone()));
hoisted_functions.push((key.clone(), variant.clone()));
}
HoistableFunction::Link(_) => {
todo!("This should probably be unreachable when I get to it")
}
HoistableFunction::CyclicLink(_) => {
unreachable!("Sorted functions should not contain cyclic links")
}
}
}
@ -3017,7 +3070,8 @@ impl<'a> CodeGenerator<'a> {
.get(&dep.1)
.unwrap_or_else(|| panic!("Missing Function Variant Definition"));
if let HoistableFunction::Function { deps, params, .. } = function {
match function {
HoistableFunction::Function { deps, params, .. } => {
if !params.is_empty() {
for (dep_generic_func, dep_variant) in deps.iter() {
if !(dep_generic_func == &dep.0 && dep_variant == &dep.1) {
@ -3029,9 +3083,41 @@ impl<'a> CodeGenerator<'a> {
}
}
}
} else {
todo!("Deal with Link later")
}
HoistableFunction::CyclicFunction { deps, .. } => {
for (dep_generic_func, dep_variant) in deps.iter() {
if !(dep_generic_func == &dep.0 && dep_variant == &dep.1) {
sorted_dep_vec.retain(|(generic_func, variant)| {
!(generic_func == dep_generic_func && variant == dep_variant)
});
deps_vec.insert(0, (dep_generic_func.clone(), dep_variant.clone()));
}
}
}
HoistableFunction::Link(_) => todo!("Deal with Link later"),
HoistableFunction::CyclicLink(cyclic_func) => {
let (_, HoistableFunction::CyclicFunction { deps, .. }) = functions_to_hoist
.get(cyclic_func)
.unwrap()
.get("")
.unwrap()
else {
unreachable!()
};
for (dep_generic_func, dep_variant) in deps.iter() {
if !(dep_generic_func == &dep.0 && dep_variant == &dep.1) {
sorted_dep_vec.retain(|(generic_func, variant)| {
!(generic_func == dep_generic_func && variant == dep_variant)
});
deps_vec.insert(0, (dep_generic_func.clone(), dep_variant.clone()));
}
}
}
}
sorted_dep_vec.push((dep.0.clone(), dep.1.clone()));
}
@ -3061,15 +3147,12 @@ impl<'a> CodeGenerator<'a> {
// In the case of zero args, we need to hoist the dependency function to the top of the zero arg function
if &dep_path.common_ancestor(func_path) == func_path || params_empty {
let HoistableFunction::Function {
match dep_function.clone() {
HoistableFunction::Function {
body: mut dep_air_tree,
deps: dependency_deps,
params: dependent_params,
} = dep_function.clone()
else {
unreachable!()
};
} => {
if dependent_params.is_empty() {
// continue for zero arg functions. They are treated like global hoists.
continue;
@ -3080,7 +3163,12 @@ impl<'a> CodeGenerator<'a> {
.any(|(key, variant)| &dep_key == key && &dep_variant == variant);
let recursive_nonstatics = if is_dependent_recursive {
modify_self_calls(&mut dep_air_tree, &dep_key, &dep_variant, &dependent_params)
modify_self_calls(
&mut dep_air_tree,
&dep_key,
&dep_variant,
&dependent_params,
)
} else {
dependent_params.clone()
};
@ -3099,6 +3187,28 @@ impl<'a> CodeGenerator<'a> {
hoisted_functions.push((dep_key.clone(), dep_variant.clone()));
}
}
HoistableFunction::CyclicFunction { functions, .. } => {
let mut functions = functions.clone();
for (_, body) in functions.iter_mut() {
modify_cyclic_calls(body, key, &self.cyclic_functions);
}
dep_insertions.push(AirTree::define_cyclic_func(
&dep_key.function_name,
&dep_key.module_name,
&dep_variant,
functions,
));
if !params_empty {
hoisted_functions.push((dep_key.clone(), dep_variant.clone()));
}
}
HoistableFunction::Link(_) => unreachable!(),
HoistableFunction::CyclicLink(_) => unreachable!(),
}
}
}
dep_insertions.reverse();
@ -3441,6 +3551,31 @@ impl<'a> CodeGenerator<'a> {
module,
..
} => {
if let Some((names, index, cyclic_name)) = self.cyclic_functions.get(&(
FunctionAccessKey {
module_name: module.clone(),
function_name: func_name.clone(),
},
variant_name.clone(),
)) {
let cyclic_var_name = if cyclic_name.module_name.is_empty() {
cyclic_name.function_name.to_string()
} else {
format!("{}_{}", cyclic_name.module_name, cyclic_name.function_name)
};
let index_name = names[*index].clone();
let mut arg_var = Term::var(index_name.clone());
for name in names.iter().rev() {
arg_var = arg_var.lambda(name);
}
let term = Term::var(cyclic_var_name).apply(arg_var);
arg_stack.push(term);
} else {
let name = if (*func_name == name
|| name == format!("{module}_{func_name}"))
&& !module.is_empty()
@ -3458,6 +3593,7 @@ impl<'a> CodeGenerator<'a> {
.into(),
));
}
}
ValueConstructorVariant::Record {
name: constr_name, ..
} => {

View File

@ -751,6 +751,55 @@ pub fn modify_self_calls(
recursive_nonstatics
}
pub fn modify_cyclic_calls(
body: &mut AirTree,
func_key: &FunctionAccessKey,
cyclic_links: &IndexMap<(FunctionAccessKey, Variant), (Vec<String>, usize, FunctionAccessKey)>,
) {
body.traverse_tree_with(
&mut |air_tree: &mut AirTree, _| {
if let AirTree::Expression(AirExpression::Var {
constructor:
ValueConstructor {
variant: ValueConstructorVariant::ModuleFn { name, module, .. },
tipo,
..
},
variant_name,
..
}) = air_tree
{
let tipo = tipo.clone();
let var_key = FunctionAccessKey {
module_name: module.clone(),
function_name: name.clone(),
};
if let Some((names, index, cyclic_name)) =
cyclic_links.get(&(var_key.clone(), variant_name.to_string()))
{
if *cyclic_name == *func_key {
let index_name = names[*index].clone();
*air_tree = AirTree::call(
air_tree.clone(),
tipo.clone(),
vec![
air_tree.clone(),
AirTree::anon_func(
names.clone(),
AirTree::local_var(index_name, tipo),
),
],
);
}
}
}
},
true,
);
}
pub fn pattern_has_conditions(
pattern: &TypedPattern,
data_types: &IndexMap<DataTypeKey, &TypedDataType>,

View File

@ -430,6 +430,22 @@ impl AirTree {
hoisted_over: None,
}
}
pub fn define_cyclic_func(
func_name: impl ToString,
module_name: impl ToString,
variant_name: impl ToString,
contained_functions: Vec<(Vec<String>, AirTree)>,
) -> AirTree {
AirTree::Statement {
statement: AirStatement::DefineCyclicFuncs {
func_name: func_name.to_string(),
module_name: module_name.to_string(),
variant_name: variant_name.to_string(),
contained_functions,
},
hoisted_over: None,
}
}
pub fn anon_func(params: Vec<String>, func_body: AirTree) -> AirTree {
AirTree::Expression(AirExpression::Fn {
params,
@ -1375,10 +1391,10 @@ impl AirTree {
pub fn traverse_tree_with(
&mut self,
with: &mut impl FnMut(&mut AirTree, &TreePath),
apply_with_last: bool,
apply_with_func_last: bool,
) {
let mut tree_path = TreePath::new();
self.do_traverse_tree_with(&mut tree_path, 0, 0, with, apply_with_last);
self.do_traverse_tree_with(&mut tree_path, 0, 0, with, apply_with_func_last);
}
pub fn traverse_tree_with_path(
@ -1387,9 +1403,9 @@ impl AirTree {
current_depth: usize,
depth_index: usize,
with: &mut impl FnMut(&mut AirTree, &TreePath),
apply_with_last: bool,
apply_with_func_last: bool,
) {
self.do_traverse_tree_with(path, current_depth, depth_index, with, apply_with_last);
self.do_traverse_tree_with(path, current_depth, depth_index, with, apply_with_func_last);
}
fn do_traverse_tree_with(
@ -1398,7 +1414,7 @@ impl AirTree {
current_depth: usize,
depth_index: usize,
with: &mut impl FnMut(&mut AirTree, &TreePath),
apply_with_last: bool,
apply_with_func_last: bool,
) {
let mut index_count = IndexCounter::new();
tree_path.push(current_depth, depth_index);
@ -1411,7 +1427,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirStatement::DefineFunc { func_body, .. } => {
@ -1420,7 +1436,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirStatement::DefineCyclicFuncs {
@ -1433,7 +1449,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
}
@ -1444,7 +1460,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirStatement::AssertBool { value, .. } => {
@ -1453,7 +1469,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirStatement::ClauseGuard { pattern, .. } => {
@ -1462,7 +1478,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirStatement::ListClauseGuard { .. } => {}
@ -1473,7 +1489,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirStatement::ListAccessor { list, .. } => {
@ -1482,7 +1498,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirStatement::ListExpose { .. } => {}
@ -1492,7 +1508,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirStatement::NoOp => {}
@ -1502,7 +1518,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirStatement::ListEmpty { list } => {
@ -1511,13 +1527,13 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
};
}
if !apply_with_last {
if !apply_with_func_last {
with(self, tree_path);
}
@ -1531,7 +1547,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirTree::Expression(e) => match e {
@ -1542,7 +1558,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
}
@ -1553,7 +1569,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
}
@ -1563,7 +1579,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
for arg in args {
@ -1572,7 +1588,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
}
@ -1582,7 +1598,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::Builtin { args, .. } => {
@ -1592,7 +1608,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
}
@ -1602,7 +1618,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
right.do_traverse_tree_with(
@ -1610,7 +1626,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::UnOp { arg, .. } => {
@ -1619,7 +1635,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::CastFromData { value, .. } => {
@ -1628,7 +1644,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::CastToData { value, .. } => {
@ -1637,7 +1653,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::When {
@ -1648,7 +1664,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
clauses.do_traverse_tree_with(
@ -1656,7 +1672,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::Clause {
@ -1670,7 +1686,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
then.do_traverse_tree_with(
@ -1678,7 +1694,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
otherwise.do_traverse_tree_with(
@ -1686,7 +1702,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::ListClause {
@ -1697,7 +1713,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
otherwise.do_traverse_tree_with(
@ -1705,7 +1721,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::WrapClause { then, otherwise } => {
@ -1714,7 +1730,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
otherwise.do_traverse_tree_with(
@ -1722,7 +1738,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::TupleClause {
@ -1733,7 +1749,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
otherwise.do_traverse_tree_with(
@ -1741,7 +1757,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::Finally { pattern, then } => {
@ -1750,7 +1766,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
then.do_traverse_tree_with(
@ -1758,7 +1774,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::If {
@ -1772,7 +1788,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
then.do_traverse_tree_with(
@ -1780,7 +1796,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
otherwise.do_traverse_tree_with(
@ -1788,7 +1804,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::Constr { args, .. } => {
@ -1798,7 +1814,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
}
@ -1808,7 +1824,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
for arg in args {
arg.do_traverse_tree_with(
@ -1816,7 +1832,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
}
@ -1826,7 +1842,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::TupleIndex { tuple, .. } => {
@ -1835,7 +1851,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
AirExpression::Trace { msg, then, .. } => {
@ -1844,7 +1860,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
then.do_traverse_tree_with(
@ -1852,7 +1868,7 @@ impl AirTree {
current_depth + 1,
index_count.next_number(),
with,
apply_with_last,
apply_with_func_last,
);
}
_ => {}
@ -1860,7 +1876,7 @@ impl AirTree {
a => unreachable!("GOT THIS {:#?}", a),
}
if apply_with_last {
if apply_with_func_last {
with(self, tree_path);
}