feat: impl if/is

This commit introduces a new feature into
the parser, typechecker, and formatter.
The work for code gen will be in the next commit.

I was able to leverage some existing infrastructure
by making using of `AssignmentPattern`. A new field
`is` was introduced into `IfBranch`. This field holds
a generic `Option<Is>` meaning a new generic has to be
introduced into `IfBranch`. When used in `UntypedExpr`,
`IfBranch` must use `AssignmentPattern`. When used in
`TypedExpr`, `IfBranch` must use `TypedPattern`.

The parser was updated such that we can support this
kind of psuedo grammar:

`if <expr:condition> [is [<pattern>: ]<annotation>]`

This can be read as, when parsing an `if` expression,
always expect an expression after the keyword `if`. And then
optionally there may be this `is` stuff, and within that you
may optionally expect a pattern followed by a colon. We will
always expect an annotation.

This first expression is still saved as the field
`condition` in `IfBranch`. If `pattern` is not there
AND `expr:condition` is `UntypedExpr::Var` we can set
the pattern to be `Pattern::Var` with the same name. From
there shadowing should allow this syntax sugar to feel
kinda magical within the `IfBranch` block that follow.

The typechecker doesn't need to be aware of the sugar
described above. The typechecker looks at `branch.is`
and if it's `Some(is)` then it'll use `infer_assignment`
for some help. Because of the way that `is` can inject
variables into the scope of the branch's block and since
it's basically just like how `expect` works minus the error
we get to re-use that helper method.

It's important to note that in the typechecker, if `is`
is `Some(_)` then we do not enforce that `condition` is
of type `Bool`. This is because the bool itself will be
whether or not the `is` itself holds true given a PlutusData
payload.

When `is` is None, we do exactly what was being done
previously so that plain `if` expressions remain unaffected
with no semantic changes.

The formatter had to be made aware of the new changes with
some simple changes that need no further explanation.
This commit is contained in:
rvcas 2024-06-11 19:05:55 -04:00 committed by Lucas
parent b2c42febaf
commit 1b8805825b
15 changed files with 599 additions and 74 deletions

View File

@ -1885,13 +1885,14 @@ impl TypedClauseGuard {
} }
} }
pub type TypedIfBranch = IfBranch<TypedExpr>; pub type TypedIfBranch = IfBranch<TypedExpr, TypedPattern>;
pub type UntypedIfBranch = IfBranch<UntypedExpr>; pub type UntypedIfBranch = IfBranch<UntypedExpr, AssignmentPattern>;
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct IfBranch<Expr> { pub struct IfBranch<Expr, Is> {
pub condition: Expr, pub condition: Expr,
pub body: Expr, pub body: Expr,
pub is: Option<Is>,
pub location: Span, pub location: Span,
} }

View File

