Refactor some code to make a define_const AirTree function

This commit is contained in:
microproofs 2024-09-20 00:39:27 -04:00
parent 513ca27717
commit 2bbc699a25
No known key found for this signature in database
GPG Key ID: 14F93C84DE6AFD17
3 changed files with 190 additions and 129 deletions

View File

@ -4988,134 +4988,149 @@ impl<'a> CodeGenerator<'a> {
} }
Air::DefineFunc { Air::DefineFunc {
func_name, func_name,
params,
recursive,
recursive_nonstatic_params,
module_name, module_name,
variant_name, variant_name,
variant,
} => { } => {
let func_name = if module_name.is_empty() { let func_name = if module_name.is_empty() {
format!("{func_name}{variant_name}") format!("{func_name}{variant_name}")
} else { } else {
format!("{module_name}_{func_name}{variant_name}") format!("{module_name}_{func_name}{variant_name}")
}; };
let mut func_body = arg_stack.pop().unwrap();
let mut term = arg_stack.pop().unwrap(); match variant {
air::FunctionVariants::Standard(params) => {
let mut func_body = arg_stack.pop().unwrap();
// Introduce a parameter for each parameter let term = arg_stack.pop().unwrap();
// NOTE: we use recursive_nonstatic_params here because
// if this is recursive, those are the ones that need to be passed
// each time
for param in recursive_nonstatic_params.iter().rev() {
func_body = func_body.lambda(param.clone());
}
if recursive_nonstatic_params.is_empty() || params.is_empty() { if params.is_empty() {
func_body = func_body.delay(); func_body = func_body.delay();
}
if !recursive {
term = term.lambda(func_name).apply(func_body.lambda(NO_INLINE));
Some(term)
} else {
func_body = func_body.lambda(func_name.clone());
if recursive_nonstatic_params == params {
// If we don't have any recursive-static params, we can just emit the function as is
term = term
.lambda(func_name.clone())
.apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone())))
.lambda(func_name)
.apply(func_body.lambda(NO_INLINE));
} else {
// If we have parameters that remain static in each recursive call,
// we can construct an *outer* function to take those in
// and simplify the recursive part to only accept the non-static arguments
let mut recursive_func_body =
Term::var(&func_name).apply(Term::var(&func_name));
for param in recursive_nonstatic_params.iter() {
recursive_func_body = recursive_func_body.apply(Term::var(param));
} }
if recursive_nonstatic_params.is_empty() { let func_body = params
recursive_func_body = recursive_func_body.force(); .into_iter()
.fold(func_body, |term, arg| term.lambda(arg))
.lambda(NO_INLINE);
Some(term.lambda(func_name).apply(func_body))
}
air::FunctionVariants::Recursive {
params,
recursive_nonstatic_params,
} => {
let mut func_body = arg_stack.pop().unwrap();
let term = arg_stack.pop().unwrap();
let no_statics = recursive_nonstatic_params == params;
if recursive_nonstatic_params.is_empty() || params.is_empty() {
func_body = func_body.delay();
} }
// Then construct an outer function with *all* parameters, not just the nonstatic ones. let func_body = recursive_nonstatic_params
let mut outer_func_body = .iter()
recursive_func_body.lambda(&func_name).apply(func_body); .rfold(func_body, |term, arg| term.lambda(arg));
// Now, add *all* parameters, so that other call sites don't know the difference let func_body = func_body.lambda(func_name.clone());
for param in params.iter().rev() {
outer_func_body = outer_func_body.lambda(param); if no_statics {
// If we don't have any recursive-static params, we can just emit the function as is
Some(
term.lambda(func_name.clone())
.apply(
Term::var(func_name.clone())
.apply(Term::var(func_name.clone())),
)
.lambda(func_name)
.apply(func_body.lambda(NO_INLINE)),
)
} else {
// If we have parameters that remain static in each recursive call,
// we can construct an *outer* function to take those in
// and simplify the recursive part to only accept the non-static arguments
let mut recursive_func_body =
Term::var(&func_name).apply(Term::var(&func_name));
if recursive_nonstatic_params.is_empty() {
recursive_func_body = recursive_func_body.force();
}
// Introduce a parameter for each parameter
// NOTE: we use recursive_nonstatic_params here because
// if this is recursive, those are the ones that need to be passed
// each time
for param in recursive_nonstatic_params.into_iter() {
recursive_func_body = recursive_func_body.apply(Term::var(param));
}
// Then construct an outer function with *all* parameters, not just the nonstatic ones.
let mut outer_func_body =
recursive_func_body.lambda(&func_name).apply(func_body);
// Now, add *all* parameters, so that other call sites don't know the difference
outer_func_body = params
.clone()
.into_iter()
.rfold(outer_func_body, |term, arg| term.lambda(arg));
// And finally, fold that definition into the rest of our program
Some(
term.lambda(&func_name)
.apply(outer_func_body.lambda(NO_INLINE)),
)
}
}
air::FunctionVariants::Constant => todo!(),
air::FunctionVariants::Cyclic(contained_functions) => {
let mut cyclic_functions = vec![];
for params in contained_functions {
let func_body = arg_stack.pop().unwrap();
cyclic_functions.push((params, func_body));
}
let mut term = arg_stack.pop().unwrap();
let mut cyclic_body = Term::var("__chooser");
for (params, func_body) in cyclic_functions.into_iter() {
let mut function = func_body;
if params.is_empty() {
function = function.delay();
} else {
for param in params.iter().rev() {
function = function.lambda(param);
}
}
// We basically Scott encode our function bodies and use the chooser function
// to determine which function body and params is run
// For example say there is a cycle of 3 function bodies
// Our choose function can look like this:
// \func1 -> \func2 -> \func3 -> func1
// In this case our chooser is a function that takes in 3 functions
// and returns the first one to run
cyclic_body = cyclic_body.apply(function)
} }
// And finally, fold that definition into the rest of our program
term = term term = term
.lambda(&func_name) .lambda(&func_name)
.apply(outer_func_body.lambda(NO_INLINE)); .apply(Term::var(&func_name).apply(Term::var(&func_name)))
} .lambda(&func_name)
.apply(
cyclic_body
.lambda("__chooser")
.lambda(func_name)
.lambda(NO_INLINE),
);
Some(term) Some(term)
}
} }
} }
Air::DefineCyclicFuncs {
func_name,
module_name,
variant_name,
contained_functions,
} => {
let func_name = if module_name.is_empty() {
format!("{func_name}{variant_name}")
} else {
format!("{module_name}_{func_name}{variant_name}")
};
let mut cyclic_functions = vec![];
for params in contained_functions {
let func_body = arg_stack.pop().unwrap();
cyclic_functions.push((params, func_body));
}
let mut term = arg_stack.pop().unwrap();
let mut cyclic_body = Term::var("__chooser");
for (params, func_body) in cyclic_functions.into_iter() {
let mut function = func_body;
if params.is_empty() {
function = function.delay();
} else {
for param in params.iter().rev() {
function = function.lambda(param);
}
}
// We basically Scott encode our function bodies and use the chooser function
// to determine which function body and params is run
// For example say there is a cycle of 3 function bodies
// Our choose function can look like this:
// \func1 -> \func2 -> \func3 -> func1
// In this case our chooser is a function that takes in 3 functions
// and returns the first one to run
cyclic_body = cyclic_body.apply(function)
}
term = term
.lambda(&func_name)
.apply(Term::var(&func_name).apply(Term::var(&func_name)))
.lambda(&func_name)
.apply(
cyclic_body
.lambda("__chooser")
.lambda(func_name)
.lambda(NO_INLINE),
);
Some(term)
}
Air::Let { name } => { Air::Let { name } => {
let arg = arg_stack.pop().unwrap(); let arg = arg_stack.pop().unwrap();

View File

@ -23,6 +23,17 @@ impl From<bool> for ExpectLevel {
} }
} }
#[derive(Debug, Clone, PartialEq)]
pub enum FunctionVariants {
Standard(Vec<String>),
Recursive {
params: Vec<String>,
recursive_nonstatic_params: Vec<String>,
},
Cyclic(Vec<Vec<String>>),
Constant,
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum Air { pub enum Air {
// Primitives // Primitives
@ -67,17 +78,8 @@ pub enum Air {
DefineFunc { DefineFunc {
func_name: String, func_name: String,
module_name: String, module_name: String,
params: Vec<String>,
recursive: bool,
recursive_nonstatic_params: Vec<String>,
variant_name: String, variant_name: String,
}, variant: FunctionVariants,
DefineCyclicFuncs {
func_name: String,
module_name: String,
variant_name: String,
// just the params
contained_functions: Vec<Vec<String>>,
}, },
Fn { Fn {
params: Vec<String>, params: Vec<String>,

View File

@ -1,4 +1,4 @@
use super::air::{Air, ExpectLevel}; use super::air::{Air, ExpectLevel, FunctionVariants};
use crate::{ use crate::{
ast::{BinOp, Curve, Span, UnOp}, ast::{BinOp, Curve, Span, UnOp},
tipo::{Type, ValueConstructor, ValueConstructorVariant}, tipo::{Type, ValueConstructor, ValueConstructorVariant},
@ -21,6 +21,7 @@ pub enum Fields {
SixthField, SixthField,
SeventhField, SeventhField,
EighthField, EighthField,
NinthField,
ArgsField(usize), ArgsField(usize),
} }
@ -136,10 +137,12 @@ pub enum AirTree {
DefineFunc { DefineFunc {
func_name: String, func_name: String,
module_name: String, module_name: String,
variant_name: String,
//params and other parts of a function
params: Vec<String>, params: Vec<String>,
recursive: bool, recursive: bool,
recursive_nonstatic_params: Vec<String>, recursive_nonstatic_params: Vec<String>,
variant_name: String, constant: bool,
func_body: Box<AirTree>, func_body: Box<AirTree>,
then: Box<AirTree>, then: Box<AirTree>,
}, },
@ -531,12 +534,33 @@ impl AirTree {
params, params,
recursive, recursive,
recursive_nonstatic_params, recursive_nonstatic_params,
constant: false,
variant_name: variant_name.to_string(), variant_name: variant_name.to_string(),
func_body: func_body.into(), func_body: func_body.into(),
then: then.into(), then: then.into(),
} }
} }
#[allow(clippy::too_many_arguments)]
pub fn define_const(
func_name: impl ToString,
module_name: impl ToString,
func_body: AirTree,
then: AirTree,
) -> AirTree {
AirTree::DefineFunc {
func_name: func_name.to_string(),
module_name: module_name.to_string(),
variant_name: "".to_string(),
params: vec![],
recursive: false,
recursive_nonstatic_params: vec![],
constant: true,
func_body: func_body.into(),
then: then.into(),
}
}
pub fn define_cyclic_func( pub fn define_cyclic_func(
func_name: impl ToString, func_name: impl ToString,
module_name: impl ToString, module_name: impl ToString,
@ -1158,17 +1182,33 @@ impl AirTree {
params, params,
recursive, recursive,
recursive_nonstatic_params, recursive_nonstatic_params,
constant,
variant_name, variant_name,
func_body, func_body,
then, then,
} => { } => {
let variant = if *constant {
assert!(!recursive);
assert!(params.is_empty());
assert!(recursive_nonstatic_params.is_empty());
FunctionVariants::Constant
} else if *recursive {
FunctionVariants::Recursive {
params: params.clone(),
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
}
} else {
assert_eq!(params, recursive_nonstatic_params);
FunctionVariants::Standard(params.clone())
};
air_vec.push(Air::DefineFunc { air_vec.push(Air::DefineFunc {
func_name: func_name.clone(), func_name: func_name.clone(),
module_name: module_name.clone(), module_name: module_name.clone(),
params: params.clone(),
recursive: *recursive,
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
variant_name: variant_name.clone(), variant_name: variant_name.clone(),
variant,
}); });
func_body.create_air_vec(air_vec); func_body.create_air_vec(air_vec);
then.create_air_vec(air_vec); then.create_air_vec(air_vec);
@ -1180,14 +1220,18 @@ impl AirTree {
contained_functions, contained_functions,
then, then,
} => { } => {
air_vec.push(Air::DefineCyclicFuncs { let variant = FunctionVariants::Cyclic(
func_name: func_name.clone(), contained_functions
module_name: module_name.clone(),
variant_name: variant_name.clone(),
contained_functions: contained_functions
.iter() .iter()
.map(|(params, _)| params.clone()) .map(|(params, _)| params.clone())
.collect_vec(), .collect_vec(),
);
air_vec.push(Air::DefineFunc {
func_name: func_name.clone(),
module_name: module_name.clone(),
variant_name: variant_name.clone(),
variant,
}); });
for (_, func_body) in contained_functions { for (_, func_body) in contained_functions {
@ -1834,8 +1878,8 @@ impl AirTree {
) { ) {
tree_path.push(current_depth, field_index); tree_path.push(current_depth, field_index);
// Assignments'/Statements' values get traversed here // TODO: Merge together the 2 match statements
// Then the body under these assignments/statements get traversed later on
match self { match self {
AirTree::Let { AirTree::Let {
name: _, name: _,
@ -2104,7 +2148,6 @@ impl AirTree {
| AirTree::MultiValidator { .. } => {} | AirTree::MultiValidator { .. } => {}
} }
// Expressions or an assignment that hoist over a expression are traversed here
match self { match self {
AirTree::NoOp { then } => { AirTree::NoOp { then } => {
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with); then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with);
@ -2389,16 +2432,17 @@ impl AirTree {
recursive: _, recursive: _,
recursive_nonstatic_params: _, recursive_nonstatic_params: _,
variant_name: _, variant_name: _,
constant: _,
func_body, func_body,
then, then,
} => { } => {
func_body.do_traverse_tree_with( func_body.do_traverse_tree_with(
tree_path, tree_path,
current_depth + 1, current_depth + 1,
Fields::SeventhField, Fields::EighthField,
with, with,
); );
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::EighthField, with) then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::NinthField, with)
} }
AirTree::DefineCyclicFuncs { AirTree::DefineCyclicFuncs {
func_name: _, func_name: _,