Refactor some code to make a define_const AirTree function

This commit is contained in:
microproofs
2024-09-20 00:39:27 -04:00
parent 513ca27717
commit 2bbc699a25
3 changed files with 190 additions and 129 deletions

View File

@@ -23,6 +23,17 @@ impl From<bool> for ExpectLevel {
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FunctionVariants {
Standard(Vec<String>),
Recursive {
params: Vec<String>,
recursive_nonstatic_params: Vec<String>,
},
Cyclic(Vec<Vec<String>>),
Constant,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Air {
// Primitives
@@ -67,17 +78,8 @@ pub enum Air {
DefineFunc {
func_name: String,
module_name: String,
params: Vec<String>,
recursive: bool,
recursive_nonstatic_params: Vec<String>,
variant_name: String,
},
DefineCyclicFuncs {
func_name: String,
module_name: String,
variant_name: String,
// just the params
contained_functions: Vec<Vec<String>>,
variant: FunctionVariants,
},
Fn {
params: Vec<String>,

View File

@@ -1,4 +1,4 @@
use super::air::{Air, ExpectLevel};
use super::air::{Air, ExpectLevel, FunctionVariants};
use crate::{
ast::{BinOp, Curve, Span, UnOp},
tipo::{Type, ValueConstructor, ValueConstructorVariant},
@@ -21,6 +21,7 @@ pub enum Fields {
SixthField,
SeventhField,
EighthField,
NinthField,
ArgsField(usize),
}
@@ -136,10 +137,12 @@ pub enum AirTree {
DefineFunc {
func_name: String,
module_name: String,
variant_name: String,
//params and other parts of a function
params: Vec<String>,
recursive: bool,
recursive_nonstatic_params: Vec<String>,
variant_name: String,
constant: bool,
func_body: Box<AirTree>,
then: Box<AirTree>,
},
@@ -531,12 +534,33 @@ impl AirTree {
params,
recursive,
recursive_nonstatic_params,
constant: false,
variant_name: variant_name.to_string(),
func_body: func_body.into(),
then: then.into(),
}
}
#[allow(clippy::too_many_arguments)]
pub fn define_const(
func_name: impl ToString,
module_name: impl ToString,
func_body: AirTree,
then: AirTree,
) -> AirTree {
AirTree::DefineFunc {
func_name: func_name.to_string(),
module_name: module_name.to_string(),
variant_name: "".to_string(),
params: vec![],
recursive: false,
recursive_nonstatic_params: vec![],
constant: true,
func_body: func_body.into(),
then: then.into(),
}
}
pub fn define_cyclic_func(
func_name: impl ToString,
module_name: impl ToString,
@@ -1158,17 +1182,33 @@ impl AirTree {
params,
recursive,
recursive_nonstatic_params,
constant,
variant_name,
func_body,
then,
} => {
let variant = if *constant {
assert!(!recursive);
assert!(params.is_empty());
assert!(recursive_nonstatic_params.is_empty());
FunctionVariants::Constant
} else if *recursive {
FunctionVariants::Recursive {
params: params.clone(),
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
}
} else {
assert_eq!(params, recursive_nonstatic_params);
FunctionVariants::Standard(params.clone())
};
air_vec.push(Air::DefineFunc {
func_name: func_name.clone(),
module_name: module_name.clone(),
params: params.clone(),
recursive: *recursive,
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
variant_name: variant_name.clone(),
variant,
});
func_body.create_air_vec(air_vec);
then.create_air_vec(air_vec);
@@ -1180,14 +1220,18 @@ impl AirTree {
contained_functions,
then,
} => {
air_vec.push(Air::DefineCyclicFuncs {
func_name: func_name.clone(),
module_name: module_name.clone(),
variant_name: variant_name.clone(),
contained_functions: contained_functions
let variant = FunctionVariants::Cyclic(
contained_functions
.iter()
.map(|(params, _)| params.clone())
.collect_vec(),
);
air_vec.push(Air::DefineFunc {
func_name: func_name.clone(),
module_name: module_name.clone(),
variant_name: variant_name.clone(),
variant,
});
for (_, func_body) in contained_functions {
@@ -1834,8 +1878,8 @@ impl AirTree {
) {
tree_path.push(current_depth, field_index);
// Assignments'/Statements' values get traversed here
// Then the body under these assignments/statements get traversed later on
// TODO: Merge together the 2 match statements
match self {
AirTree::Let {
name: _,
@@ -2104,7 +2148,6 @@ impl AirTree {
| AirTree::MultiValidator { .. } => {}
}
// Expressions or an assignment that hoist over a expression are traversed here
match self {
AirTree::NoOp { then } => {
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with);
@@ -2389,16 +2432,17 @@ impl AirTree {
recursive: _,
recursive_nonstatic_params: _,
variant_name: _,
constant: _,
func_body,
then,
} => {
func_body.do_traverse_tree_with(
tree_path,
current_depth + 1,
Fields::SeventhField,
Fields::EighthField,
with,
);
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::EighthField, with)
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::NinthField, with)
}
AirTree::DefineCyclicFuncs {
func_name: _,