Add codegen for recursive statics case
We also flip the recursive_statics fields to recursive_nonstatics; This makes the codegen a little easier. It also has a hacky way to hard code in some recursive statics for testing
This commit is contained in:
		
							parent
							
								
									586a2d7972
								
							
						
					
					
						commit
						09f889b121
					
				|  | @ -2690,6 +2690,19 @@ impl<'a> CodeGenerator<'a> { | ||||||
|             // first grab dependencies
 |             // first grab dependencies
 | ||||||
|             let func_params = params; |             let func_params = params; | ||||||
| 
 | 
 | ||||||
|  |             // HACK: partition params into the "static recursives" and otherwise
 | ||||||
|  |             // for now, we just do this based on the name, but it should be detected
 | ||||||
|  |             // as an optimization pass
 | ||||||
|  |             let recursive_static_indexes: Vec<usize> = func_params.iter().enumerate().filter(|(idx, p)| { | ||||||
|  |                 p.starts_with("pi_recursive_hack_") | ||||||
|  |             }).map(|(idx, _)| idx).collect(); | ||||||
|  |             let recursive_nonstatics: Vec<String> = func_params.iter().cloned().filter(|p| { | ||||||
|  |                 !p.starts_with("pi_recursive_hack_") | ||||||
|  |             }).collect(); | ||||||
|  | 
 | ||||||
|  |             println!("~~ recursive_nonstatics: {:?}", recursive_nonstatics); | ||||||
|  |             println!("~~ func_params: {:?}", func_params); | ||||||
|  | 
 | ||||||
|             let params_empty = func_params.is_empty(); |             let params_empty = func_params.is_empty(); | ||||||
| 
 | 
 | ||||||
|             let deps = (tree_path, func_deps.clone()); |             let deps = (tree_path, func_deps.clone()); | ||||||
|  | @ -2697,7 +2710,7 @@ impl<'a> CodeGenerator<'a> { | ||||||
|             if !params_empty { |             if !params_empty { | ||||||
|                 if is_recursive { |                 if is_recursive { | ||||||
|                     body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { |                     body.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { | ||||||
|                         modify_self_calls(air_tree, key, variant); |                         modify_self_calls(air_tree, key, variant, &recursive_static_indexes); | ||||||
|                     }); |                     }); | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|  | @ -2707,7 +2720,7 @@ impl<'a> CodeGenerator<'a> { | ||||||
|                     variant, |                     variant, | ||||||
|                     func_params.clone(), |                     func_params.clone(), | ||||||
|                     is_recursive, |                     is_recursive, | ||||||
|                     vec![], |                     recursive_nonstatics, | ||||||
|                     body, |                     body, | ||||||
|                 ); |                 ); | ||||||
| 
 | 
 | ||||||
|  | @ -2832,7 +2845,7 @@ impl<'a> CodeGenerator<'a> { | ||||||
| 
 | 
 | ||||||
|                 if is_dependent_recursive { |                 if is_dependent_recursive { | ||||||
|                     dep_air_tree.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { |                     dep_air_tree.traverse_tree_with(&mut |air_tree: &mut AirTree, _| { | ||||||
|                         modify_self_calls(air_tree, &dep_key, &dep_variant); |                         modify_self_calls(air_tree, &dep_key, &dep_variant, &vec![]); | ||||||
|                     }); |                     }); | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|  | @ -2840,9 +2853,9 @@ impl<'a> CodeGenerator<'a> { | ||||||
|                     &dep_key.function_name, |                     &dep_key.function_name, | ||||||
|                     &dep_key.module_name, |                     &dep_key.module_name, | ||||||
|                     &dep_variant, |                     &dep_variant, | ||||||
|                     dependent_params, |                     dependent_params.clone(), | ||||||
|                     is_dependent_recursive, |                     is_dependent_recursive, | ||||||
|                     vec![], |                     dependent_params, | ||||||
|                     dep_air_tree, |                     dep_air_tree, | ||||||
|                 )); |                 )); | ||||||
| 
 | 
 | ||||||
|  | @ -3706,7 +3719,7 @@ impl<'a> CodeGenerator<'a> { | ||||||
|                 func_name, |                 func_name, | ||||||
|                 params, |                 params, | ||||||
|                 recursive, |                 recursive, | ||||||
|                 recursive_static_params, |                 recursive_nonstatic_params, | ||||||
|                 module_name, |                 module_name, | ||||||
|                 variant_name, |                 variant_name, | ||||||
|             } => { |             } => { | ||||||
|  | @ -3719,7 +3732,11 @@ impl<'a> CodeGenerator<'a> { | ||||||
| 
 | 
 | ||||||
|                 let mut term = arg_stack.pop().unwrap(); |                 let mut term = arg_stack.pop().unwrap(); | ||||||
| 
 | 
 | ||||||
|                 for param in params.iter().rev() { |                 // Introduce a parameter for each parameter
 | ||||||
|  |                 // NOTE: we use recursive_nonstatic_params here because
 | ||||||
|  |                 // if this is recursive, those are the ones that need to be passed
 | ||||||
|  |                 // each time
 | ||||||
|  |                 for param in recursive_nonstatic_params.iter().rev() { | ||||||
|                     func_body = func_body.lambda(param.clone()); |                     func_body = func_body.lambda(param.clone()); | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|  | @ -3730,13 +3747,31 @@ impl<'a> CodeGenerator<'a> { | ||||||
|                 } else { |                 } else { | ||||||
|                     func_body = func_body.lambda(func_name.clone()); |                     func_body = func_body.lambda(func_name.clone()); | ||||||
| 
 | 
 | ||||||
|                     term = term |                     if recursive_nonstatic_params == params { | ||||||
|                         .lambda(func_name.clone()) |                         // If we don't have any recursive-static params, we can just emit the function as is
 | ||||||
|                         .apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone()))) |                         term = term | ||||||
|                         .lambda(func_name) |                             .lambda(func_name.clone()) | ||||||
|                         .apply(func_body); |                             .apply(Term::var(func_name.clone()).apply(Term::var(func_name.clone()))) | ||||||
|  |                             .lambda(func_name) | ||||||
|  |                             .apply(func_body); | ||||||
|  |                     } else { | ||||||
|  |                         // If we have parameters that remain static in each recursive call,
 | ||||||
|  |                         // we can construct an *outer* function to take those in
 | ||||||
|  |                         // and simplify the recursive part to only accept the non-static arguments
 | ||||||
|  |                         let mut outer_func_body = Term::var(&func_name).apply(Term::var(&func_name)); | ||||||
|  |                         for param in recursive_nonstatic_params.iter() { | ||||||
|  |                             outer_func_body = outer_func_body.apply(Term::var(param)); | ||||||
|  |                         } | ||||||
| 
 | 
 | ||||||
|                     // TODO: use recursive_static_params here
 |                         outer_func_body = outer_func_body.lambda(&func_name).apply(func_body); | ||||||
|  | 
 | ||||||
|  |                         // Now, add *all* parameters, so that other call sites don't know the difference
 | ||||||
|  |                         for param in params.iter().rev() { | ||||||
|  |                             outer_func_body = outer_func_body.lambda(param); | ||||||
|  |                         } | ||||||
|  | 
 | ||||||
|  |                         term = term.lambda(&func_name).apply(outer_func_body); | ||||||
|  |                     } | ||||||
| 
 | 
 | ||||||
|                     arg_stack.push(term); |                     arg_stack.push(term); | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|  | @ -47,7 +47,7 @@ pub enum Air { | ||||||
|         module_name: String, |         module_name: String, | ||||||
|         params: Vec<String>, |         params: Vec<String>, | ||||||
|         recursive: bool, |         recursive: bool, | ||||||
|         recursive_static_params: Vec<String>, |         recursive_nonstatic_params: Vec<String>, | ||||||
|         variant_name: String, |         variant_name: String, | ||||||
|     }, |     }, | ||||||
|     Fn { |     Fn { | ||||||
|  |  | ||||||
|  | @ -583,7 +583,7 @@ pub fn erase_opaque_type_operations( | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String) { | pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, variant: &String, static_recursive_params: &Vec<usize>) { | ||||||
|     if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree { |     if let AirTree::Expression(AirExpression::Call { func, args, .. }) = air_tree { | ||||||
|         if let AirTree::Expression(AirExpression::Var { |         if let AirTree::Expression(AirExpression::Var { | ||||||
|             constructor: |             constructor: | ||||||
|  | @ -599,6 +599,11 @@ pub fn modify_self_calls(air_tree: &mut AirTree, func_key: &FunctionAccessKey, v | ||||||
|                 && module == &func_key.module_name |                 && module == &func_key.module_name | ||||||
|                 && variant == variant_name |                 && variant == variant_name | ||||||
|             { |             { | ||||||
|  |                 // Remove any static-recursive-parameters, because they'll be bound statically
 | ||||||
|  |                 // above the recursive part of the function
 | ||||||
|  |                 for arg in static_recursive_params.iter().rev() { | ||||||
|  |                     args.remove(*arg); | ||||||
|  |                 } | ||||||
|                 let mut new_args = vec![func.as_ref().clone()]; |                 let mut new_args = vec![func.as_ref().clone()]; | ||||||
|                 new_args.append(args); |                 new_args.append(args); | ||||||
|                 *args = new_args; |                 *args = new_args; | ||||||
|  |  | ||||||
|  | @ -129,7 +129,7 @@ pub enum AirStatement { | ||||||
|         module_name: String, |         module_name: String, | ||||||
|         params: Vec<String>, |         params: Vec<String>, | ||||||
|         recursive: bool, |         recursive: bool, | ||||||
|         recursive_static_params: Vec<String>, |         recursive_nonstatic_params: Vec<String>, | ||||||
|         variant_name: String, |         variant_name: String, | ||||||
|         func_body: Box<AirTree>, |         func_body: Box<AirTree>, | ||||||
|     }, |     }, | ||||||
|  | @ -424,7 +424,7 @@ impl AirTree { | ||||||
|         variant_name: impl ToString, |         variant_name: impl ToString, | ||||||
|         params: Vec<String>, |         params: Vec<String>, | ||||||
|         recursive: bool, |         recursive: bool, | ||||||
|         recursive_static_params: Vec<String>, |         recursive_nonstatic_params: Vec<String>, | ||||||
|         func_body: AirTree, |         func_body: AirTree, | ||||||
|     ) -> AirTree { |     ) -> AirTree { | ||||||
|         AirTree::Statement { |         AirTree::Statement { | ||||||
|  | @ -433,7 +433,7 @@ impl AirTree { | ||||||
|                 module_name: module_name.to_string(), |                 module_name: module_name.to_string(), | ||||||
|                 params, |                 params, | ||||||
|                 recursive, |                 recursive, | ||||||
|                 recursive_static_params, |                 recursive_nonstatic_params, | ||||||
|                 variant_name: variant_name.to_string(), |                 variant_name: variant_name.to_string(), | ||||||
|                 func_body: func_body.into(), |                 func_body: func_body.into(), | ||||||
|             }, |             }, | ||||||
|  | @ -878,7 +878,7 @@ impl AirTree { | ||||||
|                         module_name, |                         module_name, | ||||||
|                         params, |                         params, | ||||||
|                         recursive, |                         recursive, | ||||||
|                         recursive_static_params, |                         recursive_nonstatic_params, | ||||||
|                         variant_name, |                         variant_name, | ||||||
|                         func_body, |                         func_body, | ||||||
|                     } => { |                     } => { | ||||||
|  | @ -887,7 +887,7 @@ impl AirTree { | ||||||
|                             module_name: module_name.clone(), |                             module_name: module_name.clone(), | ||||||
|                             params: params.clone(), |                             params: params.clone(), | ||||||
|                             recursive: *recursive, |                             recursive: *recursive, | ||||||
|                             recursive_static_params: recursive_static_params.clone(), |                             recursive_nonstatic_params: recursive_nonstatic_params.clone(), | ||||||
|                             variant_name: variant_name.clone(), |                             variant_name: variant_name.clone(), | ||||||
|                         }); |                         }); | ||||||
|                         func_body.create_air_vec(air_vec); |                         func_body.create_air_vec(air_vec); | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Pi Lanningham
						Pi Lanningham