Refactor some code to make a define_const AirTree function
This commit is contained in:
parent
513ca27717
commit
2bbc699a25
|
@ -4988,134 +4988,149 @@ impl<'a> CodeGenerator<'a> {
|
|||
}
|
||||
Air::DefineFunc {
|
||||
func_name,
|
||||
params,
|
||||
recursive,
|
||||
recursive_nonstatic_params,
|
||||
module_name,
|
||||
variant_name,
|
||||
variant,
|
||||
} => {
|
||||
let func_name = if module_name.is_empty() {
|
||||
format!("{func_name}{variant_name}")
|
||||
} else {
|
||||
format!("{module_name}_{func_name}{variant_name}")
|
||||
};
|
||||
let mut func_body = arg_stack.pop().unwrap();
|
||||
|
||||
let mut term = arg_stack.pop().unwrap();
|
||||
match variant {
|
||||
air::FunctionVariants::Standard(params) => {
|
||||
let mut func_body = arg_stack.pop().unwrap();
|
||||
|
||||
// Introduce a parameter for each parameter
|
||||
// NOTE: we use recursive_nonstatic_params here because
|
||||
// if this is recursive, those are the ones that need to be passed
|
||||
// each time
|
||||
for param in recursive_nonstatic_params.iter().rev() {
|
||||
func_body = func_body.lambda(param.clone());
|
||||
}
|
||||
let term = arg_stack.pop().unwrap();
|
||||
|
||||
if recursive_nonstatic_params.is_empty() || params.is_empty() {
|
||||
func_body = func_body.delay();
|
||||
}
|
||||
|
||||
if !recursive {
|
||||
term = term.lambda(func_name).apply(func_body.lambda(NO_INLINE));
|
||||
|
||||
Some(term)
|
||||
} else {
|
||||
func_body = func_body.lambda(func_name.clone());
|
||||
|
||||
if recursive_nonstatic_params == params {
|
||||
// If we don't have any recursive-static params, we can just emit the function as is
|
||||
term = term
|
||||
.lambda(func_name.clone())
|
||||
.apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone())))
|
||||
.lambda(func_name)
|
||||
.apply(func_body.lambda(NO_INLINE));
|
||||
} else {
|
||||
// If we have parameters that remain static in each recursive call,
|
||||
// we can construct an *outer* function to take those in
|
||||
// and simplify the recursive part to only accept the non-static arguments
|
||||
let mut recursive_func_body =
|
||||
Term::var(&func_name).apply(Term::var(&func_name));
|
||||
for param in recursive_nonstatic_params.iter() {
|
||||
recursive_func_body = recursive_func_body.apply(Term::var(param));
|
||||
if params.is_empty() {
|
||||
func_body = func_body.delay();
|
||||
}
|
||||
|
||||
if recursive_nonstatic_params.is_empty() {
|
||||
recursive_func_body = recursive_func_body.force();
|
||||
let func_body = params
|
||||
.into_iter()
|
||||
.fold(func_body, |term, arg| term.lambda(arg))
|
||||
.lambda(NO_INLINE);
|
||||
|
||||
Some(term.lambda(func_name).apply(func_body))
|
||||
}
|
||||
air::FunctionVariants::Recursive {
|
||||
params,
|
||||
recursive_nonstatic_params,
|
||||
} => {
|
||||
let mut func_body = arg_stack.pop().unwrap();
|
||||
|
||||
let term = arg_stack.pop().unwrap();
|
||||
|
||||
let no_statics = recursive_nonstatic_params == params;
|
||||
|
||||
if recursive_nonstatic_params.is_empty() || params.is_empty() {
|
||||
func_body = func_body.delay();
|
||||
}
|
||||
|
||||
// Then construct an outer function with *all* parameters, not just the nonstatic ones.
|
||||
let mut outer_func_body =
|
||||
recursive_func_body.lambda(&func_name).apply(func_body);
|
||||
let func_body = recursive_nonstatic_params
|
||||
.iter()
|
||||
.rfold(func_body, |term, arg| term.lambda(arg));
|
||||
|
||||
// Now, add *all* parameters, so that other call sites don't know the difference
|
||||
for param in params.iter().rev() {
|
||||
outer_func_body = outer_func_body.lambda(param);
|
||||
let func_body = func_body.lambda(func_name.clone());
|
||||
|
||||
if no_statics {
|
||||
// If we don't have any recursive-static params, we can just emit the function as is
|
||||
Some(
|
||||
term.lambda(func_name.clone())
|
||||
.apply(
|
||||
Term::var(func_name.clone())
|
||||
.apply(Term::var(func_name.clone())),
|
||||
)
|
||||
.lambda(func_name)
|
||||
.apply(func_body.lambda(NO_INLINE)),
|
||||
)
|
||||
} else {
|
||||
// If we have parameters that remain static in each recursive call,
|
||||
// we can construct an *outer* function to take those in
|
||||
// and simplify the recursive part to only accept the non-static arguments
|
||||
let mut recursive_func_body =
|
||||
Term::var(&func_name).apply(Term::var(&func_name));
|
||||
|
||||
if recursive_nonstatic_params.is_empty() {
|
||||
recursive_func_body = recursive_func_body.force();
|
||||
}
|
||||
|
||||
// Introduce a parameter for each parameter
|
||||
// NOTE: we use recursive_nonstatic_params here because
|
||||
// if this is recursive, those are the ones that need to be passed
|
||||
// each time
|
||||
for param in recursive_nonstatic_params.into_iter() {
|
||||
recursive_func_body = recursive_func_body.apply(Term::var(param));
|
||||
}
|
||||
|
||||
// Then construct an outer function with *all* parameters, not just the nonstatic ones.
|
||||
let mut outer_func_body =
|
||||
recursive_func_body.lambda(&func_name).apply(func_body);
|
||||
|
||||
// Now, add *all* parameters, so that other call sites don't know the difference
|
||||
outer_func_body = params
|
||||
.clone()
|
||||
.into_iter()
|
||||
.rfold(outer_func_body, |term, arg| term.lambda(arg));
|
||||
|
||||
// And finally, fold that definition into the rest of our program
|
||||
Some(
|
||||
term.lambda(&func_name)
|
||||
.apply(outer_func_body.lambda(NO_INLINE)),
|
||||
)
|
||||
}
|
||||
}
|
||||
air::FunctionVariants::Constant => todo!(),
|
||||
air::FunctionVariants::Cyclic(contained_functions) => {
|
||||
let mut cyclic_functions = vec![];
|
||||
|
||||
for params in contained_functions {
|
||||
let func_body = arg_stack.pop().unwrap();
|
||||
|
||||
cyclic_functions.push((params, func_body));
|
||||
}
|
||||
let mut term = arg_stack.pop().unwrap();
|
||||
|
||||
let mut cyclic_body = Term::var("__chooser");
|
||||
|
||||
for (params, func_body) in cyclic_functions.into_iter() {
|
||||
let mut function = func_body;
|
||||
if params.is_empty() {
|
||||
function = function.delay();
|
||||
} else {
|
||||
for param in params.iter().rev() {
|
||||
function = function.lambda(param);
|
||||
}
|
||||
}
|
||||
|
||||
// We basically Scott encode our function bodies and use the chooser function
|
||||
// to determine which function body and params is run
|
||||
// For example say there is a cycle of 3 function bodies
|
||||
// Our choose function can look like this:
|
||||
// \func1 -> \func2 -> \func3 -> func1
|
||||
// In this case our chooser is a function that takes in 3 functions
|
||||
// and returns the first one to run
|
||||
cyclic_body = cyclic_body.apply(function)
|
||||
}
|
||||
|
||||
// And finally, fold that definition into the rest of our program
|
||||
term = term
|
||||
.lambda(&func_name)
|
||||
.apply(outer_func_body.lambda(NO_INLINE));
|
||||
}
|
||||
.apply(Term::var(&func_name).apply(Term::var(&func_name)))
|
||||
.lambda(&func_name)
|
||||
.apply(
|
||||
cyclic_body
|
||||
.lambda("__chooser")
|
||||
.lambda(func_name)
|
||||
.lambda(NO_INLINE),
|
||||
);
|
||||
|
||||
Some(term)
|
||||
Some(term)
|
||||
}
|
||||
}
|
||||
}
|
||||
Air::DefineCyclicFuncs {
|
||||
func_name,
|
||||
module_name,
|
||||
variant_name,
|
||||
contained_functions,
|
||||
} => {
|
||||
let func_name = if module_name.is_empty() {
|
||||
format!("{func_name}{variant_name}")
|
||||
} else {
|
||||
format!("{module_name}_{func_name}{variant_name}")
|
||||
};
|
||||
let mut cyclic_functions = vec![];
|
||||
|
||||
for params in contained_functions {
|
||||
let func_body = arg_stack.pop().unwrap();
|
||||
|
||||
cyclic_functions.push((params, func_body));
|
||||
}
|
||||
let mut term = arg_stack.pop().unwrap();
|
||||
|
||||
let mut cyclic_body = Term::var("__chooser");
|
||||
|
||||
for (params, func_body) in cyclic_functions.into_iter() {
|
||||
let mut function = func_body;
|
||||
if params.is_empty() {
|
||||
function = function.delay();
|
||||
} else {
|
||||
for param in params.iter().rev() {
|
||||
function = function.lambda(param);
|
||||
}
|
||||
}
|
||||
|
||||
// We basically Scott encode our function bodies and use the chooser function
|
||||
// to determine which function body and params is run
|
||||
// For example say there is a cycle of 3 function bodies
|
||||
// Our choose function can look like this:
|
||||
// \func1 -> \func2 -> \func3 -> func1
|
||||
// In this case our chooser is a function that takes in 3 functions
|
||||
// and returns the first one to run
|
||||
cyclic_body = cyclic_body.apply(function)
|
||||
}
|
||||
|
||||
term = term
|
||||
.lambda(&func_name)
|
||||
.apply(Term::var(&func_name).apply(Term::var(&func_name)))
|
||||
.lambda(&func_name)
|
||||
.apply(
|
||||
cyclic_body
|
||||
.lambda("__chooser")
|
||||
.lambda(func_name)
|
||||
.lambda(NO_INLINE),
|
||||
);
|
||||
|
||||
Some(term)
|
||||
}
|
||||
Air::Let { name } => {
|
||||
let arg = arg_stack.pop().unwrap();
|
||||
|
||||
|
|
|
@ -23,6 +23,17 @@ impl From<bool> for ExpectLevel {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum FunctionVariants {
|
||||
Standard(Vec<String>),
|
||||
Recursive {
|
||||
params: Vec<String>,
|
||||
recursive_nonstatic_params: Vec<String>,
|
||||
},
|
||||
Cyclic(Vec<Vec<String>>),
|
||||
Constant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Air {
|
||||
// Primitives
|
||||
|
@ -67,17 +78,8 @@ pub enum Air {
|
|||
DefineFunc {
|
||||
func_name: String,
|
||||
module_name: String,
|
||||
params: Vec<String>,
|
||||
recursive: bool,
|
||||
recursive_nonstatic_params: Vec<String>,
|
||||
variant_name: String,
|
||||
},
|
||||
DefineCyclicFuncs {
|
||||
func_name: String,
|
||||
module_name: String,
|
||||
variant_name: String,
|
||||
// just the params
|
||||
contained_functions: Vec<Vec<String>>,
|
||||
variant: FunctionVariants,
|
||||
},
|
||||
Fn {
|
||||
params: Vec<String>,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::air::{Air, ExpectLevel};
|
||||
use super::air::{Air, ExpectLevel, FunctionVariants};
|
||||
use crate::{
|
||||
ast::{BinOp, Curve, Span, UnOp},
|
||||
tipo::{Type, ValueConstructor, ValueConstructorVariant},
|
||||
|
@ -21,6 +21,7 @@ pub enum Fields {
|
|||
SixthField,
|
||||
SeventhField,
|
||||
EighthField,
|
||||
NinthField,
|
||||
ArgsField(usize),
|
||||
}
|
||||
|
||||
|
@ -136,10 +137,12 @@ pub enum AirTree {
|
|||
DefineFunc {
|
||||
func_name: String,
|
||||
module_name: String,
|
||||
variant_name: String,
|
||||
//params and other parts of a function
|
||||
params: Vec<String>,
|
||||
recursive: bool,
|
||||
recursive_nonstatic_params: Vec<String>,
|
||||
variant_name: String,
|
||||
constant: bool,
|
||||
func_body: Box<AirTree>,
|
||||
then: Box<AirTree>,
|
||||
},
|
||||
|
@ -531,12 +534,33 @@ impl AirTree {
|
|||
params,
|
||||
recursive,
|
||||
recursive_nonstatic_params,
|
||||
constant: false,
|
||||
variant_name: variant_name.to_string(),
|
||||
func_body: func_body.into(),
|
||||
then: then.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn define_const(
|
||||
func_name: impl ToString,
|
||||
module_name: impl ToString,
|
||||
func_body: AirTree,
|
||||
then: AirTree,
|
||||
) -> AirTree {
|
||||
AirTree::DefineFunc {
|
||||
func_name: func_name.to_string(),
|
||||
module_name: module_name.to_string(),
|
||||
variant_name: "".to_string(),
|
||||
params: vec![],
|
||||
recursive: false,
|
||||
recursive_nonstatic_params: vec![],
|
||||
constant: true,
|
||||
func_body: func_body.into(),
|
||||
then: then.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn define_cyclic_func(
|
||||
func_name: impl ToString,
|
||||
module_name: impl ToString,
|
||||
|
@ -1158,17 +1182,33 @@ impl AirTree {
|
|||
params,
|
||||
recursive,
|
||||
recursive_nonstatic_params,
|
||||
constant,
|
||||
variant_name,
|
||||
func_body,
|
||||
then,
|
||||
} => {
|
||||
let variant = if *constant {
|
||||
assert!(!recursive);
|
||||
assert!(params.is_empty());
|
||||
assert!(recursive_nonstatic_params.is_empty());
|
||||
|
||||
FunctionVariants::Constant
|
||||
} else if *recursive {
|
||||
FunctionVariants::Recursive {
|
||||
params: params.clone(),
|
||||
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
|
||||
}
|
||||
} else {
|
||||
assert_eq!(params, recursive_nonstatic_params);
|
||||
FunctionVariants::Standard(params.clone())
|
||||
};
|
||||
|
||||
air_vec.push(Air::DefineFunc {
|
||||
func_name: func_name.clone(),
|
||||
module_name: module_name.clone(),
|
||||
params: params.clone(),
|
||||
recursive: *recursive,
|
||||
recursive_nonstatic_params: recursive_nonstatic_params.clone(),
|
||||
|
||||
variant_name: variant_name.clone(),
|
||||
variant,
|
||||
});
|
||||
func_body.create_air_vec(air_vec);
|
||||
then.create_air_vec(air_vec);
|
||||
|
@ -1180,14 +1220,18 @@ impl AirTree {
|
|||
contained_functions,
|
||||
then,
|
||||
} => {
|
||||
air_vec.push(Air::DefineCyclicFuncs {
|
||||
func_name: func_name.clone(),
|
||||
module_name: module_name.clone(),
|
||||
variant_name: variant_name.clone(),
|
||||
contained_functions: contained_functions
|
||||
let variant = FunctionVariants::Cyclic(
|
||||
contained_functions
|
||||
.iter()
|
||||
.map(|(params, _)| params.clone())
|
||||
.collect_vec(),
|
||||
);
|
||||
|
||||
air_vec.push(Air::DefineFunc {
|
||||
func_name: func_name.clone(),
|
||||
module_name: module_name.clone(),
|
||||
variant_name: variant_name.clone(),
|
||||
variant,
|
||||
});
|
||||
|
||||
for (_, func_body) in contained_functions {
|
||||
|
@ -1834,8 +1878,8 @@ impl AirTree {
|
|||
) {
|
||||
tree_path.push(current_depth, field_index);
|
||||
|
||||
// Assignments'/Statements' values get traversed here
|
||||
// Then the body under these assignments/statements get traversed later on
|
||||
// TODO: Merge together the 2 match statements
|
||||
|
||||
match self {
|
||||
AirTree::Let {
|
||||
name: _,
|
||||
|
@ -2104,7 +2148,6 @@ impl AirTree {
|
|||
| AirTree::MultiValidator { .. } => {}
|
||||
}
|
||||
|
||||
// Expressions or an assignment that hoist over a expression are traversed here
|
||||
match self {
|
||||
AirTree::NoOp { then } => {
|
||||
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with);
|
||||
|
@ -2389,16 +2432,17 @@ impl AirTree {
|
|||
recursive: _,
|
||||
recursive_nonstatic_params: _,
|
||||
variant_name: _,
|
||||
constant: _,
|
||||
func_body,
|
||||
then,
|
||||
} => {
|
||||
func_body.do_traverse_tree_with(
|
||||
tree_path,
|
||||
current_depth + 1,
|
||||
Fields::SeventhField,
|
||||
Fields::EighthField,
|
||||
with,
|
||||
);
|
||||
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::EighthField, with)
|
||||
then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::NinthField, with)
|
||||
}
|
||||
AirTree::DefineCyclicFuncs {
|
||||
func_name: _,
|
||||
|
|
Loading…
Reference in New Issue