@ -2,10 +2,10 @@ use crate::{
ast::{ ast::{
self, Annotation, ArgBy, ArgName, AssignmentPattern, BinOp, Bls12_381Point, self, Annotation, ArgBy, ArgName, AssignmentPattern, BinOp, Bls12_381Point,
ByteArrayFormatPreference, CallArg, Curve, DataType, DataTypeKey, DefinitionLocation, ByteArrayFormatPreference, CallArg, Curve, DataType, DataTypeKey, DefinitionLocation,
IfBranch, Located, LogicalOpChainKind, ParsedCallArg, Pattern, RecordConstructorArg, Located, LogicalOpChainKind, ParsedCallArg, Pattern, RecordConstructorArg,
RecordUpdateSpread, Span, TraceKind, TypedArg, TypedAssignmentKind, TypedClause, RecordUpdateSpread, Span, TraceKind, TypedArg, TypedAssignmentKind, TypedClause,
TypedDataType, TypedRecordUpdateArg, UnOp, UntypedArg, UntypedAssignmentKind, TypedDataType, TypedIfBranch, TypedRecordUpdateArg, UnOp, UntypedArg,
UntypedClause, UntypedRecordUpdateArg, UntypedAssignmentKind, UntypedClause, UntypedIfBranch, UntypedRecordUpdateArg,
}, },
builtins::void, builtins::void,
parser::token::Base, parser::token::Base,
@ -127,7 +127,7 @@ pub enum TypedExpr {
If { If {
location: Span, location: Span,
#[serde(with = "Vec1Ref")] #[serde(with = "Vec1Ref")]
branches: Vec1<IfBranch<Self>>, branches: Vec1<TypedIfBranch>,
final_else: Box<Self>, final_else: Box<Self>,
tipo: Rc<Type>, tipo: Rc<Type>,
}, },
@ -564,7 +564,7 @@ pub enum UntypedExpr {
If { If {
location: Span, location: Span,
branches: Vec1<IfBranch<Self>>, branches: Vec1<UntypedIfBranch>,
final_else: Box<Self>, final_else: Box<Self>,
}, },

View File

@ -2,10 +2,10 @@ use crate::{
ast::{ ast::{
Annotation, ArgBy, ArgName, ArgVia, AssignmentKind, AssignmentPattern, BinOp, Annotation, ArgBy, ArgName, ArgVia, AssignmentKind, AssignmentPattern, BinOp,
ByteArrayFormatPreference, CallArg, ClauseGuard, Constant, CurveType, DataType, Definition, ByteArrayFormatPreference, CallArg, ClauseGuard, Constant, CurveType, DataType, Definition,
Function, IfBranch, LogicalOpChainKind, ModuleConstant, OnTestFailure, Pattern, Function, LogicalOpChainKind, ModuleConstant, OnTestFailure, Pattern, RecordConstructor,
RecordConstructor, RecordConstructorArg, RecordUpdateSpread, Span, TraceKind, TypeAlias, RecordConstructorArg, RecordUpdateSpread, Span, TraceKind, TypeAlias, TypedArg, UnOp,
TypedArg, UnOp, UnqualifiedImport, UntypedArg, UntypedArgVia, UntypedAssignmentKind, UnqualifiedImport, UntypedArg, UntypedArgVia, UntypedAssignmentKind, UntypedClause,
UntypedClause, UntypedClauseGuard, UntypedDefinition, UntypedFunction, UntypedModule, UntypedClauseGuard, UntypedDefinition, UntypedFunction, UntypedIfBranch, UntypedModule,
UntypedPattern, UntypedRecordUpdateArg, Use, Validator, CAPTURE_VARIABLE, UntypedPattern, UntypedRecordUpdateArg, Use, Validator, CAPTURE_VARIABLE,
}, },
docvec, docvec,
@ -1195,7 +1195,7 @@ impl<'comments> Formatter<'comments> {
pub fn if_expr<'a>( pub fn if_expr<'a>(
&mut self, &mut self,
branches: &'a Vec1<IfBranch<UntypedExpr>>, branches: &'a Vec1<UntypedIfBranch>,
final_else: &'a UntypedExpr, final_else: &'a UntypedExpr,
) -> Document<'a> { ) -> Document<'a> {
let if_branches = self let if_branches = self
@ -1223,10 +1223,44 @@ impl<'comments> Formatter<'comments> {
pub fn if_branch<'a>( pub fn if_branch<'a>(
&mut self, &mut self,
if_keyword: Document<'a>, if_keyword: Document<'a>,
branch: &'a IfBranch<UntypedExpr>, branch: &'a UntypedIfBranch,
) -> Document<'a> { ) -> Document<'a> {
let if_begin = if_keyword let if_begin = if_keyword
.append(self.wrap_expr(&branch.condition)) .append(self.wrap_expr(&branch.condition))
.append(match &branch.is {
Some(AssignmentPattern {
pattern,
annotation,
..
}) => {
let is_sugar = matches!(
(&pattern, &branch.condition),
(
Pattern::Var { name, .. },
UntypedExpr::Var { name: var_name, .. }
) if name == var_name
);
let Some(annotation) = &annotation else {
unreachable!()
};
let is = if is_sugar {
self.annotation(annotation)
} else {
self.pattern(pattern)
.append(": ")
.append(self.annotation(annotation))
.group()
};
break_("", " ")
.append("is")
.append(break_("", " "))
.append(is)
}
None => nil(),
})
.append(break_("{", " {")) .append(break_("{", " {"))
.group(); .group();

View File

@ -33,6 +33,13 @@ pub fn let_(
} }
fn assignment_patterns() -> impl Parser<Token, Vec<ast::AssignmentPattern>, Error = ParseError> { fn assignment_patterns() -> impl Parser<Token, Vec<ast::AssignmentPattern>, Error = ParseError> {
assignment_pattern()
.separated_by(just(Token::Comma))
.allow_trailing()
.at_least(1)
}
pub fn assignment_pattern() -> impl Parser<Token, ast::AssignmentPattern, Error = ParseError> {
pattern() pattern()
.then(just(Token::Colon).ignore_then(annotation()).or_not()) .then(just(Token::Colon).ignore_then(annotation()).or_not())
.map_with_span(|(pattern, annotation), span| ast::AssignmentPattern { .map_with_span(|(pattern, annotation), span| ast::AssignmentPattern {
@ -40,9 +47,6 @@ fn assignment_patterns() -> impl Parser<Token, Vec<ast::AssignmentPattern>, Erro
annotation, annotation,
location: span, location: span,
}) })
.separated_by(just(Token::Comma))
.allow_trailing()
.at_least(1)
} }
pub fn expect( pub fn expect(

View File

@ -3,7 +3,7 @@ use chumsky::prelude::*;
use crate::{ use crate::{
ast, ast,
expr::UntypedExpr, expr::UntypedExpr,
parser::{error::ParseError, token::Token}, parser::{annotation, error::ParseError, pattern, token::Token},
}; };
use super::block; use super::block;
@ -40,11 +40,44 @@ fn if_branch<'a>(
expression: Recursive<'a, Token, UntypedExpr, ParseError>, expression: Recursive<'a, Token, UntypedExpr, ParseError>,
) -> impl Parser<Token, ast::UntypedIfBranch, Error = ParseError> + 'a { ) -> impl Parser<Token, ast::UntypedIfBranch, Error = ParseError> + 'a {
expression expression
.then(
just(Token::Is)
.ignore_then(
pattern()
.then_ignore(just(Token::Colon))
.or_not()
.then(annotation())
.map_with_span(|(pattern, annotation), span| (pattern, annotation, span)),
)
.or_not(),
)
.then(block(sequence)) .then(block(sequence))
.map_with_span(|(condition, body), span| ast::IfBranch { .map_with_span(|((condition, is), body), span| {
condition, let is = is.map(|(pattern, annotation, is_span)| {
body, let pattern = pattern.unwrap_or_else(|| match &condition {
location: span, UntypedExpr::Var { name, location } => ast::Pattern::Var {
name: name.clone(),
location: *location,
},
_ => ast::Pattern::Discard {
location: is_span,
name: "_".to_string(),
},
});
ast::AssignmentPattern {
pattern,
annotation: Some(annotation),
location: is_span,
}
});
ast::IfBranch {
condition,
body,
is,
location: span,
}
}) })
} }
@ -81,4 +114,23 @@ mod tests {
"# "#
); );
} }
#[test]
fn if_else_with_soft_cast() {
assert_expr!(
r#"
if ec1 is Some(x): Option<Int> {
ec2
} else if ec1 is Foo { foo }: Foo {
ec1
} else if ec1 is Option<Int> {
let Some(x) = ec1
x
} else {
Infinity
}
"#
);
}
} }

View File

@ -22,6 +22,7 @@ If {
location: 23..26, location: 23..26,
name: "ec2", name: "ec2",
}, },
is: None,
location: 3..28, location: 3..28,
}, },
IfBranch { IfBranch {
@ -56,6 +57,7 @@ If {
location: 60..63, location: 60..63,
name: "ec1", name: "ec1",
}, },
is: None,
location: 37..65, location: 37..65,
}, },
], ],

