Fix delay of arguments to be exactly the same as codegen tests

This commit is contained in:
microproofs 2024-06-25 09:21:44 -04:00 committed by Lucas
parent f695276bf7
commit 4bd9125b86
5 changed files with 173 additions and 137 deletions

View File

@ -260,38 +260,31 @@ impl<'a> CodeGenerator<'a> {
let air_value = self.build(value, module_build_name, &[]);
let otherwise_delayed = match (self.tracing, kind) {
(
TraceLevel::Silent,
AssignmentKind::Let { .. } | AssignmentKind::Expect { .. },
) => AirTree::error(void(), false),
let otherwise_delayed = {
let msg = match (self.tracing, kind) {
(TraceLevel::Silent, _) | (_, AssignmentKind::Let { .. }) => "".to_string(),
(TraceLevel::Compact, _) => {
get_line_columns_by_span(module_build_name, location, &self.module_src)
.to_string()
}
(TraceLevel::Verbose, _) => {
get_src_code_by_span(module_build_name, location, &self.module_src)
}
};
(TraceLevel::Compact | TraceLevel::Verbose, AssignmentKind::Let { .. }) => {
AirTree::error(void(), false)
}
let msg_func_name = msg.split_whitespace().join("");
(TraceLevel::Verbose | TraceLevel::Compact, AssignmentKind::Expect { .. }) => {
let msg = match self.tracing {
TraceLevel::Silent => unreachable!("excluded from pattern guards"),
TraceLevel::Compact => {
get_line_columns_by_span(module_build_name, location, &self.module_src)
.to_string()
}
TraceLevel::Verbose => {
get_src_code_by_span(module_build_name, location, &self.module_src)
}
};
self.special_functions.insert_new_function(
msg_func_name.clone(),
if msg.is_empty() {
Term::Error.delay()
} else {
Term::Error.delayed_trace(Term::string(msg)).delay()
},
void(),
);
let msg_func_name = msg.split_whitespace().join("");
self.special_functions.insert_new_function(
msg_func_name.clone(),
Term::Error.delayed_trace(Term::string(msg)).delay(),
void(),
);
self.special_functions.use_function_tree(msg_func_name)
}
self.special_functions.use_function_tree(msg_func_name)
};
let (then, context) = context.split_first().unwrap();
@ -341,6 +334,7 @@ impl<'a> CodeGenerator<'a> {
.map(|arg| arg.arg_name.get_variable_name().unwrap_or("_").to_string())
.collect_vec(),
self.build(body, module_build_name, &[]),
false,
),
TypedExpr::List {
@ -649,7 +643,7 @@ impl<'a> CodeGenerator<'a> {
Some(pattern) => AirTree::let_assignment(
"acc_var",
// use anon function as a delay to avoid evaluating the acc
AirTree::anon_func(vec![], acc),
AirTree::anon_func(vec![], acc, true),
self.assignment(
pattern,
condition,
@ -1350,6 +1344,13 @@ impl<'a> CodeGenerator<'a> {
pattern.location().end
);
let subject_name = format!(
"__subject_{}_span_{}_{}",
name,
pattern.location().start,
pattern.location().end
);
let local_value = AirTree::local_var(&constructor_name, tipo.clone());
let then = if check_replaceable_opaque_type(tipo, &self.data_types) {
@ -1379,11 +1380,17 @@ impl<'a> CodeGenerator<'a> {
panic!("Found constructor type {} with 0 constructors", name)
});
AirTree::assert_constr_index(
index,
AirTree::when(
&subject_name,
void(),
tipo.clone(),
AirTree::local_var(&constructor_name, tipo.clone()),
then,
props.otherwise.clone(),
AirTree::assert_constr_index(
index,
AirTree::local_var(&subject_name, tipo.clone()),
then,
props.otherwise.clone(),
),
)
} else {
assert!(data_type.constructors.len() == 1);
@ -1543,7 +1550,7 @@ impl<'a> CodeGenerator<'a> {
otherwise.clone(),
);
let unwrap_function = AirTree::anon_func(vec![pair_name], anon_func_body);
let unwrap_function = AirTree::anon_func(vec![pair_name], anon_func_body, false);
let function = self.code_gen_functions.get(EXPECT_ON_LIST);
@ -1658,7 +1665,8 @@ impl<'a> CodeGenerator<'a> {
let anon_func_body = expect_item;
let unwrap_function = AirTree::anon_func(vec![item_name], anon_func_body);
let unwrap_function =
AirTree::anon_func(vec![item_name], anon_func_body, false);
let function = self.code_gen_functions.get(EXPECT_ON_LIST);
@ -1844,6 +1852,7 @@ impl<'a> CodeGenerator<'a> {
)
};
// Special case here for future refactoring
AirTree::anon_func(
vec![],
AirTree::assert_constr_index(
@ -1858,6 +1867,7 @@ impl<'a> CodeGenerator<'a> {
then,
acc,
),
true,
)
},
);
@ -1870,7 +1880,7 @@ impl<'a> CodeGenerator<'a> {
format!("__constr_var_span_{}_{}", location.start, location.end),
tipo.clone(),
),
constr_clauses,
AirTree::call(constr_clauses, void(), vec![]),
);
let func_body = AirTree::let_assignment(
@ -3030,7 +3040,7 @@ impl<'a> CodeGenerator<'a> {
itertools::Position::First(arg) if has_context => {
let arg_name = arg.arg_name.get_variable_name().unwrap_or("_").to_string();
AirTree::anon_func(vec![arg_name], inner_then)
AirTree::anon_func(vec![arg_name], inner_then, true)
}
itertools::Position::First(arg)
| itertools::Position::Middle(arg)
@ -3044,33 +3054,32 @@ impl<'a> CodeGenerator<'a> {
let actual_type = convert_opaque_type(&arg.tipo, &self.data_types, true);
let otherwise_delayed = match self.tracing {
TraceLevel::Silent => AirTree::error(void(), false),
TraceLevel::Compact | TraceLevel::Verbose => {
let msg = match self.tracing {
TraceLevel::Silent => {
unreachable!("excluded from pattern guards")
}
TraceLevel::Compact => lines
.line_and_column_number(arg_span.start)
.expect("Out of bounds span")
.to_string(),
TraceLevel::Verbose => src_code
.get(arg_span.start..arg_span.end)
.expect("Out of bounds span")
.to_string(),
};
let otherwise_delayed = {
let msg = match self.tracing {
TraceLevel::Silent => "".to_string(),
TraceLevel::Compact => lines
.line_and_column_number(arg_span.start)
.expect("Out of bounds span")
.to_string(),
TraceLevel::Verbose => src_code
.get(arg_span.start..arg_span.end)
.expect("Out of bounds span")
.to_string(),
};
let msg_func_name = msg.split_whitespace().join("");
let msg_func_name = msg.split_whitespace().join("");
self.special_functions.insert_new_function(
msg_func_name.to_string(),
Term::Error.delayed_trace(Term::string(msg)).delay(),
void(),
);
self.special_functions.insert_new_function(
msg_func_name.clone(),
if msg.is_empty() {
Term::Error.delay()
} else {
Term::Error.delayed_trace(Term::string(msg)).delay()
},
void(),
);
self.special_functions.use_function_tree(msg_func_name)
}
self.special_functions.use_function_tree(msg_func_name)
};
let inner_then = self.assignment(
@ -3090,7 +3099,7 @@ impl<'a> CodeGenerator<'a> {
},
);
AirTree::anon_func(vec![arg_name], inner_then)
AirTree::anon_func(vec![arg_name], inner_then, true)
}
itertools::Position::Only(_) => unreachable!(),
})
@ -3860,7 +3869,7 @@ impl<'a> CodeGenerator<'a> {
let mut function_variant_path = IndexMap::new();
let mut body = body.clone();
let mut body = AirTree::no_op(body.clone());
body.traverse_tree_with(
&mut |air_tree, _| {
@ -4031,7 +4040,7 @@ impl<'a> CodeGenerator<'a> {
fn gen_uplc(&mut self, ir: Air, arg_stack: &mut Vec<Term<Name>>) -> Option<Term<Name>> {
let convert_data_to_type = |term, tipo, otherwise| {
if otherwise == Term::Error {
if otherwise == Term::Error.delay() {
builder::unknown_data_to_type(term, tipo)
} else {
builder::unknown_data_to_type_otherwise(term, tipo, otherwise)
@ -4340,7 +4349,7 @@ impl<'a> CodeGenerator<'a> {
let otherwise = if matches!(expect_level, ExpectLevel::Full | ExpectLevel::Items) {
arg_stack.pop().unwrap()
} else {
Term::Error
Term::Error.delay()
};
let list_id = self.id_gen.next();
@ -4403,17 +4412,25 @@ impl<'a> CodeGenerator<'a> {
Some(term)
}
Air::Fn { params } => {
Air::Fn {
params,
allow_inline,
} => {
let mut term = arg_stack.pop().unwrap();
for param in params.iter().rev() {
term = term.lambda(param);
}
term = if allow_inline {
term
} else {
term.lambda(NO_INLINE)
};
if params.is_empty() {
Some(term.lambda(NO_INLINE).delay())
Some(term.delay())
} else {
Some(term.lambda(NO_INLINE))
Some(term)
}
}
Air::Call { count, .. } => {
@ -4431,61 +4448,62 @@ impl<'a> CodeGenerator<'a> {
// How we handle zero arg anon functions has changed
// We now delay zero arg anon functions and force them on a call operation
if let Term::Var(name) = &term {
let zero_arg_functions = self.zero_arg_functions.clone();
let text = &name.text;
match &term {
Term::Var(name) => {
let zero_arg_functions = self.zero_arg_functions.clone();
let text = &name.text;
if let Some((_, air_vec)) = zero_arg_functions.iter().find(
|(
(
FunctionAccessKey {
module_name,
function_name,
},
variant,
),
_,
)| {
let name_module = format!("{module_name}_{function_name}{variant}");
let name = format!("{function_name}{variant}");
if let Some((_, air_vec)) = zero_arg_functions.iter().find(
|(
(
FunctionAccessKey {
module_name,
function_name,
},
variant,
),
_,
)| {
let name_module =
format!("{module_name}_{function_name}{variant}");
let name = format!("{function_name}{variant}");
text == &name || text == &name_module
},
) {
let mut term = self.uplc_code_gen(air_vec.clone());
text == &name || text == &name_module
},
) {
let mut term = self.uplc_code_gen(air_vec.clone());
term = term.constr_fields_exposer().constr_index_exposer();
term = term.constr_fields_exposer().constr_index_exposer();
let mut program: Program<Name> = Program {
version: (1, 0, 0),
term: self.special_functions.apply_used_functions(term),
};
let mut program: Program<Name> = Program {
version: (1, 0, 0),
term: self.special_functions.apply_used_functions(term),
};
let mut interner = CodeGenInterner::new();
let mut interner = CodeGenInterner::new();
interner.program(&mut program);
interner.program(&mut program);
let eval_program: Program<NamedDeBruijn> =
program.remove_no_inlines().try_into().unwrap();
let eval_program: Program<NamedDeBruijn> =
program.remove_no_inlines().try_into().unwrap();
let result = eval_program.eval(ExBudget::max()).result();
let result = eval_program.eval(ExBudget::max()).result();
let evaluated_term: Term<NamedDeBruijn> = result.unwrap_or_else(|e| {
panic!("Evaluated a zero argument function and received this error: {e:#?}")
});
let evaluated_term: Term<NamedDeBruijn> = result.unwrap_or_else(|e| {
panic!("Evaluated a zero argument function and received this error: {e:#?}")
});
Some(evaluated_term.try_into().unwrap())
} else {
Some(term.force())
Some(evaluated_term.try_into().unwrap())
} else {
Some(term.force())
}
}
} else if let Term::Apply { .. } = &term {
// Case for mutually recursive zero arg functions
Some(term.force())
} else {
unreachable!(
"Shouldn't call anything other than var or apply {:#?}",
Term::Delay(inner_term) => Some(inner_term.as_ref().clone()),
Term::Apply { .. } => Some(term.force()),
_ => unreachable!(
"Shouldn't call anything other than var or apply\n{:#?}",
term
)
),
}
}
}
@ -4784,7 +4802,7 @@ impl<'a> CodeGenerator<'a> {
let otherwise = if full_cast {
arg_stack.pop().unwrap()
} else {
Term::Error
Term::Error.delay()
};
term = if full_cast {
@ -4854,13 +4872,7 @@ impl<'a> CodeGenerator<'a> {
term = Term::equals_integer()
.apply(Term::integer(constr_index.into()))
.apply(
Term::var(
self.special_functions
.use_function_uplc(CONSTR_INDEX_EXPOSER.to_string()),
)
.apply(constr),
)
.apply(constr)
.if_then_else(term.delay(), otherwise)
.force();
@ -5309,7 +5321,7 @@ impl<'a> CodeGenerator<'a> {
let otherwise = if is_expect {
arg_stack.pop().unwrap()
} else {
Term::Error
Term::Error.delay()
};
let list_id = self.id_gen.next();
@ -5577,7 +5589,7 @@ impl<'a> CodeGenerator<'a> {
let otherwise = if is_expect {
arg_stack.pop().unwrap()
} else {
Term::Error
Term::Error.delay()
};
let list_id = self.id_gen.next();
@ -5620,7 +5632,7 @@ impl<'a> CodeGenerator<'a> {
let otherwise = if is_expect {
arg_stack.pop().unwrap()
} else {
Term::Error
Term::Error.delay()
};
let list_id = self.id_gen.next();

View File

@ -81,6 +81,7 @@ pub enum Air {
},
Fn {
params: Vec<String>,
allow_inline: bool,
},
Builtin {
count: usize,

View File

@ -659,6 +659,7 @@ pub fn modify_cyclic_calls(
AirTree::anon_func(
names.clone(),
AirTree::local_var(index_name, tipo),
false,
),
],
);
@ -1172,7 +1173,7 @@ pub fn unknown_data_to_type_otherwise(
.choose_data(
Term::snd_pair()
.apply(Term::var("__pair__"))
.delayed_choose_list(
.choose_list(
Term::equals_integer()
.apply(Term::integer(1.into()))
.apply(Term::fst_pair().apply(Term::var("__pair__")))
@ -1181,13 +1182,16 @@ pub fn unknown_data_to_type_otherwise(
Term::equals_integer()
.apply(Term::integer(0.into()))
.apply(Term::fst_pair().apply(Term::var("__pair__")))
.delayed_if_then_else(
Term::bool(false),
.if_then_else(
Term::bool(false).delay(),
otherwise_delayed.clone(),
),
),
)
.force(),
)
.delay(),
otherwise_delayed.clone(),
)
.force()
.lambda("__pair__")
.apply(Term::unconstr_data().apply(Term::var("__val")))
.delay(),
@ -1204,12 +1208,15 @@ pub fn unknown_data_to_type_otherwise(
Term::equals_integer()
.apply(Term::integer(0.into()))
.apply(Term::fst_pair().apply(Term::unconstr_data().apply(Term::var("__val"))))
.delayed_if_then_else(
.if_then_else(
Term::snd_pair()
.apply(Term::unconstr_data().apply(Term::var("__val")))
.delayed_choose_list(Term::unit(), otherwise_delayed.clone()),
.choose_list(Term::unit().delay(), otherwise_delayed.clone())
.force()
.delay(),
otherwise_delayed.clone(),
)
.force()
.delay(),
otherwise_delayed.clone(),
otherwise_delayed.clone(),
@ -1442,7 +1449,7 @@ pub fn list_access_to_uplc(
Term::head_list().apply(Term::var(tail_name.to_string()))
} else if matches!(expect_level, ExpectLevel::Full) {
// Expect level is full so we have an unknown piece of data to cast
if otherwise_delayed == Term::Error {
if otherwise_delayed == Term::Error.delay() {
unknown_data_to_type(
Term::head_list().apply(Term::var(tail_name.to_string())),
&tipo.to_owned(),
@ -1486,7 +1493,7 @@ pub fn list_access_to_uplc(
ExpectLevel::None => acc.lambda(name).apply(head_item).lambda(tail_name),
ExpectLevel::Full | ExpectLevel::Items => {
if otherwise_delayed == Term::Error && tail_present {
if otherwise_delayed == Term::Error.delay() && tail_present {
// No need to check last item if tail was present
acc.lambda(name).apply(head_item).lambda(tail_name)
} else if tail_present {
@ -1498,11 +1505,11 @@ pub fn list_access_to_uplc(
)
.force()
.lambda(tail_name)
} else if otherwise_delayed == Term::Error {
} else if otherwise_delayed == Term::Error.delay() {
// Check head is last item in this list
Term::tail_list()
.apply(Term::var(tail_name.to_string()))
.choose_list(acc.delay(), otherwise_delayed.clone())
.choose_list(acc.delay(), Term::Error.delay())
.force()
.lambda(name)
.apply(head_item)
@ -1533,7 +1540,8 @@ pub fn list_access_to_uplc(
let head_item = head_item(name, tipo, &tail_name);
if matches!(expect_level, ExpectLevel::None) || otherwise_delayed == Term::Error
if matches!(expect_level, ExpectLevel::None)
|| otherwise_delayed == Term::Error.delay()
{
acc.apply(Term::tail_list().apply(Term::var(tail_name.to_string())))
.lambda(name)

View File

@ -288,6 +288,7 @@ pub enum AirTree {
Fn {
params: Vec<String>,
func_body: Box<AirTree>,
allow_inline: bool,
},
Builtin {
func: DefaultFunction,
@ -538,10 +539,11 @@ impl AirTree {
}
}
pub fn anon_func(params: Vec<String>, func_body: AirTree) -> AirTree {
pub fn anon_func(params: Vec<String>, func_body: AirTree, allow_inline: bool) -> AirTree {
AirTree::Fn {
params,
func_body: func_body.into(),
allow_inline,
}
}
@ -1388,9 +1390,14 @@ impl AirTree {
arg.create_air_vec(air_vec);
}
}
AirTree::Fn { params, func_body } => {
AirTree::Fn {
params,
func_body,
allow_inline,
} => {
air_vec.push(Air::Fn {
params: params.clone(),
allow_inline: *allow_inline,
});
func_body.create_air_vec(air_vec);
}
@ -2184,6 +2191,7 @@ impl AirTree {
AirTree::Fn {
params: _,
func_body,
allow_inline: _,
} => {
func_body.do_traverse_tree_with(
tree_path,
@ -2920,6 +2928,7 @@ impl AirTree {
AirTree::Fn {
params: _,
func_body,
allow_inline: _,
} => match field {
Fields::SecondField => func_body.as_mut().do_find_air_tree_node(tree_path_iter),
_ => panic!("Tree Path index outside tree children nodes"),

View File

@ -1029,6 +1029,12 @@ impl Program<Name> {
if let Some((arg_id, arg_term)) = arg_stack.pop() {
match &arg_term {
Term::Constant(c) if matches!(c.as_ref(), Constant::String(_)) => {}
Term::Delay(e) if matches!(e.as_ref(), Term::Error) => {
let body = Rc::make_mut(body);
lambda_applied_ids.push(arg_id);
// creates new body that replaces all var occurrences with the arg
*term = substitute_var(body, parameter_name.clone(), &arg_term);
}
Term::Constant(_) | Term::Var(_) | Term::Builtin(_) => {
let body = Rc::make_mut(body);
lambda_applied_ids.push(arg_id);