From 2bbc699a25ccb0816a8d6d3d1f19857c9643a5a9 Mon Sep 17 00:00:00 2001 From: microproofs Date: Fri, 20 Sep 2024 00:39:27 -0400 Subject: [PATCH] Refactor some code to make a define_const AirTree function --- crates/aiken-lang/src/gen_uplc.rs | 223 +++++++++++++------------ crates/aiken-lang/src/gen_uplc/air.rs | 22 +-- crates/aiken-lang/src/gen_uplc/tree.rs | 74 ++++++-- 3 files changed, 190 insertions(+), 129 deletions(-) diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index ba31d1bd..4551ea61 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -4988,134 +4988,149 @@ impl<'a> CodeGenerator<'a> { } Air::DefineFunc { func_name, - params, - recursive, - recursive_nonstatic_params, module_name, variant_name, + variant, } => { let func_name = if module_name.is_empty() { format!("{func_name}{variant_name}") } else { format!("{module_name}_{func_name}{variant_name}") }; - let mut func_body = arg_stack.pop().unwrap(); - let mut term = arg_stack.pop().unwrap(); + match variant { + air::FunctionVariants::Standard(params) => { + let mut func_body = arg_stack.pop().unwrap(); - // Introduce a parameter for each parameter - // NOTE: we use recursive_nonstatic_params here because - // if this is recursive, those are the ones that need to be passed - // each time - for param in recursive_nonstatic_params.iter().rev() { - func_body = func_body.lambda(param.clone()); - } + let term = arg_stack.pop().unwrap(); - if recursive_nonstatic_params.is_empty() || params.is_empty() { - func_body = func_body.delay(); - } - - if !recursive { - term = term.lambda(func_name).apply(func_body.lambda(NO_INLINE)); - - Some(term) - } else { - func_body = func_body.lambda(func_name.clone()); - - if recursive_nonstatic_params == params { - // If we don't have any recursive-static params, we can just emit the function as is - term = term - .lambda(func_name.clone()) - .apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone()))) - .lambda(func_name) - .apply(func_body.lambda(NO_INLINE)); - } else { - // If we have parameters that remain static in each recursive call, - // we can construct an *outer* function to take those in - // and simplify the recursive part to only accept the non-static arguments - let mut recursive_func_body = - Term::var(&func_name).apply(Term::var(&func_name)); - for param in recursive_nonstatic_params.iter() { - recursive_func_body = recursive_func_body.apply(Term::var(param)); + if params.is_empty() { + func_body = func_body.delay(); } - if recursive_nonstatic_params.is_empty() { - recursive_func_body = recursive_func_body.force(); + let func_body = params + .into_iter() + .fold(func_body, |term, arg| term.lambda(arg)) + .lambda(NO_INLINE); + + Some(term.lambda(func_name).apply(func_body)) + } + air::FunctionVariants::Recursive { + params, + recursive_nonstatic_params, + } => { + let mut func_body = arg_stack.pop().unwrap(); + + let term = arg_stack.pop().unwrap(); + + let no_statics = recursive_nonstatic_params == params; + + if recursive_nonstatic_params.is_empty() || params.is_empty() { + func_body = func_body.delay(); } - // Then construct an outer function with *all* parameters, not just the nonstatic ones. - let mut outer_func_body = - recursive_func_body.lambda(&func_name).apply(func_body); + let func_body = recursive_nonstatic_params + .iter() + .rfold(func_body, |term, arg| term.lambda(arg)); - // Now, add *all* parameters, so that other call sites don't know the difference - for param in params.iter().rev() { - outer_func_body = outer_func_body.lambda(param); + let func_body = func_body.lambda(func_name.clone()); + + if no_statics { + // If we don't have any recursive-static params, we can just emit the function as is + Some( + term.lambda(func_name.clone()) + .apply( + Term::var(func_name.clone()) + .apply(Term::var(func_name.clone())), + ) + .lambda(func_name) + .apply(func_body.lambda(NO_INLINE)), + ) + } else { + // If we have parameters that remain static in each recursive call, + // we can construct an *outer* function to take those in + // and simplify the recursive part to only accept the non-static arguments + let mut recursive_func_body = + Term::var(&func_name).apply(Term::var(&func_name)); + + if recursive_nonstatic_params.is_empty() { + recursive_func_body = recursive_func_body.force(); + } + + // Introduce a parameter for each parameter + // NOTE: we use recursive_nonstatic_params here because + // if this is recursive, those are the ones that need to be passed + // each time + for param in recursive_nonstatic_params.into_iter() { + recursive_func_body = recursive_func_body.apply(Term::var(param)); + } + + // Then construct an outer function with *all* parameters, not just the nonstatic ones. + let mut outer_func_body = + recursive_func_body.lambda(&func_name).apply(func_body); + + // Now, add *all* parameters, so that other call sites don't know the difference + outer_func_body = params + .clone() + .into_iter() + .rfold(outer_func_body, |term, arg| term.lambda(arg)); + + // And finally, fold that definition into the rest of our program + Some( + term.lambda(&func_name) + .apply(outer_func_body.lambda(NO_INLINE)), + ) + } + } + air::FunctionVariants::Constant => todo!(), + air::FunctionVariants::Cyclic(contained_functions) => { + let mut cyclic_functions = vec![]; + + for params in contained_functions { + let func_body = arg_stack.pop().unwrap(); + + cyclic_functions.push((params, func_body)); + } + let mut term = arg_stack.pop().unwrap(); + + let mut cyclic_body = Term::var("__chooser"); + + for (params, func_body) in cyclic_functions.into_iter() { + let mut function = func_body; + if params.is_empty() { + function = function.delay(); + } else { + for param in params.iter().rev() { + function = function.lambda(param); + } + } + + // We basically Scott encode our function bodies and use the chooser function + // to determine which function body and params is run + // For example say there is a cycle of 3 function bodies + // Our choose function can look like this: + // \func1 -> \func2 -> \func3 -> func1 + // In this case our chooser is a function that takes in 3 functions + // and returns the first one to run + cyclic_body = cyclic_body.apply(function) } - // And finally, fold that definition into the rest of our program term = term .lambda(&func_name) - .apply(outer_func_body.lambda(NO_INLINE)); - } + .apply(Term::var(&func_name).apply(Term::var(&func_name))) + .lambda(&func_name) + .apply( + cyclic_body + .lambda("__chooser") + .lambda(func_name) + .lambda(NO_INLINE), + ); - Some(term) + Some(term) + } } } - Air::DefineCyclicFuncs { - func_name, - module_name, - variant_name, - contained_functions, - } => { - let func_name = if module_name.is_empty() { - format!("{func_name}{variant_name}") - } else { - format!("{module_name}_{func_name}{variant_name}") - }; - let mut cyclic_functions = vec![]; - for params in contained_functions { - let func_body = arg_stack.pop().unwrap(); - - cyclic_functions.push((params, func_body)); - } - let mut term = arg_stack.pop().unwrap(); - - let mut cyclic_body = Term::var("__chooser"); - - for (params, func_body) in cyclic_functions.into_iter() { - let mut function = func_body; - if params.is_empty() { - function = function.delay(); - } else { - for param in params.iter().rev() { - function = function.lambda(param); - } - } - - // We basically Scott encode our function bodies and use the chooser function - // to determine which function body and params is run - // For example say there is a cycle of 3 function bodies - // Our choose function can look like this: - // \func1 -> \func2 -> \func3 -> func1 - // In this case our chooser is a function that takes in 3 functions - // and returns the first one to run - cyclic_body = cyclic_body.apply(function) - } - - term = term - .lambda(&func_name) - .apply(Term::var(&func_name).apply(Term::var(&func_name))) - .lambda(&func_name) - .apply( - cyclic_body - .lambda("__chooser") - .lambda(func_name) - .lambda(NO_INLINE), - ); - - Some(term) - } Air::Let { name } => { let arg = arg_stack.pop().unwrap(); diff --git a/crates/aiken-lang/src/gen_uplc/air.rs b/crates/aiken-lang/src/gen_uplc/air.rs index 981ad76e..6845d8d0 100644 --- a/crates/aiken-lang/src/gen_uplc/air.rs +++ b/crates/aiken-lang/src/gen_uplc/air.rs @@ -23,6 +23,17 @@ impl From for ExpectLevel { } } +#[derive(Debug, Clone, PartialEq)] +pub enum FunctionVariants { + Standard(Vec), + Recursive { + params: Vec, + recursive_nonstatic_params: Vec, + }, + Cyclic(Vec>), + Constant, +} + #[derive(Debug, Clone, PartialEq)] pub enum Air { // Primitives @@ -67,17 +78,8 @@ pub enum Air { DefineFunc { func_name: String, module_name: String, - params: Vec, - recursive: bool, - recursive_nonstatic_params: Vec, variant_name: String, - }, - DefineCyclicFuncs { - func_name: String, - module_name: String, - variant_name: String, - // just the params - contained_functions: Vec>, + variant: FunctionVariants, }, Fn { params: Vec, diff --git a/crates/aiken-lang/src/gen_uplc/tree.rs b/crates/aiken-lang/src/gen_uplc/tree.rs index 6e0183e3..5342ef42 100644 --- a/crates/aiken-lang/src/gen_uplc/tree.rs +++ b/crates/aiken-lang/src/gen_uplc/tree.rs @@ -1,4 +1,4 @@ -use super::air::{Air, ExpectLevel}; +use super::air::{Air, ExpectLevel, FunctionVariants}; use crate::{ ast::{BinOp, Curve, Span, UnOp}, tipo::{Type, ValueConstructor, ValueConstructorVariant}, @@ -21,6 +21,7 @@ pub enum Fields { SixthField, SeventhField, EighthField, + NinthField, ArgsField(usize), } @@ -136,10 +137,12 @@ pub enum AirTree { DefineFunc { func_name: String, module_name: String, + variant_name: String, + //params and other parts of a function params: Vec, recursive: bool, recursive_nonstatic_params: Vec, - variant_name: String, + constant: bool, func_body: Box, then: Box, }, @@ -531,12 +534,33 @@ impl AirTree { params, recursive, recursive_nonstatic_params, + constant: false, variant_name: variant_name.to_string(), func_body: func_body.into(), then: then.into(), } } + #[allow(clippy::too_many_arguments)] + pub fn define_const( + func_name: impl ToString, + module_name: impl ToString, + func_body: AirTree, + then: AirTree, + ) -> AirTree { + AirTree::DefineFunc { + func_name: func_name.to_string(), + module_name: module_name.to_string(), + variant_name: "".to_string(), + params: vec![], + recursive: false, + recursive_nonstatic_params: vec![], + constant: true, + func_body: func_body.into(), + then: then.into(), + } + } + pub fn define_cyclic_func( func_name: impl ToString, module_name: impl ToString, @@ -1158,17 +1182,33 @@ impl AirTree { params, recursive, recursive_nonstatic_params, + constant, variant_name, func_body, then, } => { + let variant = if *constant { + assert!(!recursive); + assert!(params.is_empty()); + assert!(recursive_nonstatic_params.is_empty()); + + FunctionVariants::Constant + } else if *recursive { + FunctionVariants::Recursive { + params: params.clone(), + recursive_nonstatic_params: recursive_nonstatic_params.clone(), + } + } else { + assert_eq!(params, recursive_nonstatic_params); + FunctionVariants::Standard(params.clone()) + }; + air_vec.push(Air::DefineFunc { func_name: func_name.clone(), module_name: module_name.clone(), - params: params.clone(), - recursive: *recursive, - recursive_nonstatic_params: recursive_nonstatic_params.clone(), + variant_name: variant_name.clone(), + variant, }); func_body.create_air_vec(air_vec); then.create_air_vec(air_vec); @@ -1180,14 +1220,18 @@ impl AirTree { contained_functions, then, } => { - air_vec.push(Air::DefineCyclicFuncs { - func_name: func_name.clone(), - module_name: module_name.clone(), - variant_name: variant_name.clone(), - contained_functions: contained_functions + let variant = FunctionVariants::Cyclic( + contained_functions .iter() .map(|(params, _)| params.clone()) .collect_vec(), + ); + + air_vec.push(Air::DefineFunc { + func_name: func_name.clone(), + module_name: module_name.clone(), + variant_name: variant_name.clone(), + variant, }); for (_, func_body) in contained_functions { @@ -1834,8 +1878,8 @@ impl AirTree { ) { tree_path.push(current_depth, field_index); - // Assignments'/Statements' values get traversed here - // Then the body under these assignments/statements get traversed later on + // TODO: Merge together the 2 match statements + match self { AirTree::Let { name: _, @@ -2104,7 +2148,6 @@ impl AirTree { | AirTree::MultiValidator { .. } => {} } - // Expressions or an assignment that hoist over a expression are traversed here match self { AirTree::NoOp { then } => { then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::FirstField, with); @@ -2389,16 +2432,17 @@ impl AirTree { recursive: _, recursive_nonstatic_params: _, variant_name: _, + constant: _, func_body, then, } => { func_body.do_traverse_tree_with( tree_path, current_depth + 1, - Fields::SeventhField, + Fields::EighthField, with, ); - then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::EighthField, with) + then.do_traverse_tree_with(tree_path, current_depth + 1, Fields::NinthField, with) } AirTree::DefineCyclicFuncs { func_name: _,