View File

@ -28,6 +28,7 @@ If {
}, },
}, },
}, },
is: None,
location: 3..19, location: 3..19,
}, },
IfBranch { IfBranch {
@ -53,6 +54,7 @@ If {
numeric_underscore: false, numeric_underscore: false,
}, },
}, },
is: None,
location: 28..41, location: 28..41,
}, },
], ],

View File

@ -0,0 +1,183 @@
---
source: crates/aiken-lang/src/parser/expr/if_else.rs
description: "Code:\n\nif ec1 is Some(x): Option<Int> {\n ec2\n} else if ec1 is Foo { foo }: Foo {\n ec1\n} else if ec1 is Option<Int> {\n let Some(x) = ec1\n\n x\n} else {\n Infinity\n}\n"
---
If {
location: 0..158,
branches: [
IfBranch {
condition: Var {
location: 3..6,
name: "ec1",
},
body: Var {
location: 35..38,
name: "ec2",
},
is: Some(
AssignmentPattern {
pattern: Constructor {
is_record: false,
location: 10..17,
name: "Some",
arguments: [
CallArg {
label: None,
location: 15..16,
value: Var {
location: 15..16,
name: "x",
},
},
],
module: None,
constructor: (),
spread_location: None,
tipo: (),
},
annotation: Some(
Constructor {
location: 19..30,
module: None,
name: "Option",
arguments: [
Constructor {
location: 26..29,
module: None,
name: "Int",
arguments: [],
},
],
},
),
location: 10..30,
},
),
location: 3..40,
},
IfBranch {
condition: Var {
location: 49..52,
name: "ec1",
},
body: Var {
location: 77..80,
name: "ec1",
},
is: Some(
AssignmentPattern {
pattern: Constructor {
is_record: true,
location: 56..67,
name: "Foo",
arguments: [
CallArg {
label: Some(
"foo",
),
location: 62..65,
value: Var {
location: 62..65,
name: "foo",
},
},
],
module: None,
constructor: (),
spread_location: None,
tipo: (),
},
annotation: Some(
Constructor {
location: 69..72,
module: None,
name: "Foo",
arguments: [],
},
),
location: 56..72,
},
),
location: 49..82,
},
IfBranch {
condition: Var {
location: 91..94,
name: "ec1",
},
body: Sequence {
location: 114..136,
expressions: [
Assignment {
location: 114..131,
value: Var {
location: 128..131,
name: "ec1",
},
patterns: [
AssignmentPattern {
pattern: Constructor {
is_record: false,
location: 118..125,
name: "Some",
arguments: [
CallArg {
label: None,
location: 123..124,
value: Var {
location: 123..124,
name: "x",
},
},
],
module: None,
constructor: (),
spread_location: None,
tipo: (),
},
annotation: None,
location: 118..125,
},
],
kind: Let {
backpassing: false,
},
},
Var {
location: 135..136,
name: "x",
},
],
},
is: Some(
AssignmentPattern {
pattern: Var {
location: 91..94,
name: "ec1",
},
annotation: Some(
Constructor {
location: 98..109,
module: None,
name: "Option",
arguments: [
Constructor {
location: 105..108,
module: None,
name: "Int",
arguments: [],
},
],
},
),
location: 98..109,
},
),
location: 91..138,
},
],
final_else: Var {
location: 148..156,
name: "Infinity",
},
}

