From 2456801b170d82660dbde3f38f07ec987e03bc04 Mon Sep 17 00:00:00 2001 From: microproofs Date: Wed, 16 Aug 2023 16:31:44 -0400 Subject: [PATCH] fix list clauses with guards and add more tests --- crates/aiken-lang/src/gen_uplc.rs | 41 ++- crates/aiken-lang/src/gen_uplc/air.rs | 2 +- crates/aiken-lang/src/gen_uplc/builder.rs | 28 +- crates/aiken-lang/src/gen_uplc/tree.rs | 4 +- crates/aiken-project/src/tests/gen_uplc.rs | 410 +++++++++++++++++++++ 5 files changed, 463 insertions(+), 22 deletions(-) diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index 93e2fb96..4dda246f 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -24,9 +24,9 @@ use crate::{ expr::TypedExpr, gen_uplc::builder::{ convert_opaque_type, erase_opaque_type_operations, find_and_replace_generics, - get_arg_type_name, get_generic_id_and_type, get_variant_name, monomorphize, - pattern_has_conditions, wrap_as_multi_validator, wrap_validator_condition, CodeGenFunction, - SpecificClause, + find_list_clause_or_default_first, get_arg_type_name, get_generic_id_and_type, + get_variant_name, monomorphize, pattern_has_conditions, wrap_as_multi_validator, + wrap_validator_condition, CodeGenFunction, SpecificClause, }, tipo::{ ModuleValueConstructor, PatternConstructor, Type, TypeInfo, ValueConstructor, @@ -1703,8 +1703,12 @@ impl<'a> CodeGenerator<'a> { ); }; + assert!(!elements.is_empty() || tail.is_none()); + let elements_len = elements.len() + usize::from(tail.is_none()) - 1; + let current_checked_index = *checked_index; + let tail_name = defined_tails - .last() + .get(elements_len) .cloned() .unwrap_or(props.original_subject_name.clone()); @@ -1712,8 +1716,8 @@ impl<'a> CodeGenerator<'a> { if rest_clauses.is_empty() { None } else { - let next_clause = &rest_clauses[0]; - let mut next_clause_pattern = &rest_clauses[0].pattern; + let next_clause = find_list_clause_or_default_first(rest_clauses); + let mut next_clause_pattern = &next_clause.pattern; if let Pattern::Assign { pattern, .. } = next_clause_pattern { next_clause_pattern = pattern; @@ -1729,6 +1733,7 @@ impl<'a> CodeGenerator<'a> { if (*defined_tails_index as usize) < next_elements_len { *defined_tails_index += 1; + let current_defined_tail = defined_tails.last().unwrap().clone(); defined_tails.push(format!( "tail_index_{}_span_{}_{}", @@ -1737,11 +1742,14 @@ impl<'a> CodeGenerator<'a> { next_clause.pattern.location().end )); - Some(format!( - "tail_index_{}_span_{}_{}", - *defined_tails_index, - next_clause.pattern.location().start, - next_clause.pattern.location().end + Some(( + current_defined_tail, + format!( + "tail_index_{}_span_{}_{}", + *defined_tails_index, + next_clause.pattern.location().start, + next_clause.pattern.location().end + ), )) } else { None @@ -1754,9 +1762,6 @@ impl<'a> CodeGenerator<'a> { is_wild_card_elems_clause = is_wild_card_elems_clause && !pattern_has_conditions(element); } - assert!(!elements.is_empty() || tail.is_none()); - let elements_len = elements.len() + usize::from(tail.is_none()) - 1; - let current_checked_index = *checked_index; if *checked_index < elements_len.try_into().unwrap() && is_wild_card_elems_clause @@ -1784,7 +1789,9 @@ impl<'a> CodeGenerator<'a> { let complex_clause = props.complex_clause; - if current_checked_index < elements_len.try_into().unwrap() { + if current_checked_index < elements_len.try_into().unwrap() + || next_tail_name.is_some() + { AirTree::list_clause( tail_name, subject_tipo.clone(), @@ -3996,9 +4003,9 @@ impl<'a> CodeGenerator<'a> { let body = arg_stack.pop().unwrap(); let mut term = arg_stack.pop().unwrap(); - let arg = if let Some(next_tail_name) = next_tail_name { + let arg = if let Some((current_tail, next_tail_name)) = next_tail_name { term.lambda(next_tail_name) - .apply(Term::tail_list().apply(Term::var(tail_name.clone()))) + .apply(Term::tail_list().apply(Term::var(current_tail.clone()))) } else { term }; diff --git a/crates/aiken-lang/src/gen_uplc/air.rs b/crates/aiken-lang/src/gen_uplc/air.rs index 5853f170..62d1485a 100644 --- a/crates/aiken-lang/src/gen_uplc/air.rs +++ b/crates/aiken-lang/src/gen_uplc/air.rs @@ -97,7 +97,7 @@ pub enum Air { ListClause { subject_tipo: Arc, tail_name: String, - next_tail_name: Option, + next_tail_name: Option<(String, String)>, complex_clause: bool, }, WrapClause, diff --git a/crates/aiken-lang/src/gen_uplc/builder.rs b/crates/aiken-lang/src/gen_uplc/builder.rs index c30e9008..32c8ba5e 100644 --- a/crates/aiken-lang/src/gen_uplc/builder.rs +++ b/crates/aiken-lang/src/gen_uplc/builder.rs @@ -810,7 +810,18 @@ pub fn rearrange_list_clauses(clauses: Vec) -> Vec { .into_iter() .enumerate() .sorted_by(|(index1, clause1), (index2, clause2)| { - let clause1_len = match &clause1.pattern { + let mut clause_pattern1 = &clause1.pattern; + let mut clause_pattern2 = &clause2.pattern; + + if let Pattern::Assign { pattern, .. } = clause_pattern1 { + clause_pattern1 = pattern; + } + + if let Pattern::Assign { pattern, .. } = clause_pattern2 { + clause_pattern2 = pattern; + } + + let clause1_len = match clause_pattern1 { Pattern::List { elements, tail, .. } => { Some(elements.len() + usize::from(tail.is_some() && clause1.guard.is_none())) } @@ -818,7 +829,7 @@ pub fn rearrange_list_clauses(clauses: Vec) -> Vec { _ => None, }; - let clause2_len = match &clause2.pattern { + let clause2_len = match clause_pattern2 { Pattern::List { elements, tail, .. } => { Some(elements.len() + usize::from(tail.is_some() && clause2.guard.is_none())) } @@ -1013,6 +1024,19 @@ pub fn rearrange_list_clauses(clauses: Vec) -> Vec { final_clauses } +pub fn find_list_clause_or_default_first(clauses: &[TypedClause]) -> &TypedClause { + assert!(!clauses.is_empty()); + + clauses + .iter() + .find(|clause| match &clause.pattern { + Pattern::List { .. } => true, + Pattern::Assign { pattern, .. } if matches!(&**pattern, Pattern::List { .. }) => true, + _ => false, + }) + .unwrap_or(&clauses[0]) +} + pub fn convert_data_to_type(term: Term, field_type: &Arc) -> Term { if field_type.is_int() { Term::un_i_data().apply(term) diff --git a/crates/aiken-lang/src/gen_uplc/tree.rs b/crates/aiken-lang/src/gen_uplc/tree.rs index d0912f49..2b7aa06a 100644 --- a/crates/aiken-lang/src/gen_uplc/tree.rs +++ b/crates/aiken-lang/src/gen_uplc/tree.rs @@ -282,7 +282,7 @@ pub enum AirExpression { ListClause { subject_tipo: Arc, tail_name: String, - next_tail_name: Option, + next_tail_name: Option<(String, String)>, complex_clause: bool, then: Box, otherwise: Box, @@ -546,7 +546,7 @@ impl AirTree { subject_tipo: Arc, then: AirTree, otherwise: AirTree, - next_tail_name: Option, + next_tail_name: Option<(String, String)>, complex_clause: bool, ) -> AirTree { AirTree::Expression(AirExpression::ListClause { diff --git a/crates/aiken-project/src/tests/gen_uplc.rs b/crates/aiken-project/src/tests/gen_uplc.rs index 030fb244..367ef87f 100644 --- a/crates/aiken-project/src/tests/gen_uplc.rs +++ b/crates/aiken-project/src/tests/gen_uplc.rs @@ -4523,3 +4523,413 @@ fn list_clause_with_guard() { false, ); } + +#[test] +fn list_clause_with_guard2() { + let src = r#" + fn do_init(self: List) -> List { + when self is { + [] -> fail @"unreachable" + [_] -> + [] + [a, x] -> { + [a] + } + [a] if a > 10 -> [] + [a, b, ..c] -> { + c + } + } + } + + test init_3() { + do_init([1, 3]) == [1] + } + "#; + + assert_uplc( + src, + Term::equals_data() + .apply( + Term::list_data().apply( + Term::var("do_init") + .lambda("do_init") + .apply( + Term::var("self") + .delayed_choose_list( + Term::Error.trace(Term::string("unreachable")), + Term::var("tail_1") + .delayed_choose_list( + Term::empty_list(), + Term::var("tail_1") + .choose_list( + Term::var("clause_guard") + .if_else( + Term::empty_list().delay(), + Term::var("clauses_delayed"), + ) + .force() + .lambda("clause_guard") + .apply( + Term::less_than_integer() + .apply(Term::integer(10.into())) + .apply(Term::var("a")), + ) + .lambda("a") + .apply( + Term::un_i_data().apply( + Term::head_list() + .apply(Term::var("self")), + ), + ) + .delay(), + Term::var("clauses_delayed"), + ) + .force() + .lambda("clauses_delayed") + .apply( + Term::var("tail_2") + .delayed_choose_list( + Term::mk_cons() + .apply( + Term::i_data() + .apply(Term::var("a")), + ) + .apply(Term::empty_list()) + .lambda("x") + .apply( + Term::un_i_data().apply( + Term::head_list().apply( + Term::var("tail_1"), + ), + ), + ) + .lambda("a") + .apply( + Term::un_i_data().apply( + Term::head_list().apply( + Term::var("self"), + ), + ), + ), + Term::var("c").lambda("c").apply( + Term::tail_list() + .apply(Term::var("tail_1")) + .lambda("b") + .apply(Term::un_i_data().apply( + Term::head_list().apply( + Term::var("tail_1"), + ), + )) + .lambda("a") + .apply( + Term::un_i_data().apply( + Term::head_list() + .apply(Term::var( + "self", + )), + ), + ), + ), + ) + .lambda("tail_2") + .apply( + Term::tail_list() + .apply(Term::var("tail_1")), + ) + .delay(), + ), + ) + .lambda("tail_1") + .apply(Term::tail_list().apply(Term::var("self"))), + ) + .lambda("self"), + ) + .apply(Term::list_values(vec![ + Constant::Data(Data::integer(1.into())), + Constant::Data(Data::integer(3.into())), + ])), + ), + ) + .apply(Term::data(Data::list(vec![Data::integer(1.into())]))), + false, + ); +} + +#[test] +fn list_clause_with_guard3() { + let src = r#" + fn do_init(self: List) -> List { + when self is { + [] -> fail @"unreachable" + [_] -> + [] + [a, x] -> { + [a] + } + [a, ..g] if a > 10 -> g + [a, b, ..c] -> { + c + } + } + } + + test init_3() { + do_init([1, 3]) == [1] + } + "#; + + assert_uplc( + src, + Term::equals_data() + .apply( + Term::list_data().apply( + Term::var("do_init") + .lambda("do_init") + .apply( + Term::var("self") + .delayed_choose_list( + Term::Error.trace(Term::string("unreachable")), + Term::var("tail_1") + .delayed_choose_list( + Term::empty_list(), + Term::var("self") + .choose_list( + Term::var("clause_guard") + .if_else( + Term::var("g").delay(), + Term::var("clauses_delayed"), + ) + .force() + .lambda("clause_guard") + .apply( + Term::less_than_integer() + .apply(Term::integer(10.into())) + .apply(Term::var("a")), + ) + .lambda("g") + .apply( + Term::tail_list() + .apply(Term::var("self")), + ) + .lambda("a") + .apply( + Term::un_i_data().apply( + Term::head_list() + .apply(Term::var("self")), + ), + ) + .delay(), + Term::var("clauses_delayed"), + ) + .force() + .lambda("clauses_delayed") + .apply( + Term::var("tail_2") + .delayed_choose_list( + Term::mk_cons() + .apply( + Term::i_data() + .apply(Term::var("a")), + ) + .apply(Term::empty_list()) + .lambda("x") + .apply( + Term::un_i_data().apply( + Term::head_list().apply( + Term::var("tail_1"), + ), + ), + ) + .lambda("a") + .apply( + Term::un_i_data().apply( + Term::head_list().apply( + Term::var("self"), + ), + ), + ), + Term::var("c").lambda("c").apply( + Term::tail_list() + .apply(Term::var("tail_1")) + .lambda("b") + .apply(Term::un_i_data().apply( + Term::head_list().apply( + Term::var("tail_1"), + ), + )) + .lambda("a") + .apply( + Term::un_i_data().apply( + Term::head_list() + .apply(Term::var( + "self", + )), + ), + ), + ), + ) + .lambda("tail_2") + .apply( + Term::tail_list() + .apply(Term::var("tail_1")), + ) + .delay(), + ), + ) + .lambda("tail_1") + .apply(Term::tail_list().apply(Term::var("self"))), + ) + .lambda("self"), + ) + .apply(Term::list_values(vec![ + Constant::Data(Data::integer(1.into())), + Constant::Data(Data::integer(3.into())), + ])), + ), + ) + .apply(Term::data(Data::list(vec![Data::integer(1.into())]))), + false, + ); +} + +#[test] +fn list_clause_with_assign() { + let src = r#" + fn do_init(self: List) -> List { + when self is { + [] -> fail @"unreachable" + [_] as a -> + a + [a, x] if x > 2 -> { + [a] + } + [a, x] -> { + [a] + } + [a, b, ..c] -> { + c + } + } + } + + test init_3() { + do_init([1, 3]) == [1] + } + "#; + + assert_uplc( + src, + Term::equals_data() + .apply( + Term::list_data().apply( + Term::var("do_init") + .lambda("do_init") + .apply( + Term::var("self") + .delayed_choose_list( + Term::Error.trace(Term::string("unreachable")), + Term::var("tail_1") + .delayed_choose_list( + Term::empty_list(), + Term::var("tail_2") + .choose_list( + Term::var("clause_guard") + .if_else( + Term::mk_cons() + .apply( + Term::i_data() + .apply(Term::var("a")), + ) + .apply(Term::empty_list()) + .delay(), + Term::var("clauses_delayed"), + ) + .force() + .lambda("clause_guard") + .apply( + Term::less_than_integer() + .apply(Term::integer(2.into())) + .apply(Term::var("x")), + ) + .lambda("x") + .apply( + Term::un_i_data().apply( + Term::head_list() + .apply(Term::var("tail_1")), + ), + ) + .lambda("a") + .apply( + Term::un_i_data().apply( + Term::head_list() + .apply(Term::var("self")), + ), + ) + .delay(), + Term::var("clauses_delayed"), + ) + .force() + .lambda("clauses_delayed") + .apply( + Term::var("tail_2") + .delayed_choose_list( + Term::empty_list() + .lambda("b") + .apply( + Term::un_i_data().apply( + Term::head_list().apply( + Term::var("tail_1"), + ), + ), + ) + .lambda("a") + .apply( + Term::un_i_data().apply( + Term::head_list().apply( + Term::var("self"), + ), + ), + ), + Term::var("c").lambda("c").apply( + Term::tail_list() + .apply(Term::var("tail_1")) + .lambda("b") + .apply(Term::un_i_data().apply( + Term::head_list().apply( + Term::var("tail_1"), + ), + )) + .lambda("a") + .apply( + Term::un_i_data().apply( + Term::head_list() + .apply(Term::var( + "self", + )), + ), + ), + ), + ) + .delay(), + ) + .lambda("tail_2") + .apply( + Term::tail_list().apply(Term::var("tail_1")), + ), + ) + .lambda("tail_1") + .apply(Term::tail_list().apply(Term::var("self"))), + ) + .lambda("self"), + ) + .apply(Term::list_values(vec![ + Constant::Data(Data::integer(1.into())), + Constant::Data(Data::integer(3.into())), + ])), + ), + ) + .apply(Term::data(Data::list(vec![Data::integer(1.into())]))), + false, + ); +}