Refactor some code to make a define_const AirTree function

This commit is contained in:
microproofs 2024-09-20 00:39:27 -04:00
parent 513ca27717
commit 2bbc699a25
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
3 changed files with 190 additions and 129 deletions

View File

@ -4988,90 +4988,102 @@ impl<'a> CodeGenerator<'a> {
}
Air::DefineFunc {
func_name,
params,
recursive,
recursive_nonstatic_params,
module_name,
variant_name,
variant,
} => {
let func_name = if module_name.is_empty() {
format!("{func_name}{variant_name}")
} else {
format!("{module_name}_{func_name}{variant_name}")
};
match variant {
air::FunctionVariants::Standard(params) => {
let mut func_body = arg_stack.pop().unwrap();
let mut term = arg_stack.pop().unwrap();
let term = arg_stack.pop().unwrap();
// Introduce a parameter for each parameter
// NOTE: we use recursive_nonstatic_params here because
// if this is recursive, those are the ones that need to be passed
// each time
for param in recursive_nonstatic_params.iter().rev() {
func_body = func_body.lambda(param.clone());
if params.is_empty() {
func_body = func_body.delay();
}
let func_body = params
.into_iter()
.fold(func_body, |term, arg| term.lambda(arg))
.lambda(NO_INLINE);
Some(term.lambda(func_name).apply(func_body))
}
air::FunctionVariants::Recursive {
params,
recursive_nonstatic_params,
} => {
let mut func_body = arg_stack.pop().unwrap();
let term = arg_stack.pop().unwrap();
let no_statics = recursive_nonstatic_params == params;
if recursive_nonstatic_params.is_empty() || params.is_empty() {
func_body = func_body.delay();
}
if !recursive {
term = term.lambda(func_name).apply(func_body.lambda(NO_INLINE));
let func_body = recursive_nonstatic_params
.iter()
.rfold(func_body, |term, arg| term.lambda(arg));
Some(term)
} else {
func_body = func_body.lambda(func_name.clone());
let func_body = func_body.lambda(func_name.clone());
if recursive_nonstatic_params == params {
if no_statics {
// If we don't have any recursive-static params, we can just emit the function as is
term = term
.lambda(func_name.clone())
.apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone())))
Some(
term.lambda(func_name.clone())
.apply(
Term::var(func_name.clone())
.apply(Term::var(func_name.clone())),
)
.lambda(func_name)
.apply(func_body.lambda(NO_INLINE));
.apply(func_body.lambda(NO_INLINE)),
)
} else {
// If we have parameters that remain static in each recursive call,
// we can construct an *outer* function to take those in
// and simplify the recursive part to only accept the non-static arguments
let mut recursive_func_body =
Term::var(&func_name).apply(Term::var(&func_name));
for param in recursive_nonstatic_params.iter() {
recursive_func_body = recursive_func_body.apply(Term::var(param));
}
if recursive_nonstatic_params.is_empty() {
recursive_func_body = recursive_func_body.force();
}
// Introduce a parameter for each parameter
// NOTE: we use recursive_nonstatic_params here because
// if this is recursive, those are the ones that need to be passed
// each time
for param in recursive_nonstatic_params.into_iter() {
recursive_func_body = recursive_func_body.apply(Term::var(param));
}
// Then construct an outer function with *all* parameters, not just the nonstatic ones.
let mut outer_func_body =
recursive_func_body.lambda(&func_name).apply(func_body);
// Now, add *all* parameters, so that other call sites don't know the difference
for param in params.iter().rev() {
outer_func_body = outer_func_body.lambda(param);
}
outer_func_body = params
.clone()
.into_iter()
.rfold(outer_func_body, |term, arg| term.lambda(arg));
// And finally, fold that definition into the rest of our program
term = term
.lambda(&func_name)
.apply(outer_func_body.lambda(NO_INLINE));
}
Some(term)
Some(
term.lambda(&func_name)
.apply(outer_func_body.lambda(NO_INLINE)),
)
}
}
Air::DefineCyclicFuncs {
func_name,
module_name,
variant_name,
contained_functions,
} => {
let func_name = if module_name.is_empty() {
format!("{func_name}{variant_name}")
} else {
format!("{module_name}_{func_name}{variant_name}")
};
air::FunctionVariants::Constant => todo!(),
air::FunctionVariants::Cyclic(contained_functions) => {
let mut cyclic_functions = vec![];
for params in contained_functions {
@ -5116,6 +5128,9 @@ impl<'a> CodeGenerator<'a> {
Some(term)
}
}
}
Air::Let { name } => {
let arg = arg_stack.pop().unwrap();

View File

@ -23,6 +23,17 @@ impl From<bool> for ExpectLevel {
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FunctionVariants {
Standard(Vec<String>),
Recursive {
params: Vec<String>,
recursive_nonstatic_params: Vec<String>,
},
Cyclic(Vec<Vec<String>>),
Constant,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Air {
// Primitives
@ -67,17 +78,8 @@ pub enum Air {
DefineFunc {
func_name: String,
module_name: String,
params: Vec<String>,
recursive: bool,
recursive_nonstatic_params: Vec<String>,
variant_name: String,
},
DefineCyclicFuncs {
func_name: String,
module_name: String,
variant_name: String,
// just the params
contained_functions: Vec<Vec<String>>,
variant: FunctionVariants,
},
Fn {
params: Vec<String>,

View File

@ -1,4 +1,4 @@
use super::air::{Air, ExpectLevel};
use super::air::{Air, ExpectLevel, FunctionVariants};
use crate::{
ast::{BinOp, Curve, Span, UnOp},
tipo::{Type, ValueConstructor, ValueConstructorVariant},
@ -21,6 +21,7 @@ pub enum Fields {
SixthField,
SeventhField,
EighthField,
NinthField,
ArgsField(usize),
}
@ -136,10 +137,12 @@ pub enum AirTree {
DefineFunc {
func_name: String,
module_name: String,
variant_name: String,
//params and other parts of a function
params: Vec<String>,
recursive: bool,
recursive_nonstatic_params: Vec<String>,
variant_name: String,
constant: bool,
func_body: Box<AirTree>,
then: Box<AirTree>,
},
@ -531,12 +534,33 @@ impl AirTree {
params,
recursive,
recursive_nonstatic_params,
constant: false,
variant_name: variant_name.to_string(),
func_body: func_body.into(),
then: then.into(),
}
}
#[allow(clippy::too_many_arguments)]
pub fn define_const(
func_name: impl ToString,
module_name: impl ToString,
func_body: AirTree,
then: AirTree,
) -> AirTree {
AirTree::DefineFunc {
func_name: func_name.to_string(),
module_name: module_name.to_string(),
variant_name: "".to_string(),
params: vec![],
recursive: false,
recursive_nonstatic_params: vec![],
constant: true,
func_body: func_body.into(),
then: then.into(),
}
}
pub fn define_cyclic_func(
func_name: impl ToString,
module_name: impl ToString,
@ -1158,17 +1182,33 @@ impl AirTree {
params,
recursive,
recursive_nonstatic_params,
constant,
variant_name,
func_body,
then,
} => {
let variant = if *constant {
assert!(!recursive);
assert!(params.is_empty());
assert!(recursive_nonstatic_params.is_empty());
FunctionVariants::Constant
} else if *recursive {
FunctionVariants::Recursive {
params: params.clone(),
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
}
} else {
assert_eq!(params, recursive_nonstatic_params);
FunctionVariants::Standard(params.clone())
};
air_vec.push(Air::DefineFunc {
func_name: func_name.clone(),
module_name: module_name.clone(),
params: params.clone(),
recursive: *recursive,
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
variant_name: variant_name.clone(),
variant,
});
func_body.create_air_vec(air_vec);
then.create_air_vec(air_vec);
@ -1180,14 +1220,18 @@ impl AirTree {
contained_functions,
then,
} => {
air_vec.push(Air::DefineCyclicFuncs {
func_name: func_name.clone(),
module_name: module_name.clone(),
variant_name: variant_name.clone(),
contained_functions: contained_functions
let variant = FunctionVariants::Cyclic(
contained_functions
.iter()
.map(|(params, _)| params.clone())
.collect_vec(),
);
air_vec.push(Air::DefineFunc {
func_name: func_name.clone(),
module_name: module_name.clone(),
variant_name: variant_name.clone(),
variant,
});
for (_, func_body) in contained_functions {
@ -1834,8 +1878,8 @@ impl AirTree {
) {
tree_path.push(current_depth, field_index);
// Assignments'/Statements' values get traversed here
// Then the body under these assignments/statements get traversed later on
// TODO: Merge together the 2 match statements
match self {
AirTree::Let {
name: _,
@ -2104,7 +2148,6 @@ impl AirTree {
| AirTree::MultiValidator { .. } => {}
}
// Expressions or an assignment that hoist over a expression are traversed here
match self {
AirTree::NoOp { then } => {
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with);
@ -2389,16 +2432,17 @@ impl AirTree {
recursive: _,
recursive_nonstatic_params: _,
variant_name: _,
constant: _,
func_body,
then,
} => {
func_body.do_traverse_tree_with(
tree_path,
current_depth + 1,
Fields::SeventhField,
Fields::EighthField,
with,
);
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::EighthField, with)
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::NinthField, with)
}
AirTree::DefineCyclicFuncs {
func_name: _,