View File

@ -198,7 +198,7 @@ fn illegal_function_comparison() {
"#; "#;
assert!(matches!( assert!(matches!(
dbg!(check_validator(parse(source_code))), check_validator(parse(source_code)),
Err((_, Error::IllegalComparison { .. })) Err((_, Error::IllegalComparison { .. }))
)) ))
} }
@ -287,7 +287,7 @@ fn illegal_unserialisable_in_generic_miller_loop() {
"#; "#;
assert!(matches!( assert!(matches!(
dbg!(check(parse(source_code))), check(parse(source_code)),
Err((_, Error::IllegalTypeInData { .. })) Err((_, Error::IllegalTypeInData { .. }))
)) ))
} }
@ -2417,7 +2417,7 @@ fn partial_eq_call_args() {
"#; "#;
assert!(matches!( assert!(matches!(
dbg!(check(parse(source_code))), check(parse(source_code)),
Err((_, Error::IncorrectFieldsArity { .. })) Err((_, Error::IncorrectFieldsArity { .. }))
)); ));
} }
@ -2435,7 +2435,7 @@ fn partial_eq_callback_args() {
"#; "#;
assert!(matches!( assert!(matches!(
dbg!(check(parse(source_code))), check(parse(source_code)),
Err((_, Error::CouldNotUnify { .. })) Err((_, Error::CouldNotUnify { .. }))
)); ));
} }
@ -2453,7 +2453,7 @@ fn partial_eq_callback_return() {
"#; "#;
assert!(matches!( assert!(matches!(
dbg!(check(parse(source_code))), check(parse(source_code)),
Err((_, Error::CouldNotUnify { .. })) Err((_, Error::CouldNotUnify { .. }))
)); ));
} }
@ -2488,7 +2488,7 @@ fn pair_index_out_of_bound() {
"#; "#;
assert!(matches!( assert!(matches!(
dbg!(check_validator(parse(source_code))), check_validator(parse(source_code)),
Err((_, Error::PairIndexOutOfBound { .. })) Err((_, Error::PairIndexOutOfBound { .. }))
)) ))
} }
@ -2502,7 +2502,7 @@ fn not_indexable() {
"#; "#;
assert!(matches!( assert!(matches!(
dbg!(check_validator(parse(source_code))), check_validator(parse(source_code)),
Err((_, Error::NotIndexable { .. })) Err((_, Error::NotIndexable { .. }))
)) ))
} }
@ -2520,7 +2520,7 @@ fn out_of_scope_access() {
"#; "#;
assert!(matches!( assert!(matches!(
dbg!(check_validator(parse(source_code))), check_validator(parse(source_code)),
Err((_, Error::UnknownVariable { .. })) Err((_, Error::UnknownVariable { .. }))
)) ))
} }
@ -2552,7 +2552,7 @@ fn fn_single_variant_pattern() {
} }
"#; "#;
assert!(dbg!(check(parse(source_code))).is_ok()); assert!(check(parse(source_code)).is_ok());
} }
#[test] #[test]
@ -2569,7 +2569,132 @@ fn fn_multi_variant_pattern() {
"#; "#;
assert!(matches!( assert!(matches!(
dbg!(check_validator(parse(source_code))), check_validator(parse(source_code)),
Err((_, Error::NotExhaustivePatternMatch { .. })) Err((_, Error::NotExhaustivePatternMatch { .. }))
)) ))
} }
#[test]
fn if_soft_cast() {
let source_code = r#"
pub type Foo {
a: Int
}
pub fn foo(foo: Data) -> Int {
if foo is bar: Foo {
bar.a
} else {
0
}
}
"#;
assert!(check(parse(source_code)).is_ok());
}
#[test]
fn if_soft_cast_sugar() {
let source_code = r#"
pub type Foo {
a: Int
}
pub fn foo(foo: Data) -> Int {
if foo is Foo {
foo.a
} else {
0
}
}
"#;
assert!(check(parse(source_code)).is_ok());
}
#[test]
fn if_soft_cast_record() {
let source_code = r#"
pub type Foo {
a: Int
}
pub fn foo(foo: Data) -> Int {
if foo is Foo { a }: Foo {
a
} else {
0
}
}
"#;
assert!(check(parse(source_code)).is_ok());
}
#[test]
fn if_soft_cast_no_scope_leak() {
let source_code = r#"
pub type Foo {
a: Int
}
pub fn foo(foo: Data) -> Int {
if foo is bar: Foo {
bar.a
} else {
bar
}
}
"#;
assert!(matches!(
check_validator(parse(source_code)),
Err((_, Error::UnknownVariable { name, .. })) if name == "bar"
))
}
#[test]
fn if_soft_cast_not_data_single_constr() {
let source_code = r#"
pub type Foo {
a: Int
}
pub fn foo(foo: Foo) -> Int {
if foo is Foo { a }: Foo {
a
} else {
0
}
}
"#;
let (warnings, _ast) = check(parse(source_code)).unwrap();
assert!(matches!(
warnings[0],
Warning::SingleConstructorExpect { .. }
))
}
#[test]
fn if_soft_cast_not_data_multi_constr() {
let source_code = r#"
pub type Foo {
Bar { a: Int }
Buzz { b: Int }
}
pub fn foo(foo: Foo) -> Int {
if foo is Bar { a }: Foo {
a
} else {
0
}
}
"#;
let (warnings, _ast) = dbg!(check(parse(source_code))).unwrap();
assert!(matches!(warnings[0], Warning::UseWhenInstead { .. }))
}

