From a3f7b48ec39c5cffd3c8fae92d43055362cf068e Mon Sep 17 00:00:00 2001 From: KtorZ Date: Sat, 23 Mar 2024 11:56:38 +0100 Subject: [PATCH] Allow downcasting to data in piped function calls. We have been a bit too strict on disallowing 'allow_cast' propagations. This is really only problematic for nested elements like Tuple's elements or App's args. However, for linked and unbound var it is probably okay, and it certainly is as well for function arguments. --- crates/aiken-lang/src/tests/check.rs | 34 +++++++++++++++++++++++ crates/aiken-lang/src/tipo/environment.rs | 6 ++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/crates/aiken-lang/src/tests/check.rs b/crates/aiken-lang/src/tests/check.rs index ff599dc8..b5265c1b 100644 --- a/crates/aiken-lang/src/tests/check.rs +++ b/crates/aiken-lang/src/tests/check.rs @@ -2131,6 +2131,40 @@ fn can_down_cast_to_data_always() { assert!(check(parse(source_code)).is_ok()); } +#[test] +fn can_down_cast_to_data_on_fn_call() { + let source_code = r#" + pub type Foo { Foo } + + pub fn serialise(data: Data) -> ByteArray { + "" + } + + test foo() { + serialise(Foo) == "" + } + "#; + + assert!(check(parse(source_code)).is_ok()); +} + +#[test] +fn can_down_cast_to_data_on_pipe() { + let source_code = r#" + pub type Foo { Foo } + + pub fn serialise(data: Data) -> ByteArray { + "" + } + + test foo() { + (Foo |> serialise) == "" + } + "#; + + assert!(check(parse(source_code)).is_ok()); +} + #[test] fn correct_span_for_backpassing_args() { let source_code = r#" diff --git a/crates/aiken-lang/src/tipo/environment.rs b/crates/aiken-lang/src/tipo/environment.rs index 430b6754..49f46a52 100644 --- a/crates/aiken-lang/src/tipo/environment.rs +++ b/crates/aiken-lang/src/tipo/environment.rs @@ -1431,7 +1431,7 @@ impl<'a> Environment<'a> { lhs, Type::with_alias(tipo.clone(), alias.clone()), location, - false, + allow_cast, ); } } @@ -1470,7 +1470,7 @@ impl<'a> Environment<'a> { Ok(()) } - Action::Unify(t) => self.unify(t, rhs, location, false), + Action::Unify(t) => self.unify(t, rhs, location, allow_cast), Action::CouldNotUnify => Err(Error::CouldNotUnify { location, @@ -1550,7 +1550,7 @@ impl<'a> Environment<'a> { }, ) if args1.len() == args2.len() => { for (a, b) in args1.iter().zip(args2) { - self.unify(a.clone(), b.clone(), location, false) + self.unify(a.clone(), b.clone(), location, allow_cast) .map_err(|_| Error::CouldNotUnify { location, expected: lhs.clone(),