From d7e4aef4c530063f32d72e7998280846d14a5992 Mon Sep 17 00:00:00 2001 From: Kasey White Date: Tue, 17 Jan 2023 03:43:47 -0500 Subject: [PATCH] feat: Add boolean conditions to when statements --- crates/aiken-lang/src/air.rs | 6 + crates/aiken-lang/src/uplc.rs | 367 ++++++++++++++++++++-------------- 2 files changed, 218 insertions(+), 155 deletions(-) diff --git a/crates/aiken-lang/src/air.rs b/crates/aiken-lang/src/air.rs index 3ecb4648..52791762 100644 --- a/crates/aiken-lang/src/air.rs +++ b/crates/aiken-lang/src/air.rs @@ -24,6 +24,11 @@ pub enum Air { bytes: Vec, }, + Bool { + scope: Vec, + value: bool, + }, + Var { scope: Vec, constructor: ValueConstructor, @@ -228,6 +233,7 @@ impl Air { Air::Int { scope, .. } | Air::String { scope, .. } | Air::ByteArray { scope, .. } + | Air::Bool { scope, .. } | Air::Var { scope, .. } | Air::List { scope, .. } | Air::ListAccessor { scope, .. } diff --git a/crates/aiken-lang/src/uplc.rs b/crates/aiken-lang/src/uplc.rs index aa09ae4e..d9fa9a43 100644 --- a/crates/aiken-lang/src/uplc.rs +++ b/crates/aiken-lang/src/uplc.rs @@ -838,61 +838,68 @@ impl<'a> CodeGenerator<'a> { let mut temp_clause_properties = clause_properties.clone(); *temp_clause_properties.needs_constr_var() = false; - for arg in arguments { - check_when_pattern_needs(&arg.value, &mut temp_clause_properties); - } - - // find data type definition - let data_type = lookup_data_type_by_tipo(self.data_types.clone(), tipo).unwrap(); - - let (index, _) = data_type - .constructors - .iter() - .enumerate() - .find(|(_, dt)| &dt.name == constr_name) - .unwrap(); - - let mut new_vec = vec![Air::Var { - constructor: ValueConstructor::public( - tipo.clone().into(), - ValueConstructorVariant::LocalVariable { - location: Span::empty(), - }, - ), - name: temp_clause_properties.clause_var_name().clone(), - scope: scope.clone(), - variant_name: String::new(), - }]; - - // if only one constructor, no need to check - if data_type.constructors.len() > 1 { - // push constructor Index - pattern_vec.push(Air::Int { - value: index.to_string(), - scope: scope.clone(), + if tipo.is_bool() { + pattern_vec.push(Air::Bool { + scope, + value: constr_name == "True", }); - } - - if *temp_clause_properties.needs_constr_var() { - self.when_recursive_ir( - pattern, - pattern_vec, - &mut new_vec, - clause_properties, - tipo, - scope, - ); } else { - self.when_recursive_ir( - pattern, - pattern_vec, - &mut vec![], - clause_properties, - tipo, - scope, - ); - } + for arg in arguments { + check_when_pattern_needs(&arg.value, &mut temp_clause_properties); + } + // find data type definition + let data_type = + lookup_data_type_by_tipo(self.data_types.clone(), tipo).unwrap(); + + let (index, _) = data_type + .constructors + .iter() + .enumerate() + .find(|(_, dt)| &dt.name == constr_name) + .unwrap(); + + let mut new_vec = vec![Air::Var { + constructor: ValueConstructor::public( + tipo.clone().into(), + ValueConstructorVariant::LocalVariable { + location: Span::empty(), + }, + ), + name: temp_clause_properties.clause_var_name().clone(), + scope: scope.clone(), + variant_name: String::new(), + }]; + + // if only one constructor, no need to check + if data_type.constructors.len() > 1 { + // push constructor Index + pattern_vec.push(Air::Int { + value: index.to_string(), + scope: scope.clone(), + }); + } + + if *temp_clause_properties.needs_constr_var() { + self.when_recursive_ir( + pattern, + pattern_vec, + &mut new_vec, + clause_properties, + tipo, + scope, + ); + } else { + self.when_recursive_ir( + pattern, + pattern_vec, + &mut vec![], + clause_properties, + tipo, + scope, + ); + } + } pattern_vec.append(values); // unify clause properties @@ -2684,6 +2691,10 @@ impl<'a> CodeGenerator<'a> { let term = Term::Constant(UplcConstant::ByteString(bytes)); arg_stack.push(term); } + Air::Bool { value, .. } => { + let term = Term::Constant(UplcConstant::Bool(value)); + arg_stack.push(term); + } Air::Var { name, constructor, @@ -3626,66 +3637,86 @@ impl<'a> CodeGenerator<'a> { // the next branch in the when expression let mut term = arg_stack.pop().unwrap(); - let checker = if tipo.is_int() { - apply_wrap( - DefaultFunction::EqualsInteger.into(), - Term::Var(Name { - text: subject_name, - unique: 0.into(), - }), - ) - } else if tipo.is_bytearray() { - apply_wrap( - DefaultFunction::EqualsByteString.into(), - Term::Var(Name { - text: subject_name, - unique: 0.into(), - }), - ) - } else if tipo.is_bool() { - todo!("Bool in when statements not done yet") - } else if tipo.is_string() { - apply_wrap( - DefaultFunction::EqualsString.into(), - Term::Var(Name { - text: subject_name, - unique: 0.into(), - }), - ) - } else if tipo.is_list() || tipo.is_tuple() { - unreachable!() - } else { - apply_wrap( - DefaultFunction::EqualsInteger.into(), - Term::Var(Name { - text: subject_name, - unique: 0.into(), - }), - ) - }; - - if complex_clause { - term = apply_wrap( - Term::Lambda { - parameter_name: Name { - text: "__other_clauses_delayed".to_string(), + if tipo.is_bool() { + if matches!(clause, Term::Constant(UplcConstant::Bool(true))) { + term = delayed_if_else( + Term::Var(Name { + text: subject_name, unique: 0.into(), - }, - body: if_else( - apply_wrap(checker, clause), - Term::Delay(body.into()), - Term::Var(Name { + }), + body, + term, + ); + } else { + term = delayed_if_else( + Term::Var(Name { + text: subject_name, + unique: 0.into(), + }), + term, + body, + ); + } + } else { + let checker = if tipo.is_int() { + apply_wrap( + DefaultFunction::EqualsInteger.into(), + Term::Var(Name { + text: subject_name, + unique: 0.into(), + }), + ) + } else if tipo.is_bytearray() { + apply_wrap( + DefaultFunction::EqualsByteString.into(), + Term::Var(Name { + text: subject_name, + unique: 0.into(), + }), + ) + } else if tipo.is_string() { + apply_wrap( + DefaultFunction::EqualsString.into(), + Term::Var(Name { + text: subject_name, + unique: 0.into(), + }), + ) + } else if tipo.is_list() || tipo.is_tuple() { + unreachable!() + } else { + apply_wrap( + DefaultFunction::EqualsInteger.into(), + Term::Var(Name { + text: subject_name, + unique: 0.into(), + }), + ) + }; + + if complex_clause { + term = apply_wrap( + Term::Lambda { + parameter_name: Name { text: "__other_clauses_delayed".to_string(), unique: 0.into(), - }), - ) - .force_wrap() - .into(), - }, - Term::Delay(term.into()), - ); - } else { - term = delayed_if_else(apply_wrap(checker, clause), body, term); + }, + body: if_else( + apply_wrap(checker, clause), + Term::Delay(body.into()), + Term::Var(Name { + text: "__other_clauses_delayed".to_string(), + unique: 0.into(), + }), + ) + .force_wrap() + .into(), + }, + Term::Delay(term.into()), + ); + } else { + term = delayed_if_else(apply_wrap(checker, clause), body, term); + } } arg_stack.push(term); @@ -3767,55 +3798,81 @@ impl<'a> CodeGenerator<'a> { let then = arg_stack.pop().unwrap(); - let checker = if tipo.is_int() { - apply_wrap( - DefaultFunction::EqualsInteger.into(), - Term::Var(Name { - text: subject_name, - unique: 0.into(), - }), - ) - } else if tipo.is_bytearray() { - apply_wrap( - DefaultFunction::EqualsByteString.into(), - Term::Var(Name { - text: subject_name, - unique: 0.into(), - }), - ) - } else if tipo.is_bool() { - todo!("Nested bool usage in when statements not yet implemented") - } else if tipo.is_string() { - apply_wrap( - DefaultFunction::EqualsString.into(), - Term::Var(Name { - text: subject_name, - unique: 0.into(), - }), - ) - } else if tipo.is_list() || tipo.is_tuple() { - unreachable!() - } else { - apply_wrap( - DefaultFunction::EqualsInteger.into(), - constr_index_exposer(Term::Var(Name { - text: subject_name, - unique: 0.into(), - })), - ) - }; - - let term = if_else( - apply_wrap(checker, condition), - Term::Delay(then.into()), - Term::Var(Name { + if tipo.is_bool() { + let mut term = Term::Var(Name { text: "__other_clauses_delayed".to_string(), unique: 0.into(), - }), - ) - .force_wrap(); + }); + if matches!(condition, Term::Constant(UplcConstant::Bool(true))) { + term = if_else( + Term::Var(Name { + text: subject_name, + unique: 0.into(), + }), + Term::Delay(then.into()), + term, + ) + .force_wrap(); + } else { + term = if_else( + Term::Var(Name { + text: subject_name, + unique: 0.into(), + }), + term, + Term::Delay(then.into()), + ) + .force_wrap(); + } + arg_stack.push(term); + } else { + let checker = if tipo.is_int() { + apply_wrap( + DefaultFunction::EqualsInteger.into(), + Term::Var(Name { + text: subject_name, + unique: 0.into(), + }), + ) + } else if tipo.is_bytearray() { + apply_wrap( + DefaultFunction::EqualsByteString.into(), + Term::Var(Name { + text: subject_name, + unique: 0.into(), + }), + ) + } else if tipo.is_string() { + apply_wrap( + DefaultFunction::EqualsString.into(), + Term::Var(Name { + text: subject_name, + unique: 0.into(), + }), + ) + } else if tipo.is_list() || tipo.is_tuple() { + unreachable!() + } else { + apply_wrap( + DefaultFunction::EqualsInteger.into(), + constr_index_exposer(Term::Var(Name { + text: subject_name, + unique: 0.into(), + })), + ) + }; - arg_stack.push(term); + let term = if_else( + apply_wrap(checker, condition), + Term::Delay(then.into()), + Term::Var(Name { + text: "__other_clauses_delayed".to_string(), + unique: 0.into(), + }), + ) + .force_wrap(); + arg_stack.push(term); + } } Air::ListClauseGuard { tail_name,