View File

@ -86,6 +86,43 @@ fn format_if() {
); );
} }
#[test]
fn format_if_soft_cast() {
assert_format!(
r#"
pub fn foo(a) {
if a is Option<Int> { 14 } else { 42 }
}
"#
);
}
#[test]
fn format_if_soft_cast_pattern() {
assert_format!(
r#"
pub fn foo(a) {
if a is Some(x): Option<Int> { 14 } else if b is Foo { b } else { 42 }
}
"#
);
}
#[test]
fn format_if_soft_cast_record() {
assert_format!(
r#"
pub fn foo(foo: Data) -> Int {
if foo is Foo { a }: Foo {
a
} else {
0
}
}
"#
);
}
#[test] #[test]
fn format_logic_op_with_code_block() { fn format_logic_op_with_code_block() {
assert_format!( assert_format!(

View File

@ -0,0 +1,11 @@
---
source: crates/aiken-lang/src/tests/format.rs
description: "Code:\n\npub fn foo(a) {\n if a is Option<Int> { 14 } else { 42 }\n }\n"
---
pub fn foo(a) {
if a is Option<Int> {
14
} else {
42
}
}

View File

@ -0,0 +1,13 @@
---
source: crates/aiken-lang/src/tests/format.rs
description: "Code:\n\npub fn foo(a) {\n if a is Some(x): Option<Int> { 14 } else if b is Foo { b } else { 42 }\n }\n"
---
pub fn foo(a) {
if a is Some(x): Option<Int> {
14
} else if b is Foo {
b
} else {
42
}
}

View File

@ -0,0 +1,11 @@
---
source: crates/aiken-lang/src/tests/format.rs
description: "Code:\n\npub fn foo(foo: Data) -> Int {\n if foo is Foo { a }: Foo {\n a\n } else {\n 0\n }\n}\n"
---
pub fn foo(foo: Data) -> Int {
if foo is Foo { a }: Foo {
a
} else {
0
}
}

View File

@ -1671,6 +1671,25 @@ pub enum Warning {
name: String, name: String,
}, },
#[error(
"I found an {} that checks an expression with a known type.",
"if/is".if_supports_color(Stderr, |s| s.purple())
)]
#[diagnostic(
code("if_is_on_non_data"),
help(
"Prefer using a {} to match on all known constructors.",
"when/is".if_supports_color(Stderr, |s| s.purple())
)
)]
UseWhenInstead {
#[label(
"use {} instead",
"when/is".if_supports_color(Stderr, |s| s.purple())
)]
location: Span,
},
#[error( #[error(
"I came across a discarded variable in a let assignment: {}", "I came across a discarded variable in a let assignment: {}",
name.if_supports_color(Stderr, |s| s.default_color()) name.if_supports_color(Stderr, |s| s.default_color())
@ -1755,7 +1774,8 @@ impl ExtraData for Warning {
| Warning::UnusedType { .. } | Warning::UnusedType { .. }
| Warning::UnusedVariable { .. } | Warning::UnusedVariable { .. }
| Warning::DiscardedLetAssignment { .. } | Warning::DiscardedLetAssignment { .. }
| Warning::ValidatorInLibraryModule { .. } => None, | Warning::ValidatorInLibraryModule { .. }
| Warning::UseWhenInstead { .. } => None,
Warning::Utf8ByteArrayIsValidHexString { value, .. } => Some(value.clone()), Warning::Utf8ByteArrayIsValidHexString { value, .. } => Some(value.clone()),
Warning::UnusedImportedModule { location, .. } => { Warning::UnusedImportedModule { location, .. } => {
Some(format!("{},{}", false, location.start)) Some(format!("{},{}", false, location.start))

View File

@ -700,6 +700,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
branches: vec1::vec1![IfBranch { branches: vec1::vec1![IfBranch {
condition: typed_value, condition: typed_value,
body: var_true, body: var_true,
is: None,
location, location,
}], }],
final_else: Box::new(TypedExpr::Trace { final_else: Box::new(TypedExpr::Trace {
@ -1191,7 +1192,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
let ann_typ = if let Some(ann) = annotation { let ann_typ = if let Some(ann) = annotation {
let ann_typ = self let ann_typ = self
.type_from_annotation(ann) .type_from_annotation(ann)
.map(|t| self.instantiate(t, &mut HashMap::new(), location))??; .and_then(|t| self.instantiate(t, &mut HashMap::new(), location))?;
self.unify( self.unify(
ann_typ.clone(), ann_typ.clone(),
@ -1695,60 +1696,33 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
final_else: UntypedExpr, final_else: UntypedExpr,
location: Span, location: Span,
) -> Result<TypedExpr, Error> { ) -> Result<TypedExpr, Error> {
let first = branches.first(); let mut branches = branches.into_iter();
let first = branches.next().unwrap();
let condition = self.infer(first.condition.clone())?; let first_typed_if_branch = self.infer_if_branch(first)?;
self.unify( let first_body_type = first_typed_if_branch.body.tipo();
bool(),
condition.tipo(),
condition.type_defining_location(),
false,
)?;
assert_no_assignment(&first.body)?; let mut typed_branches = vec1::vec1![first_typed_if_branch];
let body = self.infer(first.body.clone())?;
let tipo = body.tipo(); for branch in branches {
let typed_branch = self.infer_if_branch(branch)?;
let mut typed_branches = vec1::vec1![TypedIfBranch {
body,
condition,
location: first.location,
}];
for branch in &branches[1..] {
let condition = self.infer(branch.condition.clone())?;
self.unify( self.unify(
bool(), first_body_type.clone(),
condition.tipo(), typed_branch.body.tipo(),
condition.type_defining_location(), typed_branch.body.type_defining_location(),
false, false,
)?; )?;
assert_no_assignment(&branch.body)?; typed_branches.push(typed_branch);
let body = self.infer(branch.body.clone())?;
self.unify(
tipo.clone(),
body.tipo(),
body.type_defining_location(),
false,
)?;
typed_branches.push(TypedIfBranch {
body,
condition,
location: branch.location,
});
} }
assert_no_assignment(&final_else)?; assert_no_assignment(&final_else)?;
let typed_final_else = self.infer(final_else)?; let typed_final_else = self.infer(final_else)?;
self.unify( self.unify(
tipo.clone(), first_body_type.clone(),
typed_final_else.tipo(), typed_final_else.tipo(),
typed_final_else.type_defining_location(), typed_final_else.type_defining_location(),
false, false,
@ -1758,7 +1732,63 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
location, location,
branches: typed_branches, branches: typed_branches,
final_else: Box::new(typed_final_else), final_else: Box::new(typed_final_else),
tipo, tipo: first_body_type,
})
}
fn infer_if_branch(&mut self, branch: UntypedIfBranch) -> Result<TypedIfBranch, Error> {
let (condition, body, is) = match branch.is {
Some(is) => self.in_new_scope(|typer| {
let AssignmentPattern {
pattern,
annotation,
location,
} = is;
let TypedExpr::Assignment { value, pattern, .. } = typer.infer_assignment(
pattern,
branch.condition.clone(),
AssignmentKind::expect(),
&annotation,
location,
)?
else {
unreachable!()
};
if !value.tipo().is_data() {
typer.environment.warnings.push(Warning::UseWhenInstead {
location: branch.location,
})
}
assert_no_assignment(&branch.body)?;
let body = typer.infer(branch.body.clone())?;
Ok((*value, body, Some(pattern)))
})?,
None => {
let condition = self.infer(branch.condition.clone())?;
self.unify(
bool(),
condition.tipo(),
condition.type_defining_location(),
false,
)?;
assert_no_assignment(&branch.body)?;
let body = self.infer(branch.body.clone())?;
(condition, body, None)
}
};
Ok(TypedIfBranch {
body,
condition,
is,
location: branch.location,
}) })
} }