diff --git a/crates/aiken-lang/src/ast.rs b/crates/aiken-lang/src/ast.rs index f3b0782d..2a14012f 100644 --- a/crates/aiken-lang/src/ast.rs +++ b/crates/aiken-lang/src/ast.rs @@ -1423,6 +1423,26 @@ impl UntypedPattern { is_record: false, } } + + /// Returns Some() if the pattern is a [`Boolean`] literal, + /// holding the target value. None if it isn't a bool pattern. + pub fn get_bool(&self) -> Option { + match self { + Self::Constructor { + module: None, + name, + constructor: (), + .. + } if name == "True" => Some(true), + Self::Constructor { + module: None, + name, + constructor: (), + .. + } if name == "False" => Some(false), + _ => None, + } + } } impl TypedPattern { diff --git a/crates/aiken-lang/src/expr.rs b/crates/aiken-lang/src/expr.rs index 6e131a5b..2749d9bf 100644 --- a/crates/aiken-lang/src/expr.rs +++ b/crates/aiken-lang/src/expr.rs @@ -1,4 +1,5 @@ -use crate::{ +use crate::tipo::ValueConstructorVariant; +pub(crate) use crate::{ ast::{ self, Annotation, ArgBy, ArgName, AssignmentPattern, BinOp, Bls12_381Point, ByteArrayFormatPreference, CallArg, Curve, DataType, DataTypeKey, DefinitionLocation, @@ -7,7 +8,7 @@ use crate::{ TypedDataType, TypedIfBranch, TypedRecordUpdateArg, UnOp, UntypedArg, UntypedAssignmentKind, UntypedClause, UntypedIfBranch, UntypedRecordUpdateArg, }, - builtins::void, + builtins::{bool, void}, parser::token::Base, tipo::{ check_replaceable_opaque_type, convert_opaque_type, lookup_data_type_by_tipo, @@ -472,6 +473,44 @@ impl TypedExpr { .or(Some(Located::Expression(self))), } } + + pub fn void(location: Span) -> Self { + TypedExpr::Var { + name: "Void".to_string(), + constructor: ValueConstructor { + public: true, + variant: ValueConstructorVariant::Record { + name: "Void".to_string(), + arity: 0, + field_map: None, + location: Span::empty(), + module: String::new(), + constructors_count: 1, + }, + tipo: void(), + }, + location, + } + } + + pub fn bool(value: bool, location: Span) -> Self { + TypedExpr::Var { + name: "Bool".to_string(), + constructor: ValueConstructor { + public: true, + variant: ValueConstructorVariant::Record { + name: if value { "True" } else { "False" }.to_string(), + arity: 0, + field_map: None, + location: Span::empty(), + module: String::new(), + constructors_count: 2, + }, + tipo: bool(), + }, + location, + } + } } // Represent how a function was written so that we can format it back. diff --git a/crates/aiken-lang/src/tests/check.rs b/crates/aiken-lang/src/tests/check.rs index 906240e7..69d18311 100644 --- a/crates/aiken-lang/src/tests/check.rs +++ b/crates/aiken-lang/src/tests/check.rs @@ -1009,85 +1009,6 @@ fn anonymous_function_dupicate_args() { )) } -#[test] -fn assignement_last_expr_when() { - let source_code = r#" - pub fn foo() { - let bar = None - - when bar is { - Some(_) -> { - let wow = 1 - } - None -> { - 2 - } - } - } - "#; - - assert!(matches!( - check(parse(source_code)), - Err((_, Error::LastExpressionIsAssignment { .. })) - )) -} - -#[test] -fn assignement_last_expr_if_first_branch() { - let source_code = r#" - pub fn foo() { - if True { - let thing = 1 - } else { - 1 - } - } - "#; - - assert!(matches!( - check(parse(source_code)), - Err((_, Error::LastExpressionIsAssignment { .. })) - )) -} - -#[test] -fn assignement_last_expr_if_branches() { - let source_code = r#" - pub fn foo() { - if True { - 2 - } else if False { - let thing = 1 - } else { - 1 - } - } - "#; - - assert!(matches!( - check(parse(source_code)), - Err((_, Error::LastExpressionIsAssignment { .. })) - )) -} - -#[test] -fn assignement_last_expr_if_final_else() { - let source_code = r#" - pub fn foo() { - if True { - 1 - } else { - let thing = 1 - } - } - "#; - - assert!(matches!( - check(parse(source_code)), - Err((_, Error::LastExpressionIsAssignment { .. })) - )) -} - #[test] fn if_scoping() { let source_code = r#" @@ -2956,3 +2877,79 @@ fn pattern_bytearray_not_unify_subject() { Err((_, Error::CouldNotUnify { .. })) )) } + +#[test] +fn recover_no_assignment_sequence() { + let source_code = r#" + pub fn main() { + let result = 42 + expect result + 1 == 43 + } + "#; + + assert!(check(parse(source_code)).is_ok()); +} + +#[test] +fn recover_no_assignment_fn_body() { + let source_code = r#" + pub fn is_bool(foo: Data) -> Void { + expect _: Bool = foo + } + "#; + + assert!(check(parse(source_code)).is_ok()); +} + +#[test] +fn recover_no_assignment_when_clause() { + let source_code = r#" + pub fn main(foo) { + when foo is { + [] -> let bar = foo + [x, ..] -> expect _: Int = x + } + } + "#; + + let (warnings, _) = check(parse(source_code)).unwrap(); + + assert!(matches!( + &warnings[..], + [Warning::UnusedVariable { name, .. }] if name == "bar", + )) +} + +#[test] +fn recover_no_assignment_fn_if_then_else() { + let source_code = r#" + pub fn foo(weird_maths) -> Bool { + if weird_maths { + expect 1 == 2 + } else { + expect 1 + 1 == 2 + } + } + "#; + + assert!(check(parse(source_code)).is_ok()); +} + +#[test] +fn recover_no_assignment_logical_chain_op() { + let source_code = r#" + pub fn foo() -> Bool { + and { + expect 1 + 1 == 2, + True, + 2 > 0, + or { + expect True, + False, + } + } + } + "#; + + assert!(check(parse(source_code)).is_ok()); +} diff --git a/crates/aiken-lang/src/tipo/error.rs b/crates/aiken-lang/src/tipo/error.rs index 91dbb988..0fe9293d 100644 --- a/crates/aiken-lang/src/tipo/error.rs +++ b/crates/aiken-lang/src/tipo/error.rs @@ -2,7 +2,7 @@ use super::Type; use crate::{ ast::{Annotation, BinOp, CallArg, LogicalOpChainKind, Span, UntypedFunction, UntypedPattern}, error::ExtraData, - expr::{self, UntypedExpr}, + expr::{self, AssignmentPattern, UntypedExpr}, format::Formatter, levenshtein, pretty::Documentable, @@ -15,6 +15,7 @@ use owo_colors::{ Stream::{Stderr, Stdout}, }; use std::{collections::HashMap, fmt::Display, rc::Rc}; +use vec1::Vec1; #[derive(Debug, Clone, thiserror::Error)] #[error( @@ -470,6 +471,7 @@ If you really meant to return that last expression, try to replace it with the f #[label("let-binding as last expression")] location: Span, expr: expr::UntypedExpr, + patterns: Vec1, }, #[error( diff --git a/crates/aiken-lang/src/tipo/expr.rs b/crates/aiken-lang/src/tipo/expr.rs index cfa63a1b..7c0041fc 100644 --- a/crates/aiken-lang/src/tipo/expr.rs +++ b/crates/aiken-lang/src/tipo/expr.rs @@ -1397,9 +1397,16 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let (then, typed_patterns) = self.in_new_scope(|scope| { let typed_patterns = scope.infer_clause_pattern(patterns, subject, &location)?; - assert_no_assignment(&then)?; - - let then = scope.infer(then)?; + let then = if let Some(filler) = + recover_from_no_assignment(assert_no_assignment(&then), then.location())? + { + TypedExpr::Sequence { + location, + expressions: vec![scope.infer(then)?, filler], + } + } else { + scope.infer(then)? + }; Ok::<_, Error>((then, typed_patterns)) })?; @@ -1521,8 +1528,16 @@ impl<'a, 'b> ExprTyper<'a, 'b> { typed_branches.push(typed_branch); } - assert_no_assignment(&final_else)?; - let typed_final_else = self.infer(final_else)?; + let typed_final_else = if let Some(filler) = + recover_from_no_assignment(assert_no_assignment(&final_else), final_else.location())? + { + TypedExpr::Sequence { + location: final_else.location(), + expressions: vec![self.infer(final_else)?, filler], + } + } else { + self.infer(final_else)? + }; self.unify( first_body_type.clone(), @@ -1569,8 +1584,18 @@ impl<'a, 'b> ExprTyper<'a, 'b> { location: branch.condition.location().union(location), }) } - assert_no_assignment(&branch.body)?; - let body = typer.infer(branch.body.clone())?; + + let body = if let Some(filler) = recover_from_no_assignment( + assert_no_assignment(&branch.body), + branch.body.location(), + )? { + TypedExpr::Sequence { + location: branch.body.location(), + expressions: vec![typer.infer(branch.body.clone())?, filler], + } + } else { + typer.infer(branch.body.clone())? + }; Ok((*value, body, Some((pattern, tipo)))) })?, @@ -1584,8 +1609,17 @@ impl<'a, 'b> ExprTyper<'a, 'b> { false, )?; - assert_no_assignment(&branch.body)?; - let body = self.infer(branch.body.clone())?; + let body = if let Some(filler) = recover_from_no_assignment( + assert_no_assignment(&branch.body), + branch.body.location(), + )? { + TypedExpr::Sequence { + location: branch.body.location(), + expressions: vec![self.infer(branch.body.clone())?, filler], + } + } else { + self.infer(branch.body.clone())? + }; (condition, body, None) } @@ -1631,7 +1665,9 @@ impl<'a, 'b> ExprTyper<'a, 'b> { body: UntypedExpr, return_type: Option>, ) -> Result<(Vec, TypedExpr, Rc), Error> { - assert_no_assignment(&body)?; + let location = body.location(); + + let no_assignment = assert_no_assignment(&body); let (body_rigid_names, body_infer) = self.in_new_scope(|body_typer| { let mut argument_names = HashMap::with_capacity(args.len()); @@ -1668,7 +1704,17 @@ impl<'a, 'b> ExprTyper<'a, 'b> { Ok((body_typer.hydrator.rigid_names(), body_typer.infer(body))) })?; - let body = body_infer.map_err(|e| e.with_unify_error_rigid_names(&body_rigid_names))?; + let inferred_body = + body_infer.map_err(|e| e.with_unify_error_rigid_names(&body_rigid_names)); + + let body = if let Some(filler) = recover_from_no_assignment(no_assignment, location)? { + TypedExpr::Sequence { + location, + expressions: vec![inferred_body?, filler], + } + } else { + inferred_body? + }; // Check that any return type is accurate. let return_type = match return_type { @@ -1757,9 +1803,17 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let mut typed_expressions = vec![]; for expression in expressions { - assert_no_assignment(&expression)?; - - let typed_expression = self.infer(expression)?; + let typed_expression = if let Some(filler) = recover_from_no_assignment( + assert_no_assignment(&expression), + expression.location(), + )? { + TypedExpr::Sequence { + location: expression.location(), + expressions: vec![self.infer(expression)?, filler], + } + } else { + self.infer(expression)? + }; self.unify( bool(), @@ -2046,21 +2100,29 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let typed_expression = scope.infer(expression)?; - expressions.push(match i.cmp(&(count - 1)) { + match i.cmp(&(count - 1)) { // When the expression is the last in a sequence, we enforce it is NOT // an assignment (kind of treat assignments like statements). Ordering::Equal => { - no_assignment?; - typed_expression + if let Some(filler) = + recover_from_no_assignment(no_assignment, typed_expression.location())? + { + expressions.push(typed_expression); + expressions.push(filler); + } else { + expressions.push(typed_expression); + } } // This isn't the final expression in the sequence, so it *must* // be a let-binding; we do not allow anything else. - Ordering::Less => assert_assignment(typed_expression)?, + Ordering::Less => { + expressions.push(assert_assignment(typed_expression)?); + } // Can't actually happen - Ordering::Greater => typed_expression, - }) + Ordering::Greater => unreachable!(), + } } Ok(expressions) @@ -2479,11 +2541,28 @@ impl<'a, 'b> ExprTyper<'a, 'b> { } } +fn recover_from_no_assignment( + result: Result<(), Error>, + span: Span, +) -> Result, Error> { + if let Err(Error::LastExpressionIsAssignment { patterns, .. }) = result { + match patterns.first().pattern.get_bool() { + Some(expected) if patterns.len() == 1 => Ok(Some(TypedExpr::bool(expected, span))), + _ => Ok(Some(TypedExpr::void(span))), + } + } else { + result.map(|()| None) + } +} + fn assert_no_assignment(expr: &UntypedExpr) -> Result<(), Error> { match expr { - UntypedExpr::Assignment { value, .. } => Err(Error::LastExpressionIsAssignment { + UntypedExpr::Assignment { + value, patterns, .. + } => Err(Error::LastExpressionIsAssignment { location: expr.location(), expr: *value.clone(), + patterns: patterns.clone(), }), UntypedExpr::Trace { then, .. } => assert_no_assignment(then), UntypedExpr::Fn { .. }