fix list clauses with guards and add more tests

This commit is contained in:
microproofs 2023-08-16 16:31:44 -04:00 committed by Kasey
parent f4d0f231d7
commit 2456801b17
5 changed files with 463 additions and 22 deletions

View File

@ -24,9 +24,9 @@ use crate::{
expr::TypedExpr,
gen_uplc::builder::{
convert_opaque_type, erase_opaque_type_operations, find_and_replace_generics,
get_arg_type_name, get_generic_id_and_type, get_variant_name, monomorphize,
pattern_has_conditions, wrap_as_multi_validator, wrap_validator_condition, CodeGenFunction,
SpecificClause,
find_list_clause_or_default_first, get_arg_type_name, get_generic_id_and_type,
get_variant_name, monomorphize, pattern_has_conditions, wrap_as_multi_validator,
wrap_validator_condition, CodeGenFunction, SpecificClause,
},
tipo::{
ModuleValueConstructor, PatternConstructor, Type, TypeInfo, ValueConstructor,
@ -1703,8 +1703,12 @@ impl<'a> CodeGenerator<'a> {
);
};
assert!(!elements.is_empty() || tail.is_none());
let elements_len = elements.len() + usize::from(tail.is_none()) - 1;
let current_checked_index = *checked_index;
let tail_name = defined_tails
.last()
.get(elements_len)
.cloned()
.unwrap_or(props.original_subject_name.clone());
@ -1712,8 +1716,8 @@ impl<'a> CodeGenerator<'a> {
if rest_clauses.is_empty() {
None
} else {
let next_clause = &rest_clauses[0];
let mut next_clause_pattern = &rest_clauses[0].pattern;
let next_clause = find_list_clause_or_default_first(rest_clauses);
let mut next_clause_pattern = &next_clause.pattern;
if let Pattern::Assign { pattern, .. } = next_clause_pattern {
next_clause_pattern = pattern;
@ -1729,6 +1733,7 @@ impl<'a> CodeGenerator<'a> {
if (*defined_tails_index as usize) < next_elements_len {
*defined_tails_index += 1;
let current_defined_tail = defined_tails.last().unwrap().clone();
defined_tails.push(format!(
"tail_index_{}_span_{}_{}",
@ -1737,11 +1742,14 @@ impl<'a> CodeGenerator<'a> {
next_clause.pattern.location().end
));
Some(format!(
"tail_index_{}_span_{}_{}",
*defined_tails_index,
next_clause.pattern.location().start,
next_clause.pattern.location().end
Some((
current_defined_tail,
format!(
"tail_index_{}_span_{}_{}",
*defined_tails_index,
next_clause.pattern.location().start,
next_clause.pattern.location().end
),
))
} else {
None
@ -1754,9 +1762,6 @@ impl<'a> CodeGenerator<'a> {
is_wild_card_elems_clause =
is_wild_card_elems_clause && !pattern_has_conditions(element);
}
assert!(!elements.is_empty() || tail.is_none());
let elements_len = elements.len() + usize::from(tail.is_none()) - 1;
let current_checked_index = *checked_index;
if *checked_index < elements_len.try_into().unwrap()
&& is_wild_card_elems_clause
@ -1784,7 +1789,9 @@ impl<'a> CodeGenerator<'a> {
let complex_clause = props.complex_clause;
if current_checked_index < elements_len.try_into().unwrap() {
if current_checked_index < elements_len.try_into().unwrap()
|| next_tail_name.is_some()
{
AirTree::list_clause(
tail_name,
subject_tipo.clone(),
@ -3996,9 +4003,9 @@ impl<'a> CodeGenerator<'a> {
let body = arg_stack.pop().unwrap();
let mut term = arg_stack.pop().unwrap();
let arg = if let Some(next_tail_name) = next_tail_name {
let arg = if let Some((current_tail, next_tail_name)) = next_tail_name {
term.lambda(next_tail_name)
.apply(Term::tail_list().apply(Term::var(tail_name.clone())))
.apply(Term::tail_list().apply(Term::var(current_tail.clone())))
} else {
term
};

View File

@ -97,7 +97,7 @@ pub enum Air {
ListClause {
subject_tipo: Arc<Type>,
tail_name: String,
next_tail_name: Option<String>,
next_tail_name: Option<(String, String)>,
complex_clause: bool,
},
WrapClause,

View File

@ -810,7 +810,18 @@ pub fn rearrange_list_clauses(clauses: Vec<TypedClause>) -> Vec<TypedClause> {
.into_iter()
.enumerate()
.sorted_by(|(index1, clause1), (index2, clause2)| {
let clause1_len = match &clause1.pattern {
let mut clause_pattern1 = &clause1.pattern;
let mut clause_pattern2 = &clause2.pattern;
if let Pattern::Assign { pattern, .. } = clause_pattern1 {
clause_pattern1 = pattern;
}
if let Pattern::Assign { pattern, .. } = clause_pattern2 {
clause_pattern2 = pattern;
}
let clause1_len = match clause_pattern1 {
Pattern::List { elements, tail, .. } => {
Some(elements.len() + usize::from(tail.is_some() && clause1.guard.is_none()))
}
@ -818,7 +829,7 @@ pub fn rearrange_list_clauses(clauses: Vec<TypedClause>) -> Vec<TypedClause> {
_ => None,
};
let clause2_len = match &clause2.pattern {
let clause2_len = match clause_pattern2 {
Pattern::List { elements, tail, .. } => {
Some(elements.len() + usize::from(tail.is_some() && clause2.guard.is_none()))
}
@ -1013,6 +1024,19 @@ pub fn rearrange_list_clauses(clauses: Vec<TypedClause>) -> Vec<TypedClause> {
final_clauses
}
pub fn find_list_clause_or_default_first(clauses: &[TypedClause]) -> &TypedClause {
assert!(!clauses.is_empty());
clauses
.iter()
.find(|clause| match &clause.pattern {
Pattern::List { .. } => true,
Pattern::Assign { pattern, .. } if matches!(&**pattern, Pattern::List { .. }) => true,
_ => false,
})
.unwrap_or(&clauses[0])
}
pub fn convert_data_to_type(term: Term<Name>, field_type: &Arc<Type>) -> Term<Name> {
if field_type.is_int() {
Term::un_i_data().apply(term)

View File

@ -282,7 +282,7 @@ pub enum AirExpression {
ListClause {
subject_tipo: Arc<Type>,
tail_name: String,
next_tail_name: Option<String>,
next_tail_name: Option<(String, String)>,
complex_clause: bool,
then: Box<AirTree>,
otherwise: Box<AirTree>,
@ -546,7 +546,7 @@ impl AirTree {
subject_tipo: Arc<Type>,
then: AirTree,
otherwise: AirTree,
next_tail_name: Option<String>,
next_tail_name: Option<(String, String)>,
complex_clause: bool,
) -> AirTree {
AirTree::Expression(AirExpression::ListClause {

View File

@ -4523,3 +4523,413 @@ fn list_clause_with_guard() {
false,
);
}
#[test]
fn list_clause_with_guard2() {
let src = r#"
fn do_init(self: List<Int>) -> List<Int> {
when self is {
[] -> fail @"unreachable"
[_] ->
[]
[a, x] -> {
[a]
}
[a] if a > 10 -> []
[a, b, ..c] -> {
c
}
}
}
test init_3() {
do_init([1, 3]) == [1]
}
"#;
assert_uplc(
src,
Term::equals_data()
.apply(
Term::list_data().apply(
Term::var("do_init")
.lambda("do_init")
.apply(
Term::var("self")
.delayed_choose_list(
Term::Error.trace(Term::string("unreachable")),
Term::var("tail_1")
.delayed_choose_list(
Term::empty_list(),
Term::var("tail_1")
.choose_list(
Term::var("clause_guard")
.if_else(
Term::empty_list().delay(),
Term::var("clauses_delayed"),
)
.force()
.lambda("clause_guard")
.apply(
Term::less_than_integer()
.apply(Term::integer(10.into()))
.apply(Term::var("a")),
)
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list()
.apply(Term::var("self")),
),
)
.delay(),
Term::var("clauses_delayed"),
)
.force()
.lambda("clauses_delayed")
.apply(
Term::var("tail_2")
.delayed_choose_list(
Term::mk_cons()
.apply(
Term::i_data()
.apply(Term::var("a")),
)
.apply(Term::empty_list())
.lambda("x")
.apply(
Term::un_i_data().apply(
Term::head_list().apply(
Term::var("tail_1"),
),
),
)
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list().apply(
Term::var("self"),
),
),
),
Term::var("c").lambda("c").apply(
Term::tail_list()
.apply(Term::var("tail_1"))
.lambda("b")
.apply(Term::un_i_data().apply(
Term::head_list().apply(
Term::var("tail_1"),
),
))
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list()
.apply(Term::var(
"self",
)),
),
),
),
)
.lambda("tail_2")
.apply(
Term::tail_list()
.apply(Term::var("tail_1")),
)
.delay(),
),
)
.lambda("tail_1")
.apply(Term::tail_list().apply(Term::var("self"))),
)
.lambda("self"),
)
.apply(Term::list_values(vec![
Constant::Data(Data::integer(1.into())),
Constant::Data(Data::integer(3.into())),
])),
),
)
.apply(Term::data(Data::list(vec![Data::integer(1.into())]))),
false,
);
}
#[test]
fn list_clause_with_guard3() {
let src = r#"
fn do_init(self: List<Int>) -> List<Int> {
when self is {
[] -> fail @"unreachable"
[_] ->
[]
[a, x] -> {
[a]
}
[a, ..g] if a > 10 -> g
[a, b, ..c] -> {
c
}
}
}
test init_3() {
do_init([1, 3]) == [1]
}
"#;
assert_uplc(
src,
Term::equals_data()
.apply(
Term::list_data().apply(
Term::var("do_init")
.lambda("do_init")
.apply(
Term::var("self")
.delayed_choose_list(
Term::Error.trace(Term::string("unreachable")),
Term::var("tail_1")
.delayed_choose_list(
Term::empty_list(),
Term::var("self")
.choose_list(
Term::var("clause_guard")
.if_else(
Term::var("g").delay(),
Term::var("clauses_delayed"),
)
.force()
.lambda("clause_guard")
.apply(
Term::less_than_integer()
.apply(Term::integer(10.into()))
.apply(Term::var("a")),
)
.lambda("g")
.apply(
Term::tail_list()
.apply(Term::var("self")),
)
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list()
.apply(Term::var("self")),
),
)
.delay(),
Term::var("clauses_delayed"),
)
.force()
.lambda("clauses_delayed")
.apply(
Term::var("tail_2")
.delayed_choose_list(
Term::mk_cons()
.apply(
Term::i_data()
.apply(Term::var("a")),
)
.apply(Term::empty_list())
.lambda("x")
.apply(
Term::un_i_data().apply(
Term::head_list().apply(
Term::var("tail_1"),
),
),
)
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list().apply(
Term::var("self"),
),
),
),
Term::var("c").lambda("c").apply(
Term::tail_list()
.apply(Term::var("tail_1"))
.lambda("b")
.apply(Term::un_i_data().apply(
Term::head_list().apply(
Term::var("tail_1"),
),
))
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list()
.apply(Term::var(
"self",
)),
),
),
),
)
.lambda("tail_2")
.apply(
Term::tail_list()
.apply(Term::var("tail_1")),
)
.delay(),
),
)
.lambda("tail_1")
.apply(Term::tail_list().apply(Term::var("self"))),
)
.lambda("self"),
)
.apply(Term::list_values(vec![
Constant::Data(Data::integer(1.into())),
Constant::Data(Data::integer(3.into())),
])),
),
)
.apply(Term::data(Data::list(vec![Data::integer(1.into())]))),
false,
);
}
#[test]
fn list_clause_with_assign() {
let src = r#"
fn do_init(self: List<Int>) -> List<Int> {
when self is {
[] -> fail @"unreachable"
[_] as a ->
a
[a, x] if x > 2 -> {
[a]
}
[a, x] -> {
[a]
}
[a, b, ..c] -> {
c
}
}
}
test init_3() {
do_init([1, 3]) == [1]
}
"#;
assert_uplc(
src,
Term::equals_data()
.apply(
Term::list_data().apply(
Term::var("do_init")
.lambda("do_init")
.apply(
Term::var("self")
.delayed_choose_list(
Term::Error.trace(Term::string("unreachable")),
Term::var("tail_1")
.delayed_choose_list(
Term::empty_list(),
Term::var("tail_2")
.choose_list(
Term::var("clause_guard")
.if_else(
Term::mk_cons()
.apply(
Term::i_data()
.apply(Term::var("a")),
)
.apply(Term::empty_list())
.delay(),
Term::var("clauses_delayed"),
)
.force()
.lambda("clause_guard")
.apply(
Term::less_than_integer()
.apply(Term::integer(2.into()))
.apply(Term::var("x")),
)
.lambda("x")
.apply(
Term::un_i_data().apply(
Term::head_list()
.apply(Term::var("tail_1")),
),
)
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list()
.apply(Term::var("self")),
),
)
.delay(),
Term::var("clauses_delayed"),
)
.force()
.lambda("clauses_delayed")
.apply(
Term::var("tail_2")
.delayed_choose_list(
Term::empty_list()
.lambda("b")
.apply(
Term::un_i_data().apply(
Term::head_list().apply(
Term::var("tail_1"),
),
),
)
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list().apply(
Term::var("self"),
),
),
),
Term::var("c").lambda("c").apply(
Term::tail_list()
.apply(Term::var("tail_1"))
.lambda("b")
.apply(Term::un_i_data().apply(
Term::head_list().apply(
Term::var("tail_1"),
),
))
.lambda("a")
.apply(
Term::un_i_data().apply(
Term::head_list()
.apply(Term::var(
"self",
)),
),
),
),
)
.delay(),
)
.lambda("tail_2")
.apply(
Term::tail_list().apply(Term::var("tail_1")),
),
)
.lambda("tail_1")
.apply(Term::tail_list().apply(Term::var("self"))),
)
.lambda("self"),
)
.apply(Term::list_values(vec![
Constant::Data(Data::integer(1.into())),
Constant::Data(Data::integer(3.into())),
])),
),
)
.apply(Term::data(Data::list(vec![Data::integer(1.into())]))),
false,
);
}