Refactor some code to make a define_const AirTree function
This commit is contained in:
parent
513ca27717
commit
2bbc699a25
|
@ -4988,90 +4988,102 @@ impl<'a> CodeGenerator<'a> {
|
||||||
}
|
}
|
||||||
Air::DefineFunc {
|
Air::DefineFunc {
|
||||||
func_name,
|
func_name,
|
||||||
params,
|
|
||||||
recursive,
|
|
||||||
recursive_nonstatic_params,
|
|
||||||
module_name,
|
module_name,
|
||||||
variant_name,
|
variant_name,
|
||||||
|
variant,
|
||||||
} => {
|
} => {
|
||||||
let func_name = if module_name.is_empty() {
|
let func_name = if module_name.is_empty() {
|
||||||
format!("{func_name}{variant_name}")
|
format!("{func_name}{variant_name}")
|
||||||
} else {
|
} else {
|
||||||
format!("{module_name}_{func_name}{variant_name}")
|
format!("{module_name}_{func_name}{variant_name}")
|
||||||
};
|
};
|
||||||
|
|
||||||
|
match variant {
|
||||||
|
air::FunctionVariants::Standard(params) => {
|
||||||
let mut func_body = arg_stack.pop().unwrap();
|
let mut func_body = arg_stack.pop().unwrap();
|
||||||
|
|
||||||
let mut term = arg_stack.pop().unwrap();
|
let term = arg_stack.pop().unwrap();
|
||||||
|
|
||||||
// Introduce a parameter for each parameter
|
if params.is_empty() {
|
||||||
// NOTE: we use recursive_nonstatic_params here because
|
func_body = func_body.delay();
|
||||||
// if this is recursive, those are the ones that need to be passed
|
|
||||||
// each time
|
|
||||||
for param in recursive_nonstatic_params.iter().rev() {
|
|
||||||
func_body = func_body.lambda(param.clone());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let func_body = params
|
||||||
|
.into_iter()
|
||||||
|
.fold(func_body, |term, arg| term.lambda(arg))
|
||||||
|
.lambda(NO_INLINE);
|
||||||
|
|
||||||
|
Some(term.lambda(func_name).apply(func_body))
|
||||||
|
}
|
||||||
|
air::FunctionVariants::Recursive {
|
||||||
|
params,
|
||||||
|
recursive_nonstatic_params,
|
||||||
|
} => {
|
||||||
|
let mut func_body = arg_stack.pop().unwrap();
|
||||||
|
|
||||||
|
let term = arg_stack.pop().unwrap();
|
||||||
|
|
||||||
|
let no_statics = recursive_nonstatic_params == params;
|
||||||
|
|
||||||
if recursive_nonstatic_params.is_empty() || params.is_empty() {
|
if recursive_nonstatic_params.is_empty() || params.is_empty() {
|
||||||
func_body = func_body.delay();
|
func_body = func_body.delay();
|
||||||
}
|
}
|
||||||
|
|
||||||
if !recursive {
|
let func_body = recursive_nonstatic_params
|
||||||
term = term.lambda(func_name).apply(func_body.lambda(NO_INLINE));
|
.iter()
|
||||||
|
.rfold(func_body, |term, arg| term.lambda(arg));
|
||||||
|
|
||||||
Some(term)
|
let func_body = func_body.lambda(func_name.clone());
|
||||||
} else {
|
|
||||||
func_body = func_body.lambda(func_name.clone());
|
|
||||||
|
|
||||||
if recursive_nonstatic_params == params {
|
if no_statics {
|
||||||
// If we don't have any recursive-static params, we can just emit the function as is
|
// If we don't have any recursive-static params, we can just emit the function as is
|
||||||
term = term
|
Some(
|
||||||
.lambda(func_name.clone())
|
term.lambda(func_name.clone())
|
||||||
.apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone())))
|
.apply(
|
||||||
|
Term::var(func_name.clone())
|
||||||
|
.apply(Term::var(func_name.clone())),
|
||||||
|
)
|
||||||
.lambda(func_name)
|
.lambda(func_name)
|
||||||
.apply(func_body.lambda(NO_INLINE));
|
.apply(func_body.lambda(NO_INLINE)),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
// If we have parameters that remain static in each recursive call,
|
// If we have parameters that remain static in each recursive call,
|
||||||
// we can construct an *outer* function to take those in
|
// we can construct an *outer* function to take those in
|
||||||
// and simplify the recursive part to only accept the non-static arguments
|
// and simplify the recursive part to only accept the non-static arguments
|
||||||
let mut recursive_func_body =
|
let mut recursive_func_body =
|
||||||
Term::var(&func_name).apply(Term::var(&func_name));
|
Term::var(&func_name).apply(Term::var(&func_name));
|
||||||
for param in recursive_nonstatic_params.iter() {
|
|
||||||
recursive_func_body = recursive_func_body.apply(Term::var(param));
|
|
||||||
}
|
|
||||||
|
|
||||||
if recursive_nonstatic_params.is_empty() {
|
if recursive_nonstatic_params.is_empty() {
|
||||||
recursive_func_body = recursive_func_body.force();
|
recursive_func_body = recursive_func_body.force();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Introduce a parameter for each parameter
|
||||||
|
// NOTE: we use recursive_nonstatic_params here because
|
||||||
|
// if this is recursive, those are the ones that need to be passed
|
||||||
|
// each time
|
||||||
|
for param in recursive_nonstatic_params.into_iter() {
|
||||||
|
recursive_func_body = recursive_func_body.apply(Term::var(param));
|
||||||
|
}
|
||||||
|
|
||||||
// Then construct an outer function with *all* parameters, not just the nonstatic ones.
|
// Then construct an outer function with *all* parameters, not just the nonstatic ones.
|
||||||
let mut outer_func_body =
|
let mut outer_func_body =
|
||||||
recursive_func_body.lambda(&func_name).apply(func_body);
|
recursive_func_body.lambda(&func_name).apply(func_body);
|
||||||
|
|
||||||
// Now, add *all* parameters, so that other call sites don't know the difference
|
// Now, add *all* parameters, so that other call sites don't know the difference
|
||||||
for param in params.iter().rev() {
|
outer_func_body = params
|
||||||
outer_func_body = outer_func_body.lambda(param);
|
.clone()
|
||||||
}
|
.into_iter()
|
||||||
|
.rfold(outer_func_body, |term, arg| term.lambda(arg));
|
||||||
|
|
||||||
// And finally, fold that definition into the rest of our program
|
// And finally, fold that definition into the rest of our program
|
||||||
term = term
|
Some(
|
||||||
.lambda(&func_name)
|
term.lambda(&func_name)
|
||||||
.apply(outer_func_body.lambda(NO_INLINE));
|
.apply(outer_func_body.lambda(NO_INLINE)),
|
||||||
}
|
)
|
||||||
|
|
||||||
Some(term)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Air::DefineCyclicFuncs {
|
air::FunctionVariants::Constant => todo!(),
|
||||||
func_name,
|
air::FunctionVariants::Cyclic(contained_functions) => {
|
||||||
module_name,
|
|
||||||
variant_name,
|
|
||||||
contained_functions,
|
|
||||||
} => {
|
|
||||||
let func_name = if module_name.is_empty() {
|
|
||||||
format!("{func_name}{variant_name}")
|
|
||||||
} else {
|
|
||||||
format!("{module_name}_{func_name}{variant_name}")
|
|
||||||
};
|
|
||||||
let mut cyclic_functions = vec![];
|
let mut cyclic_functions = vec![];
|
||||||
|
|
||||||
for params in contained_functions {
|
for params in contained_functions {
|
||||||
|
@ -5116,6 +5128,9 @@ impl<'a> CodeGenerator<'a> {
|
||||||
|
|
||||||
Some(term)
|
Some(term)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Air::Let { name } => {
|
Air::Let { name } => {
|
||||||
let arg = arg_stack.pop().unwrap();
|
let arg = arg_stack.pop().unwrap();
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,17 @@ impl From<bool> for ExpectLevel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum FunctionVariants {
|
||||||
|
Standard(Vec<String>),
|
||||||
|
Recursive {
|
||||||
|
params: Vec<String>,
|
||||||
|
recursive_nonstatic_params: Vec<String>,
|
||||||
|
},
|
||||||
|
Cyclic(Vec<Vec<String>>),
|
||||||
|
Constant,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub enum Air {
|
pub enum Air {
|
||||||
// Primitives
|
// Primitives
|
||||||
|
@ -67,17 +78,8 @@ pub enum Air {
|
||||||
DefineFunc {
|
DefineFunc {
|
||||||
func_name: String,
|
func_name: String,
|
||||||
module_name: String,
|
module_name: String,
|
||||||
params: Vec<String>,
|
|
||||||
recursive: bool,
|
|
||||||
recursive_nonstatic_params: Vec<String>,
|
|
||||||
variant_name: String,
|
variant_name: String,
|
||||||
},
|
variant: FunctionVariants,
|
||||||
DefineCyclicFuncs {
|
|
||||||
func_name: String,
|
|
||||||
module_name: String,
|
|
||||||
variant_name: String,
|
|
||||||
// just the params
|
|
||||||
contained_functions: Vec<Vec<String>>,
|
|
||||||
},
|
},
|
||||||
Fn {
|
Fn {
|
||||||
params: Vec<String>,
|
params: Vec<String>,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use super::air::{Air, ExpectLevel};
|
use super::air::{Air, ExpectLevel, FunctionVariants};
|
||||||
use crate::{
|
use crate::{
|
||||||
ast::{BinOp, Curve, Span, UnOp},
|
ast::{BinOp, Curve, Span, UnOp},
|
||||||
tipo::{Type, ValueConstructor, ValueConstructorVariant},
|
tipo::{Type, ValueConstructor, ValueConstructorVariant},
|
||||||
|
@ -21,6 +21,7 @@ pub enum Fields {
|
||||||
SixthField,
|
SixthField,
|
||||||
SeventhField,
|
SeventhField,
|
||||||
EighthField,
|
EighthField,
|
||||||
|
NinthField,
|
||||||
ArgsField(usize),
|
ArgsField(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,10 +137,12 @@ pub enum AirTree {
|
||||||
DefineFunc {
|
DefineFunc {
|
||||||
func_name: String,
|
func_name: String,
|
||||||
module_name: String,
|
module_name: String,
|
||||||
|
variant_name: String,
|
||||||
|
//params and other parts of a function
|
||||||
params: Vec<String>,
|
params: Vec<String>,
|
||||||
recursive: bool,
|
recursive: bool,
|
||||||
recursive_nonstatic_params: Vec<String>,
|
recursive_nonstatic_params: Vec<String>,
|
||||||
variant_name: String,
|
constant: bool,
|
||||||
func_body: Box<AirTree>,
|
func_body: Box<AirTree>,
|
||||||
then: Box<AirTree>,
|
then: Box<AirTree>,
|
||||||
},
|
},
|
||||||
|
@ -531,12 +534,33 @@ impl AirTree {
|
||||||
params,
|
params,
|
||||||
recursive,
|
recursive,
|
||||||
recursive_nonstatic_params,
|
recursive_nonstatic_params,
|
||||||
|
constant: false,
|
||||||
variant_name: variant_name.to_string(),
|
variant_name: variant_name.to_string(),
|
||||||
func_body: func_body.into(),
|
func_body: func_body.into(),
|
||||||
then: then.into(),
|
then: then.into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn define_const(
|
||||||
|
func_name: impl ToString,
|
||||||
|
module_name: impl ToString,
|
||||||
|
func_body: AirTree,
|
||||||
|
then: AirTree,
|
||||||
|
) -> AirTree {
|
||||||
|
AirTree::DefineFunc {
|
||||||
|
func_name: func_name.to_string(),
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
variant_name: "".to_string(),
|
||||||
|
params: vec![],
|
||||||
|
recursive: false,
|
||||||
|
recursive_nonstatic_params: vec![],
|
||||||
|
constant: true,
|
||||||
|
func_body: func_body.into(),
|
||||||
|
then: then.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn define_cyclic_func(
|
pub fn define_cyclic_func(
|
||||||
func_name: impl ToString,
|
func_name: impl ToString,
|
||||||
module_name: impl ToString,
|
module_name: impl ToString,
|
||||||
|
@ -1158,17 +1182,33 @@ impl AirTree {
|
||||||
params,
|
params,
|
||||||
recursive,
|
recursive,
|
||||||
recursive_nonstatic_params,
|
recursive_nonstatic_params,
|
||||||
|
constant,
|
||||||
variant_name,
|
variant_name,
|
||||||
func_body,
|
func_body,
|
||||||
then,
|
then,
|
||||||
} => {
|
} => {
|
||||||
|
let variant = if *constant {
|
||||||
|
assert!(!recursive);
|
||||||
|
assert!(params.is_empty());
|
||||||
|
assert!(recursive_nonstatic_params.is_empty());
|
||||||
|
|
||||||
|
FunctionVariants::Constant
|
||||||
|
} else if *recursive {
|
||||||
|
FunctionVariants::Recursive {
|
||||||
|
params: params.clone(),
|
||||||
|
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert_eq!(params, recursive_nonstatic_params);
|
||||||
|
FunctionVariants::Standard(params.clone())
|
||||||
|
};
|
||||||
|
|
||||||
air_vec.push(Air::DefineFunc {
|
air_vec.push(Air::DefineFunc {
|
||||||
func_name: func_name.clone(),
|
func_name: func_name.clone(),
|
||||||
module_name: module_name.clone(),
|
module_name: module_name.clone(),
|
||||||
params: params.clone(),
|
|
||||||
recursive: *recursive,
|
|
||||||
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
|
|
||||||
variant_name: variant_name.clone(),
|
variant_name: variant_name.clone(),
|
||||||
|
variant,
|
||||||
});
|
});
|
||||||
func_body.create_air_vec(air_vec);
|
func_body.create_air_vec(air_vec);
|
||||||
then.create_air_vec(air_vec);
|
then.create_air_vec(air_vec);
|
||||||
|
@ -1180,14 +1220,18 @@ impl AirTree {
|
||||||
contained_functions,
|
contained_functions,
|
||||||
then,
|
then,
|
||||||
} => {
|
} => {
|
||||||
air_vec.push(Air::DefineCyclicFuncs {
|
let variant = FunctionVariants::Cyclic(
|
||||||
func_name: func_name.clone(),
|
contained_functions
|
||||||
module_name: module_name.clone(),
|
|
||||||
variant_name: variant_name.clone(),
|
|
||||||
contained_functions: contained_functions
|
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(params, _)| params.clone())
|
.map(|(params, _)| params.clone())
|
||||||
.collect_vec(),
|
.collect_vec(),
|
||||||
|
);
|
||||||
|
|
||||||
|
air_vec.push(Air::DefineFunc {
|
||||||
|
func_name: func_name.clone(),
|
||||||
|
module_name: module_name.clone(),
|
||||||
|
variant_name: variant_name.clone(),
|
||||||
|
variant,
|
||||||
});
|
});
|
||||||
|
|
||||||
for (_, func_body) in contained_functions {
|
for (_, func_body) in contained_functions {
|
||||||
|
@ -1834,8 +1878,8 @@ impl AirTree {
|
||||||
) {
|
) {
|
||||||
tree_path.push(current_depth, field_index);
|
tree_path.push(current_depth, field_index);
|
||||||
|
|
||||||
// Assignments'/Statements' values get traversed here
|
// TODO: Merge together the 2 match statements
|
||||||
// Then the body under these assignments/statements get traversed later on
|
|
||||||
match self {
|
match self {
|
||||||
AirTree::Let {
|
AirTree::Let {
|
||||||
name: _,
|
name: _,
|
||||||
|
@ -2104,7 +2148,6 @@ impl AirTree {
|
||||||
| AirTree::MultiValidator { .. } => {}
|
| AirTree::MultiValidator { .. } => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expressions or an assignment that hoist over a expression are traversed here
|
|
||||||
match self {
|
match self {
|
||||||
AirTree::NoOp { then } => {
|
AirTree::NoOp { then } => {
|
||||||
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with);
|
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with);
|
||||||
|
@ -2389,16 +2432,17 @@ impl AirTree {
|
||||||
recursive: _,
|
recursive: _,
|
||||||
recursive_nonstatic_params: _,
|
recursive_nonstatic_params: _,
|
||||||
variant_name: _,
|
variant_name: _,
|
||||||
|
constant: _,
|
||||||
func_body,
|
func_body,
|
||||||
then,
|
then,
|
||||||
} => {
|
} => {
|
||||||
func_body.do_traverse_tree_with(
|
func_body.do_traverse_tree_with(
|
||||||
tree_path,
|
tree_path,
|
||||||
current_depth + 1,
|
current_depth + 1,
|
||||||
Fields::SeventhField,
|
Fields::EighthField,
|
||||||
with,
|
with,
|
||||||
);
|
);
|
||||||
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::EighthField, with)
|
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::NinthField, with)
|
||||||
}
|
}
|
||||||
AirTree::DefineCyclicFuncs {
|
AirTree::DefineCyclicFuncs {
|
||||||
func_name: _,
|
func_name: _,
|
||||||
|
|
Loading…
Reference in New Issue