Refactor some code to make a define_const AirTree function

This commit is contained in:
microproofs 2024-09-20 00:39:27 -04:00
parent 513ca27717
commit 2bbc699a25
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
3 changed files with 190 additions and 129 deletions

View File

@ -4988,90 +4988,102 @@ impl<'a> CodeGenerator<'a> {
} }
Air::DefineFunc { Air::DefineFunc {
func_name, func_name,
params,
recursive,
recursive_nonstatic_params,
module_name, module_name,
variant_name, variant_name,
variant,
} => { } => {
let func_name = if module_name.is_empty() { let func_name = if module_name.is_empty() {
format!("{func_name}{variant_name}") format!("{func_name}{variant_name}")
} else { } else {
format!("{module_name}_{func_name}{variant_name}") format!("{module_name}_{func_name}{variant_name}")
}; };
match variant {
air::FunctionVariants::Standard(params) => {
let mut func_body = arg_stack.pop().unwrap(); let mut func_body = arg_stack.pop().unwrap();
let mut term = arg_stack.pop().unwrap(); let term = arg_stack.pop().unwrap();
// Introduce a parameter for each parameter if params.is_empty() {
// NOTE: we use recursive_nonstatic_params here because func_body = func_body.delay();
// if this is recursive, those are the ones that need to be passed
// each time
for param in recursive_nonstatic_params.iter().rev() {
func_body = func_body.lambda(param.clone());
} }
let func_body = params
.into_iter()
.fold(func_body, |term, arg| term.lambda(arg))
.lambda(NO_INLINE);
Some(term.lambda(func_name).apply(func_body))
}
air::FunctionVariants::Recursive {
params,
recursive_nonstatic_params,
} => {
let mut func_body = arg_stack.pop().unwrap();
let term = arg_stack.pop().unwrap();
let no_statics = recursive_nonstatic_params == params;
if recursive_nonstatic_params.is_empty() || params.is_empty() { if recursive_nonstatic_params.is_empty() || params.is_empty() {
func_body = func_body.delay(); func_body = func_body.delay();
} }
if !recursive { let func_body = recursive_nonstatic_params
term = term.lambda(func_name).apply(func_body.lambda(NO_INLINE)); .iter()
.rfold(func_body, |term, arg| term.lambda(arg));
Some(term) let func_body = func_body.lambda(func_name.clone());
} else {
func_body = func_body.lambda(func_name.clone());
if recursive_nonstatic_params == params { if no_statics {
// If we don't have any recursive-static params, we can just emit the function as is // If we don't have any recursive-static params, we can just emit the function as is
term = term Some(
.lambda(func_name.clone()) term.lambda(func_name.clone())
.apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone()))) .apply(
Term::var(func_name.clone())
.apply(Term::var(func_name.clone())),
)
.lambda(func_name) .lambda(func_name)
.apply(func_body.lambda(NO_INLINE)); .apply(func_body.lambda(NO_INLINE)),
)
} else { } else {
// If we have parameters that remain static in each recursive call, // If we have parameters that remain static in each recursive call,
// we can construct an *outer* function to take those in // we can construct an *outer* function to take those in
// and simplify the recursive part to only accept the non-static arguments // and simplify the recursive part to only accept the non-static arguments
let mut recursive_func_body = let mut recursive_func_body =
Term::var(&func_name).apply(Term::var(&func_name)); Term::var(&func_name).apply(Term::var(&func_name));
for param in recursive_nonstatic_params.iter() {
recursive_func_body = recursive_func_body.apply(Term::var(param));
}
if recursive_nonstatic_params.is_empty() { if recursive_nonstatic_params.is_empty() {
recursive_func_body = recursive_func_body.force(); recursive_func_body = recursive_func_body.force();
} }
// Introduce a parameter for each parameter
// NOTE: we use recursive_nonstatic_params here because
// if this is recursive, those are the ones that need to be passed
// each time
for param in recursive_nonstatic_params.into_iter() {
recursive_func_body = recursive_func_body.apply(Term::var(param));
}
// Then construct an outer function with *all* parameters, not just the nonstatic ones. // Then construct an outer function with *all* parameters, not just the nonstatic ones.
let mut outer_func_body = let mut outer_func_body =
recursive_func_body.lambda(&func_name).apply(func_body); recursive_func_body.lambda(&func_name).apply(func_body);
// Now, add *all* parameters, so that other call sites don't know the difference // Now, add *all* parameters, so that other call sites don't know the difference
for param in params.iter().rev() { outer_func_body = params
outer_func_body = outer_func_body.lambda(param); .clone()
} .into_iter()
.rfold(outer_func_body, |term, arg| term.lambda(arg));
// And finally, fold that definition into the rest of our program // And finally, fold that definition into the rest of our program
term = term Some(
.lambda(&func_name) term.lambda(&func_name)
.apply(outer_func_body.lambda(NO_INLINE)); .apply(outer_func_body.lambda(NO_INLINE)),
} )
Some(term)
} }
} }
Air::DefineCyclicFuncs { air::FunctionVariants::Constant => todo!(),
func_name, air::FunctionVariants::Cyclic(contained_functions) => {
module_name,
variant_name,
contained_functions,
} => {
let func_name = if module_name.is_empty() {
format!("{func_name}{variant_name}")
} else {
format!("{module_name}_{func_name}{variant_name}")
};
let mut cyclic_functions = vec![]; let mut cyclic_functions = vec![];
for params in contained_functions { for params in contained_functions {
@ -5116,6 +5128,9 @@ impl<'a> CodeGenerator<'a> {
Some(term) Some(term)
} }
}
}
Air::Let { name } => { Air::Let { name } => {
let arg = arg_stack.pop().unwrap(); let arg = arg_stack.pop().unwrap();

View File

@ -23,6 +23,17 @@ impl From<bool> for ExpectLevel {
} }
} }
#[derive(Debug, Clone, PartialEq)]
pub enum FunctionVariants {
Standard(Vec<String>),
Recursive {
params: Vec<String>,
recursive_nonstatic_params: Vec<String>,
},
Cyclic(Vec<Vec<String>>),
Constant,
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum Air { pub enum Air {
// Primitives // Primitives
@ -67,17 +78,8 @@ pub enum Air {
DefineFunc { DefineFunc {
func_name: String, func_name: String,
module_name: String, module_name: String,
params: Vec<String>,
recursive: bool,
recursive_nonstatic_params: Vec<String>,
variant_name: String, variant_name: String,
}, variant: FunctionVariants,
DefineCyclicFuncs {
func_name: String,
module_name: String,
variant_name: String,
// just the params
contained_functions: Vec<Vec<String>>,
}, },
Fn { Fn {
params: Vec<String>, params: Vec<String>,

View File

@ -1,4 +1,4 @@
use super::air::{Air, ExpectLevel}; use super::air::{Air, ExpectLevel, FunctionVariants};
use crate::{ use crate::{
ast::{BinOp, Curve, Span, UnOp}, ast::{BinOp, Curve, Span, UnOp},
tipo::{Type, ValueConstructor, ValueConstructorVariant}, tipo::{Type, ValueConstructor, ValueConstructorVariant},
@ -21,6 +21,7 @@ pub enum Fields {
SixthField, SixthField,
SeventhField, SeventhField,
EighthField, EighthField,
NinthField,
ArgsField(usize), ArgsField(usize),
} }
@ -136,10 +137,12 @@ pub enum AirTree {
DefineFunc { DefineFunc {
func_name: String, func_name: String,
module_name: String, module_name: String,
variant_name: String,
//params and other parts of a function
params: Vec<String>, params: Vec<String>,
recursive: bool, recursive: bool,
recursive_nonstatic_params: Vec<String>, recursive_nonstatic_params: Vec<String>,
variant_name: String, constant: bool,
func_body: Box<AirTree>, func_body: Box<AirTree>,
then: Box<AirTree>, then: Box<AirTree>,
}, },
@ -531,12 +534,33 @@ impl AirTree {
params, params,
recursive, recursive,
recursive_nonstatic_params, recursive_nonstatic_params,
constant: false,
variant_name: variant_name.to_string(), variant_name: variant_name.to_string(),
func_body: func_body.into(), func_body: func_body.into(),
then: then.into(), then: then.into(),
} }
} }
#[allow(clippy::too_many_arguments)]
pub fn define_const(
func_name: impl ToString,
module_name: impl ToString,
func_body: AirTree,
then: AirTree,
) -> AirTree {
AirTree::DefineFunc {
func_name: func_name.to_string(),
module_name: module_name.to_string(),
variant_name: "".to_string(),
params: vec![],
recursive: false,
recursive_nonstatic_params: vec![],
constant: true,
func_body: func_body.into(),
then: then.into(),
}
}
pub fn define_cyclic_func( pub fn define_cyclic_func(
func_name: impl ToString, func_name: impl ToString,
module_name: impl ToString, module_name: impl ToString,
@ -1158,17 +1182,33 @@ impl AirTree {
params, params,
recursive, recursive,
recursive_nonstatic_params, recursive_nonstatic_params,
constant,
variant_name, variant_name,
func_body, func_body,
then, then,
} => { } => {
let variant = if *constant {
assert!(!recursive);
assert!(params.is_empty());
assert!(recursive_nonstatic_params.is_empty());
FunctionVariants::Constant
} else if *recursive {
FunctionVariants::Recursive {
params: params.clone(),
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
}
} else {
assert_eq!(params, recursive_nonstatic_params);
FunctionVariants::Standard(params.clone())
};
air_vec.push(Air::DefineFunc { air_vec.push(Air::DefineFunc {
func_name: func_name.clone(), func_name: func_name.clone(),
module_name: module_name.clone(), module_name: module_name.clone(),
params: params.clone(),
recursive: *recursive,
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
variant_name: variant_name.clone(), variant_name: variant_name.clone(),
variant,
}); });
func_body.create_air_vec(air_vec); func_body.create_air_vec(air_vec);
then.create_air_vec(air_vec); then.create_air_vec(air_vec);
@ -1180,14 +1220,18 @@ impl AirTree {
contained_functions, contained_functions,
then, then,
} => { } => {
air_vec.push(Air::DefineCyclicFuncs { let variant = FunctionVariants::Cyclic(
func_name: func_name.clone(), contained_functions
module_name: module_name.clone(),
variant_name: variant_name.clone(),
contained_functions: contained_functions
.iter() .iter()
.map(|(params, _)| params.clone()) .map(|(params, _)| params.clone())
.collect_vec(), .collect_vec(),
);
air_vec.push(Air::DefineFunc {
func_name: func_name.clone(),
module_name: module_name.clone(),
variant_name: variant_name.clone(),
variant,
}); });
for (_, func_body) in contained_functions { for (_, func_body) in contained_functions {
@ -1834,8 +1878,8 @@ impl AirTree {
) { ) {
tree_path.push(current_depth, field_index); tree_path.push(current_depth, field_index);
// Assignments'/Statements' values get traversed here // TODO: Merge together the 2 match statements
// Then the body under these assignments/statements get traversed later on
match self { match self {
AirTree::Let { AirTree::Let {
name: _, name: _,
@ -2104,7 +2148,6 @@ impl AirTree {
| AirTree::MultiValidator { .. } => {} | AirTree::MultiValidator { .. } => {}
} }
// Expressions or an assignment that hoist over a expression are traversed here
match self { match self {
AirTree::NoOp { then } => { AirTree::NoOp { then } => {
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with); then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with);
@ -2389,16 +2432,17 @@ impl AirTree {
recursive: _, recursive: _,
recursive_nonstatic_params: _, recursive_nonstatic_params: _,
variant_name: _, variant_name: _,
constant: _,
func_body, func_body,
then, then,
} => { } => {
func_body.do_traverse_tree_with( func_body.do_traverse_tree_with(
tree_path, tree_path,
current_depth + 1, current_depth + 1,
Fields::SeventhField, Fields::EighthField,
with, with,
); );
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::EighthField, with) then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::NinthField, with)
} }
AirTree::DefineCyclicFuncs { AirTree::DefineCyclicFuncs {
func_name: _, func_name: _,