fix: error when using nested boolean checks in when conditions

This commit is contained in:
microproofs 2023-06-06 03:04:12 -04:00
parent bfd2a50a6b
commit 5faa925aea
2 changed files with 185 additions and 20 deletions

View File

@ -1078,7 +1078,7 @@ impl<'a> CodeGenerator<'a> {
); );
// if only one constructor, no need to check // if only one constructor, no need to check
if data_type.constructors.len() > 1 { if data_type.constructors.len() > 1 || *clause_properties.is_final_clause() {
// push constructor Index // push constructor Index
let mut tag_stack = pattern_stack.empty_with_scope(); let mut tag_stack = pattern_stack.empty_with_scope();
tag_stack.integer(index.to_string()); tag_stack.integer(index.to_string());
@ -1527,7 +1527,6 @@ impl<'a> CodeGenerator<'a> {
} => { } => {
let id = self.id_gen.next(); let id = self.id_gen.next();
let constr_var_name = format!("{constr_name}_{id}"); let constr_var_name = format!("{constr_name}_{id}");
let data_type = builder::lookup_data_type_by_tipo(&self.data_types, tipo).unwrap();
let mut when_stack = pattern_stack.empty_with_scope(); let mut when_stack = pattern_stack.empty_with_scope();
@ -1547,10 +1546,11 @@ impl<'a> CodeGenerator<'a> {
&mut clause_properties, &mut clause_properties,
); );
if data_type.constructors.len() > 1 { let data_type = builder::lookup_data_type_by_tipo(&self.data_types, tipo);
if final_clause { if final_clause {
pattern_stack.finally(when_stack); pattern_stack.finally(when_stack);
} else { } else if let Some(data_type) = data_type {
if data_type.constructors.len() > 1 {
let empty_stack = pattern_stack.empty_with_scope(); let empty_stack = pattern_stack.empty_with_scope();
pattern_stack.clause_guard( pattern_stack.clause_guard(
constr_var_name.clone(), constr_var_name.clone(),
@ -1558,10 +1558,18 @@ impl<'a> CodeGenerator<'a> {
when_stack, when_stack,
empty_stack, empty_stack,
); );
}
} else { } else {
pattern_stack.merge_child(when_stack); pattern_stack.merge_child(when_stack);
} }
} else {
let empty_stack = pattern_stack.empty_with_scope();
pattern_stack.clause_guard(
constr_var_name.clone(),
tipo.clone(),
when_stack,
empty_stack,
)
}
Some(constr_var_name) Some(constr_var_name)
} }

View File

@ -1,6 +1,24 @@
use aiken/interval.{Finite, Interval, IntervalBound}
use aiken/list use aiken/list
use aiken/time.{PosixTime}
use aiken/transaction.{ScriptContext, ValidityRange}
use aiken/transaction/value.{Value} use aiken/transaction/value.{Value}
// TODO added to the stdlib in #40
pub fn count(self: List<a>, predicate: fn(a) -> Bool) -> Int {
list.foldl(
self,
fn(item, total) {
if predicate(item) {
total + 1
} else {
total
}
},
0,
)
}
test foldl_value_test1() { test foldl_value_test1() {
let val1 = value.from_lovelace(1000000) let val1 = value.from_lovelace(1000000)
let val2 = value.from_lovelace(2000000) let val2 = value.from_lovelace(2000000)
@ -14,16 +32,155 @@ test foldl_value_test1() {
2, 2,
) )
} }
// test foldl_value_test2() {
// let val1 = value.from_lovelace(1000000) test foldl_value_test2() {
// let val2 = value.from_lovelace(2000000) let val1 = value.from_lovelace(1000000)
// let foo = let val2 = value.from_lovelace(2000000)
// fn(i: Value, acc: (Value, Int)) { let foo =
// let (v, int) = acc fn(i: Value, acc: (Value, Int)) {
// (value.add(i, v), int + 1) let (v, int) = acc
// } (value.add(i, v), int + 1)
// list.foldl([val1, val2], foo, (value.from_lovelace(0), 0)) == ( }
// value.from_lovelace(3000000), list.foldl([val1, val2], foo, (value.from_lovelace(0), 0)) == (
// 2, value.from_lovelace(3000000),
// ) 2,
// } )
}
pub type NativeScript {
Signature { keyHash: ByteArray }
AllOf { scripts: List<NativeScript> }
AnyOf { scripts: List<NativeScript> }
AtLeast { required: Int, scripts: List<NativeScript> }
Before { time: PosixTime }
After { time: PosixTime }
}
pub fn satisfied(
script: NativeScript,
signatories: List<ByteArray>,
validRange: ValidityRange,
) -> Bool {
when script is {
Signature { keyHash } -> list.has(signatories, keyHash)
AllOf { scripts } ->
list.all(scripts, fn(s) { satisfied(s, signatories, validRange) })
AnyOf { scripts } ->
list.any(scripts, fn(s) { satisfied(s, signatories, validRange) })
AtLeast { required, scripts } ->
required <= count(
scripts,
fn(s) { satisfied(s, signatories, validRange) },
)
Before { time } ->
when validRange.upper_bound.bound_type is {
Finite(hi) ->
if validRange.upper_bound.is_inclusive {
hi <= time
} else {
hi < time
}
_ -> False
}
After { time } ->
when validRange.lower_bound is {
IntervalBound { bound_type: b, is_inclusive: False } -> {
expect Finite(lo) = b
time < lo
}
IntervalBound { bound_type: b, is_inclusive: True } -> {
expect Finite(lo) = b
time <= lo
}
}
}
// After { time } ->
// when validRange.lower_bound.bound_type is {
// Finite(lo) ->
// if validRange.lower_bound.is_inclusive {
// time <= lo
// } else {
// time < lo
// }
// _ -> False
// }
}
test satisfying() {
let keyHash1 = "key1"
let keyHash2 = "key2"
let keyHash3 = "key3"
let sig1 = Signature { keyHash: keyHash1 }
let sig2 = Signature { keyHash: keyHash2 }
let sig3 = Signature { keyHash: keyHash3 }
let allOf = AllOf { scripts: [sig1, sig2] }
let anyOf = AnyOf { scripts: [sig1, sig2] }
let atLeast = AtLeast { required: 2, scripts: [sig1, sig2, sig3] }
let before = Before { time: 10 }
let after = After { time: 10 }
let between = AllOf { scripts: [After { time: 10 }, Before { time: 15 }] }
let vesting =
AnyOf {
scripts: [
AllOf { scripts: [before, sig1] },
// clawback
AllOf { scripts: [after, sig2] },
],
}
// vested
let validRange =
fn(lo: Int, hi: Int) -> ValidityRange {
Interval {
lower_bound: IntervalBound {
bound_type: Finite(lo),
is_inclusive: True,
},
upper_bound: IntervalBound {
bound_type: Finite(hi),
is_inclusive: False,
},
}
}
// Helper method because ? binds more tightly than !
let unsatisfied =
fn(n: NativeScript, s: List<ByteArray>, v: ValidityRange) {
!satisfied(n, s, v)
}
list.and(
[
satisfied(sig1, [keyHash1], validRange(0, 1))?,
satisfied(sig2, [keyHash1, keyHash2], validRange(0, 1))?,
satisfied(allOf, [keyHash1, keyHash2], validRange(0, 1))?,
satisfied(anyOf, [keyHash2], validRange(0, 1))?,
satisfied(atLeast, [keyHash2, keyHash3], validRange(0, 1))?,
satisfied(before, [], validRange(0, 5))?,
satisfied(after, [], validRange(15, 20))?,
satisfied(after, [], validRange(10, 15))?,
satisfied(between, [], validRange(12, 13))?,
satisfied(vesting, [keyHash1], validRange(0, 5))?,
satisfied(vesting, [keyHash2], validRange(15, 20))?,
unsatisfied(sig1, [keyHash2], validRange(0, 1))?,
unsatisfied(sig3, [keyHash1, keyHash2], validRange(0, 1))?,
unsatisfied(allOf, [keyHash1, keyHash3], validRange(0, 1))?,
unsatisfied(anyOf, [keyHash3], validRange(0, 1))?,
unsatisfied(atLeast, [keyHash2], validRange(0, 1))?,
unsatisfied(before, [], validRange(5, 15))?,
unsatisfied(before, [], validRange(5, 10))?,
unsatisfied(before, [], validRange(10, 10))?,
unsatisfied(after, [], validRange(5, 15))?,
unsatisfied(between, [], validRange(0, 5))?,
unsatisfied(between, [], validRange(0, 13))?,
unsatisfied(between, [], validRange(0, 20))?,
unsatisfied(between, [], validRange(13, 20))?,
unsatisfied(between, [], validRange(13, 15))?,
unsatisfied(between, [], validRange(15, 20))?,
unsatisfied(vesting, [keyHash2], validRange(0, 5))?,
unsatisfied(vesting, [keyHash1], validRange(15, 20))?,
unsatisfied(vesting, [keyHash3], validRange(10, 10))?,
unsatisfied(vesting, [keyHash3], validRange(0, 5))?,
unsatisfied(vesting, [keyHash3], validRange(15, 20))?,
],
)
}