Add codegen for recursive statics case
We also flip the recursive_statics fields to recursive_nonstatics; This makes the codegen a little easier. It also has a hacky way to hard code in some recursive statics for testing
This commit is contained in:
		
							parent
							
								
									586a2d7972
								
							
						
					
					
						commit
						09f889b121
					
				|  | @ -2690,6 +2690,19 @@ impl<'a> CodeGenerator<'a> { | |||
|             // first grab dependencies
 | ||||
|             let func_params = params; | ||||
| 
 | ||||
|             // HACK: partition params into the "static recursives" and otherwise
 | ||||
|             // for now, we just do this based on the name, but it should be detected
 | ||||
|             // as an optimization pass
 | ||||
|             let recursive_static_indexes: Vec<usize> = func_params.iter().enumerate().filter(|(idx, p)| { | ||||
|                 p.starts_with("pi_recursive_hack_") | ||||
|             }).map(|(idx, _)| idx).collect(); | ||||
|             let recursive_nonstatics: Vec<String> = func_params.iter().cloned().filter(|p| { | ||||
|                 !p.starts_with("pi_recursive_hack_") | ||||
|             }).collect(); | ||||
| 
 | ||||
|             println!("~~ recursive_nonstatics: {:?}", recursive_nonstatics); | ||||
|             println!("~~ func_params: {:?}", func_params); | ||||
| 
 | ||||
|             let params_empty = func_params.is_empty(); | ||||
| 
 | ||||
|             let deps = (tree_path, func_deps.clone()); | ||||
|  | @ -2697,7 +2710,7 @@ impl<'a> CodeGenerator<'a> { | |||
|             if !params_empty { | ||||
|                 if is_recursive { | ||||
|                     body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { | ||||
|                         modify_self_calls(air_tree, key, variant); | ||||
|                         modify_self_calls(air_tree, key, variant, &recursive_static_indexes); | ||||
|                     }); | ||||
|                 } | ||||
| 
 | ||||
|  | @ -2707,7 +2720,7 @@ impl<'a> CodeGenerator<'a> { | |||
|                     variant, | ||||
|                     func_params.clone(), | ||||
|                     is_recursive, | ||||
|                     vec![], | ||||
|                     recursive_nonstatics, | ||||
|                     body, | ||||
|                 ); | ||||
| 
 | ||||
|  | @ -2832,7 +2845,7 @@ impl<'a> CodeGenerator<'a> { | |||
| 
 | ||||
|                 if is_dependent_recursive { | ||||
|                     dep_air_tree.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { | ||||
|                         modify_self_calls(air_tree, &dep_key, &dep_variant); | ||||
|                         modify_self_calls(air_tree, &dep_key, &dep_variant, &vec![]); | ||||
|                     }); | ||||
|                 } | ||||
| 
 | ||||
|  | @ -2840,9 +2853,9 @@ impl<'a> CodeGenerator<'a> { | |||
|                     &dep_key.function_name, | ||||
|                     &dep_key.module_name, | ||||
|                     &dep_variant, | ||||
|                     dependent_params, | ||||
|                     dependent_params.clone(), | ||||
|                     is_dependent_recursive, | ||||
|                     vec![], | ||||
|                     dependent_params, | ||||
|                     dep_air_tree, | ||||
|                 )); | ||||
| 
 | ||||
|  | @ -3706,7 +3719,7 @@ impl<'a> CodeGenerator<'a> { | |||
|                 func_name, | ||||
|                 params, | ||||
|                 recursive, | ||||
|                 recursive_static_params, | ||||
|                 recursive_nonstatic_params, | ||||
|                 module_name, | ||||
|                 variant_name, | ||||
|             } => { | ||||
|  | @ -3719,7 +3732,11 @@ impl<'a> CodeGenerator<'a> { | |||
| 
 | ||||
|                 let mut term = arg_stack.pop().unwrap(); | ||||
| 
 | ||||
|                 for param in params.iter().rev() { | ||||
|                 // Introduce a parameter for each parameter
 | ||||
|                 // NOTE: we use recursive_nonstatic_params here because
 | ||||
|                 // if this is recursive, those are the ones that need to be passed
 | ||||
|                 // each time
 | ||||
|                 for param in recursive_nonstatic_params.iter().rev() { | ||||
|                     func_body = func_body.lambda(param.clone()); | ||||
|                 } | ||||
| 
 | ||||
|  | @ -3730,13 +3747,31 @@ impl<'a> CodeGenerator<'a> { | |||
|                 } else { | ||||
|                     func_body = func_body.lambda(func_name.clone()); | ||||
| 
 | ||||
|                     if recursive_nonstatic_params == params { | ||||
|                         // If we don't have any recursive-static params, we can just emit the function as is
 | ||||
|                         term = term | ||||
|                             .lambda(func_name.clone()) | ||||
|                             .apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone()))) | ||||
|                             .lambda(func_name) | ||||
|                             .apply(func_body); | ||||
|                     } else { | ||||
|                         // If we have parameters that remain static in each recursive call,
 | ||||
|                         // we can construct an *outer* function to take those in
 | ||||
|                         // and simplify the recursive part to only accept the non-static arguments
 | ||||
|                         let mut outer_func_body = Term::var(&func_name).apply(Term::var(&func_name)); | ||||
|                         for param in recursive_nonstatic_params.iter() { | ||||
|                             outer_func_body = outer_func_body.apply(Term::var(param)); | ||||
|                         } | ||||
| 
 | ||||
|                     // TODO: use recursive_static_params here
 | ||||
|                         outer_func_body = outer_func_body.lambda(&func_name).apply(func_body); | ||||
| 
 | ||||
|                         // Now, add *all* parameters, so that other call sites don't know the difference
 | ||||
|                         for param in params.iter().rev() { | ||||
|                             outer_func_body = outer_func_body.lambda(param); | ||||
|                         } | ||||
| 
 | ||||
|                         term = term.lambda(&func_name).apply(outer_func_body); | ||||
|                     } | ||||
| 
 | ||||
|                     arg_stack.push(term); | ||||
|                 } | ||||
|  |  | |||
|  | @ -47,7 +47,7 @@ pub enum Air { | |||
|         module_name: String, | ||||
|         params: Vec<String>, | ||||
|         recursive: bool, | ||||
|         recursive_static_params: Vec<String>, | ||||
|         recursive_nonstatic_params: Vec<String>, | ||||
|         variant_name: String, | ||||
|     }, | ||||
|     Fn { | ||||
|  |  | |||
|  | @ -583,7 +583,7 @@ pub fn erase_opaque_type_operations( | |||
|     } | ||||
| } | ||||
| 
 | ||||
| pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String) { | ||||
| pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, static_recursive_params: &Vec<usize>) { | ||||
|     if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree { | ||||
|         if let AirTree::Expression(AirExpression::Var { | ||||
|             constructor: | ||||
|  | @ -599,6 +599,11 @@ pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, v | |||
|                 && module == &func_key.module_name | ||||
|                 && variant == variant_name | ||||
|             { | ||||
|                 // Remove any static-recursive-parameters, because they'll be bound statically
 | ||||
|                 // above the recursive part of the function
 | ||||
|                 for arg in static_recursive_params.iter().rev() { | ||||
|                     args.remove(*arg); | ||||
|                 } | ||||
|                 let mut new_args = vec![func.as_ref().clone()]; | ||||
|                 new_args.append(args); | ||||
|                 *args = new_args; | ||||
|  |  | |||
|  | @ -129,7 +129,7 @@ pub enum AirStatement { | |||
|         module_name: String, | ||||
|         params: Vec<String>, | ||||
|         recursive: bool, | ||||
|         recursive_static_params: Vec<String>, | ||||
|         recursive_nonstatic_params: Vec<String>, | ||||
|         variant_name: String, | ||||
|         func_body: Box<AirTree>, | ||||
|     }, | ||||
|  | @ -424,7 +424,7 @@ impl AirTree { | |||
|         variant_name: impl ToString, | ||||
|         params: Vec<String>, | ||||
|         recursive: bool, | ||||
|         recursive_static_params: Vec<String>, | ||||
|         recursive_nonstatic_params: Vec<String>, | ||||
|         func_body: AirTree, | ||||
|     ) -> AirTree { | ||||
|         AirTree::Statement { | ||||
|  | @ -433,7 +433,7 @@ impl AirTree { | |||
|                 module_name: module_name.to_string(), | ||||
|                 params, | ||||
|                 recursive, | ||||
|                 recursive_static_params, | ||||
|                 recursive_nonstatic_params, | ||||
|                 variant_name: variant_name.to_string(), | ||||
|                 func_body: func_body.into(), | ||||
|             }, | ||||
|  | @ -878,7 +878,7 @@ impl AirTree { | |||
|                         module_name, | ||||
|                         params, | ||||
|                         recursive, | ||||
|                         recursive_static_params, | ||||
|                         recursive_nonstatic_params, | ||||
|                         variant_name, | ||||
|                         func_body, | ||||
|                     } => { | ||||
|  | @ -887,7 +887,7 @@ impl AirTree { | |||
|                             module_name: module_name.clone(), | ||||
|                             params: params.clone(), | ||||
|                             recursive: *recursive, | ||||
|                             recursive_static_params: recursive_static_params.clone(), | ||||
|                             recursive_nonstatic_params: recursive_nonstatic_params.clone(), | ||||
|                             variant_name: variant_name.clone(), | ||||
|                         }); | ||||
|                         func_body.create_air_vec(air_vec); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Pi Lanningham
						Pi Lanningham