diff --git a/crates/aiken-lang/src/ast.rs b/crates/aiken-lang/src/ast.rs index b326d123..1760ec53 100644 --- a/crates/aiken-lang/src/ast.rs +++ b/crates/aiken-lang/src/ast.rs @@ -1479,13 +1479,17 @@ impl AssignmentKind { } } -impl AssignmentKind<()> { +impl AssignmentKind { pub fn let_() -> Self { - AssignmentKind::Let { backpassing: () } + AssignmentKind::Let { + backpassing: Default::default(), + } } pub fn expect() -> Self { - AssignmentKind::Expect { backpassing: () } + AssignmentKind::Expect { + backpassing: Default::default(), + } } } diff --git a/crates/aiken-lang/src/tests/check.rs b/crates/aiken-lang/src/tests/check.rs index 89184c28..11a5c861 100644 --- a/crates/aiken-lang/src/tests/check.rs +++ b/crates/aiken-lang/src/tests/check.rs @@ -1205,6 +1205,46 @@ fn backpassing_basic() { assert!(check(parse(source_code)).is_ok()) } +#[test] +fn backpassing_expect_simple() { + let source_code = r#" + fn and_then(opt: Option, then: fn(a) -> Option) -> Option { + when opt is { + None -> None + Some(a) -> then(a) + } + } + + fn backpassing(opt_i: Option, opt_j: Option) -> Option { + expect 42 <- and_then(opt_i) + let j <- and_then(opt_j) + Some(j + 42) + } + "#; + + assert!(check(parse(source_code)).is_ok()) +} + +#[test] +fn backpassing_expect_nested() { + let source_code = r#" + fn and_then(opt: Option, then: fn(Option) -> Option) -> Option { + when opt is { + None -> None + Some(a) -> then(Some(a)) + } + } + + fn backpassing(opt_i: Option, opt_j: Option) -> Option { + expect Some(i) <- and_then(opt_i) + expect Some(j) <- and_then(opt_j) + Some(i + j) + } + "#; + + assert!(check(parse(source_code)).is_ok()) +} + #[test] fn backpassing_interleaved_capture() { let source_code = r#" @@ -1320,6 +1360,29 @@ fn backpassing_unsaturated_fn() { )) } +#[test] +fn backpassing_expect_type_mismatch() { + let source_code = r#" + fn and_then(opt: Option, then: fn(a) -> Option) -> Option { + when opt is { + None -> None + Some(a) -> then(a) + } + } + + fn backpassing(opt_i: Option, opt_j: Option) -> Option { + expect Some(i) <- and_then(opt_i) + let j <- and_then(opt_j) + Some(i + j) + } + "#; + + assert!(matches!( + check(parse(source_code)), + Err((_, Error::CouldNotUnify { .. })) + )) +} + #[test] fn trace_if_false_ko() { let source_code = r#" diff --git a/crates/aiken-lang/src/tipo/expr.rs b/crates/aiken-lang/src/tipo/expr.rs index 166e1b54..5fa3dcae 100644 --- a/crates/aiken-lang/src/tipo/expr.rs +++ b/crates/aiken-lang/src/tipo/expr.rs @@ -1708,14 +1708,15 @@ impl<'a, 'b> ExprTyper<'a, 'b> { } fn backpass(&mut self, breakpoint: UntypedExpr, continuation: Vec) -> UntypedExpr { - let (assign_location, value, pattern, annotation) = match breakpoint { + let (value, pattern, annotation, kind, assign_location) = match breakpoint { UntypedExpr::Assignment { location, value, pattern, annotation, + kind, .. - } => (location, value, pattern, annotation), + } => (value, pattern, annotation, kind, location), _ => unreachable!("backpass misuse: breakpoint isn't an Assignment ?!"), }; @@ -1723,19 +1724,23 @@ impl<'a, 'b> ExprTyper<'a, 'b> { // in front of the continuation sequence. This is because we do not support patterns in function argument // (which is perhaps something we should support?). let (name, continuation) = match pattern { - Pattern::Var { name, .. } | Pattern::Discard { name, .. } => { + Pattern::Var { name, .. } | Pattern::Discard { name, .. } if kind.is_let() => { (name.clone(), continuation) } _ => { let mut with_assignment = vec![UntypedExpr::Assignment { location: assign_location, value: UntypedExpr::Var { - location: assign_location, + location: value_location, name: ast::BACKPASS_VARIABLE.to_string(), } .into(), pattern, - kind: AssignmentKind::Let { backpassing: false }, + // Erase backpassing while preserving assignment kind. + kind: match kind { + AssignmentKind::Let { .. } => AssignmentKind::let_(), + AssignmentKind::Expect { .. } => AssignmentKind::expect(), + }, annotation, }]; with_assignment.extend(continuation); @@ -2126,7 +2131,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { location: Span::empty(), value: Box::new(subject.clone()), pattern: clauses[0].patterns[0].clone(), - kind: AssignmentKind::Let { backpassing: false }, + kind: AssignmentKind::let_(), annotation: None, }, });