Code gen now handles expecting on validator args in the air stack.

Thus allowing us to use code gen created functions to expect on data types including recursive ones.
Some minor tweaks to the air.
Added a uplc optimization for later.
This commit is contained in:
Kasey White 2023-04-07 02:25:18 -04:00 committed by Kasey
parent 4e4eed13e1
commit f8483da4e0
5 changed files with 147 additions and 94 deletions

View File

@ -97,7 +97,13 @@ impl<'a> CodeGenerator<'a> {
) -> Program<Name> {
let mut ir_stack = AirStack::new(self.id_gen.clone());
ir_stack.validator(fun.arguments.clone());
ir_stack.noop();
let mut args_stack = ir_stack.empty_with_scope();
self.wrap_validator_args(&mut args_stack, &fun.arguments, true);
ir_stack.merge_child(args_stack);
self.build(&fun.body, &mut ir_stack);
@ -113,7 +119,14 @@ impl<'a> CodeGenerator<'a> {
self.reset();
let mut other_ir_stack = AirStack::new(self.id_gen.clone());
other_ir_stack.validator(other.arguments.clone());
other_ir_stack.noop();
let mut other_args_stack = other_ir_stack.empty_with_scope();
self.wrap_validator_args(&mut other_args_stack, &other.arguments, true);
other_ir_stack.merge_child(other_args_stack);
self.build(&other.body, &mut other_ir_stack);
@ -132,6 +145,7 @@ impl<'a> CodeGenerator<'a> {
};
term = builder::wrap_as_multi_validator(spend, mint);
self.needs_field_access = true;
}
@ -143,7 +157,7 @@ impl<'a> CodeGenerator<'a> {
pub fn generate_test(&mut self, test_body: &TypedExpr) -> Program<Name> {
let mut ir_stack = AirStack::new(self.id_gen.clone());
ir_stack.validator(vec![]);
ir_stack.noop();
self.build(test_body, &mut ir_stack);
@ -2287,10 +2301,7 @@ impl<'a> CodeGenerator<'a> {
expect_list_stack.expect_on_list();
self.code_gen_functions.insert(
EXPECT_ON_LIST.to_string(),
CodeGenFunction::Function(
expect_list_stack.complete(),
vec!["__list_to_check".to_string(), "__check_with".to_string()],
),
CodeGenFunction::Function(expect_list_stack.complete(), vec![]),
);
}
@ -2329,10 +2340,7 @@ impl<'a> CodeGenerator<'a> {
expect_list_stack.expect_on_list();
self.code_gen_functions.insert(
EXPECT_ON_LIST.to_string(),
CodeGenFunction::Function(
expect_list_stack.complete(),
vec!["__list_to_check".to_string(), "__check_with".to_string()],
),
CodeGenFunction::Function(expect_list_stack.complete(), vec![]),
);
}
@ -2386,6 +2394,9 @@ impl<'a> CodeGenerator<'a> {
if function.is_none() && defined_data_types.get(&data_type_name).is_none() {
defined_data_types.insert(data_type_name.clone(), 1);
let current_defined_state = defined_data_types.clone();
let mut diff_defined_types = IndexMap::new();
let mut clause_stack = expect_stack.empty_with_scope();
let mut when_stack = expect_stack.empty_with_scope();
let mut trace_stack = expect_stack.empty_with_scope();
@ -2433,17 +2444,20 @@ impl<'a> CodeGenerator<'a> {
for (_index, name, tipo) in arg_indices.clone() {
let mut call_stack = expect_stack.empty_with_scope();
let mut inner_defined_types = IndexMap::new();
self.expect_type(&tipo, &mut call_stack, &name, &mut inner_defined_types);
self.expect_type(&tipo, &mut call_stack, &name, defined_data_types);
arg_stack.merge_child(call_stack);
for (inner_data_type, inner_count) in inner_defined_types {
if let Some(count) = defined_data_types.get_mut(&inner_data_type) {
*count += inner_count
for (inner_data_type, inner_count) in defined_data_types.iter() {
if let Some(prev_count) = current_defined_state.get(inner_data_type) {
diff_defined_types.insert(
inner_data_type.to_string(),
*inner_count - *prev_count,
);
} else {
defined_data_types.insert(inner_data_type, inner_count);
diff_defined_types
.insert(inner_data_type.to_string(), *inner_count);
}
}
}
@ -2474,7 +2488,7 @@ impl<'a> CodeGenerator<'a> {
trace_stack,
);
let recursive = *defined_data_types.get(&data_type_name).unwrap() >= 1;
let recursive = *diff_defined_types.get(&data_type_name).unwrap() > 0;
data_type_stack.define_func(
&data_type_name,
@ -2489,13 +2503,20 @@ impl<'a> CodeGenerator<'a> {
data_type_name.clone(),
CodeGenFunction::Function(
data_type_stack.complete(),
defined_data_types
.keys()
.cloned()
.filter(|x| x != &data_type_name)
diff_defined_types
.into_iter()
.filter(|(dt, counter)| dt != &data_type_name && *counter > 0)
.map(|(x, _)| x)
.collect_vec(),
),
);
} else if defined_data_types.get(&data_type_name).is_some() && function.is_none() {
let Some(counter) = defined_data_types.get_mut(&data_type_name)
else {
unreachable!();
};
*counter += 1;
}
func_stack.var(
@ -2830,14 +2851,18 @@ impl<'a> CodeGenerator<'a> {
air: vec![],
};
func_stack.define_func(
function_access_key.function_name.clone(),
function_access_key.module_name.clone(),
function_access_key.variant_name.clone(),
func_comp.args.clone(),
func_comp.recursive,
recursion_stack,
);
if func_comp.is_code_gen_func {
func_stack = recursion_stack
} else {
func_stack.define_func(
function_access_key.function_name.clone(),
function_access_key.module_name.clone(),
function_access_key.variant_name.clone(),
func_comp.args.clone(),
func_comp.recursive,
recursion_stack,
);
}
full_func_ir.extend(func_stack.complete());
@ -2892,12 +2917,7 @@ impl<'a> CodeGenerator<'a> {
variant_name: variant_name.clone(),
};
if recursion_func_map.contains_key(&FunctionAccessKey {
module_name: module.clone(),
function_name: func_name.clone(),
variant_name: variant_name.clone(),
}) && func == &ir_function_key
{
if recursion_func_map.contains_key(&ir_function_key) && func == &ir_function_key {
skip = true;
} else if func == &ir_function_key {
recursion_func_map_to_add.insert(ir_function_key, ());
@ -3157,12 +3177,14 @@ impl<'a> CodeGenerator<'a> {
}
};
let function_key = FunctionAccessKey {
module_name: "".to_string(),
function_name: name.to_string(),
variant_name: "".to_string(),
};
func_components.insert(
FunctionAccessKey {
module_name: "".to_string(),
function_name: name.to_string(),
variant_name: "".to_string(),
},
function_key.clone(),
FuncComponents {
ir: func_ir,
dependencies: dependencies
@ -3174,11 +3196,13 @@ impl<'a> CodeGenerator<'a> {
})
.collect_vec(),
recursive: false,
args: vec![],
args: vec!["__one".to_string()],
defined_by_zero_arg: in_zero_arg_func,
is_code_gen_func: true,
},
);
to_be_defined_map.insert(function_key, scope.clone());
} else {
unreachable!("We found a function with no definitions");
}
@ -4886,33 +4910,22 @@ impl<'a> CodeGenerator<'a> {
}
arg_stack.push(term);
}
Air::Validator { params, .. } => {
// Wrap the validator body if ifThenElse term unit error
let mut term = arg_stack.pop().unwrap();
term = term.final_wrapper();
term = self.wrap_validator_args(term, &params, true);
arg_stack.push(term);
}
Air::NoOp { .. } => {}
}
}
pub fn wrap_validator_args(
&mut self,
term: Term<Name>,
validator_stack: &mut AirStack,
arguments: &[TypedArg],
has_context: bool,
) -> Term<Name> {
let mut term = term;
) {
let mut arg_stack = validator_stack.empty_with_scope();
for (index, arg) in arguments.iter().enumerate().rev() {
if !(has_context && index == arguments.len() - 1)
&& arg.arg_name.get_variable_name().unwrap_or("_") != "_"
{
let mut air_stack = AirStack::new(self.id_gen.clone());
let mut param_stack = air_stack.empty_with_scope();
let arg_name = arg.arg_name.get_variable_name().unwrap_or("_").to_string();
if !(has_context && index == arguments.len() - 1) && &arg_name != "_" {
let mut param_stack = validator_stack.empty_with_scope();
let mut value_stack = validator_stack.empty_with_scope();
param_stack.local_var(data(), arg.arg_name.get_variable_name().unwrap_or("_"));
@ -4923,9 +4936,9 @@ impl<'a> CodeGenerator<'a> {
self.assignment(
&Pattern::Var {
location: Span::empty(),
name: arg.arg_name.get_variable_name().unwrap_or("_").to_string(),
name: arg_name.to_string(),
},
&mut air_stack,
&mut value_stack,
param_stack,
&actual_type,
AssignmentProperties {
@ -4933,20 +4946,19 @@ impl<'a> CodeGenerator<'a> {
kind: AssignmentKind::Expect,
},
);
air_stack.local_var(
actual_type,
arg.arg_name.get_variable_name().unwrap_or("_").to_string(),
);
value_stack.local_var(actual_type, &arg_name);
let mut air_vec = air_stack.complete();
term = term
.lambda(arg.arg_name.get_variable_name().unwrap_or("_"))
.apply(self.uplc_code_gen(&mut air_vec));
arg_stack.let_assignment(arg_name, value_stack);
}
term = term.lambda(arg.arg_name.get_variable_name().unwrap_or("_"))
}
term
validator_stack.anonymous_function(
arguments
.iter()
.map(|arg| arg.arg_name.get_variable_name().unwrap_or("_").to_string())
.collect_vec(),
arg_stack,
)
}
}

View File

@ -3,7 +3,7 @@ use std::sync::Arc;
use uplc::builtins::DefaultFunction;
use crate::{
ast::{Arg, BinOp, UnOp},
ast::{BinOp, UnOp},
tipo::{Type, ValueConstructor},
};
@ -213,9 +213,8 @@ pub enum Air {
scope: Scope,
tipo: Arc<Type>,
},
Validator {
NoOp {
scope: Scope,
params: Vec<Arg<Arc<Type>>>,
},
FieldsEmpty {
scope: Scope,
@ -264,7 +263,7 @@ impl Air {
| Air::TupleIndex { scope, .. }
| Air::ErrorTerm { scope, .. }
| Air::Trace { scope, .. }
| Air::Validator { scope, .. } => scope.clone(),
| Air::NoOp { scope, .. } => scope.clone(),
}
}
pub fn scope_mut(&mut self) -> &mut Scope {
@ -308,7 +307,7 @@ impl Air {
| Air::TupleIndex { scope, .. }
| Air::ErrorTerm { scope, .. }
| Air::Trace { scope, .. }
| Air::Validator { scope, .. } => scope,
| Air::NoOp { scope, .. } => scope,
}
}
pub fn tipo(&self) -> Option<Arc<Type>> {
@ -399,7 +398,7 @@ impl Air {
| Air::Finally { .. }
| Air::FieldsExpose { .. }
| Air::FieldsEmpty { .. }
| Air::Validator { .. } => None,
| Air::NoOp { .. } => None,
Air::UnOp { op, .. } => match op {
UnOp::Not => Some(
Type::App {

View File

@ -1492,14 +1492,18 @@ pub fn handle_func_dependencies(
air: recursion_ir,
};
temp_stack.define_func(
dependency.function_name.clone(),
dependency.module_name.clone(),
dependency.variant_name.clone(),
depend_comp.args.clone(),
depend_comp.recursive,
recursion_stack,
);
if depend_comp.is_code_gen_func {
temp_stack = recursion_stack;
} else {
temp_stack.define_func(
dependency.function_name.clone(),
dependency.module_name.clone(),
dependency.variant_name.clone(),
depend_comp.args.clone(),
depend_comp.recursive,
recursion_stack,
);
}
let mut temp_ir = temp_stack.complete();
@ -1550,7 +1554,7 @@ pub fn handle_recursion_ir(
tipo,
}
}
_ => unreachable!(),
_ => unreachable!("Will support not using call right away later."),
}
}
}

View File

@ -5,7 +5,7 @@ use indexmap::IndexSet;
use uplc::{builder::EXPECT_ON_LIST, builtins::DefaultFunction};
use crate::{
ast::{Arg, Span},
ast::Span,
builtins::{data, list, void},
tipo::{Type, ValueConstructor, ValueConstructorVariant},
IdGenerator,
@ -701,12 +701,11 @@ impl AirStack {
self.merge_child(body_stack);
}
pub fn validator(&mut self, params: Vec<Arg<Arc<Type>>>) {
pub fn noop(&mut self) {
self.new_scope();
self.air.push(Air::Validator {
self.air.push(Air::NoOp {
scope: self.scope.clone(),
params,
});
}

View File

@ -56,6 +56,15 @@ impl Program<Name> {
term,
}
}
pub fn force_delay_reduce(self) -> Program<Name> {
let mut term = self.term.clone();
force_delay_reduce(&mut term);
Program {
version: self.version,
term,
}
}
}
fn builtin_force_reduce(term: &mut Term<Name>, builtin_map: &mut IndexMap<u8, ()>) {
@ -112,6 +121,36 @@ fn builtin_force_reduce(term: &mut Term<Name>, builtin_map: &mut IndexMap<u8, ()
}
}
fn force_delay_reduce(term: &mut Term<Name>) {
match term {
Term::Force(f) => {
let f = Rc::make_mut(f);
if let Term::Delay(body) = f {
*term = body.as_ref().clone();
} else {
force_delay_reduce(f);
}
}
Term::Delay(d) => {
let d = Rc::make_mut(d);
force_delay_reduce(d);
}
Term::Lambda { body, .. } => {
let body = Rc::make_mut(body);
force_delay_reduce(body);
}
Term::Apply { function, argument } => {
let func = Rc::make_mut(function);
force_delay_reduce(func);
let arg = Rc::make_mut(argument);
force_delay_reduce(arg);
}
_ => {}
}
}
fn inline_basic_reduce(term: &mut Term<Name>) {
match term {
Term::Delay(d) => {
@ -222,7 +261,7 @@ fn lambda_reduce(term: &mut Term<Name>) {
body,
} = func
{
if let replace_term @ (Term::Var(_) | Term::Constant(_)) = arg{
if let replace_term @ (Term::Var(_) | Term::Constant(_)) = arg {
let body = Rc::make_mut(body);
*term = substitute_term(body, parameter_name.clone(), replace_term);
}