From ba911d48ea86a764e38281d6d1a36cd6af9139b0 Mon Sep 17 00:00:00 2001 From: KtorZ Date: Sat, 17 Jun 2023 07:26:46 +0200 Subject: [PATCH] Refactor 'is_capture' field on function expressions. Refactored into an enum to make it easier to extend with a new variant to support binary operators. --- crates/aiken-lang/src/expr.rs | 9 +++++++- crates/aiken-lang/src/format.rs | 7 +++--- crates/aiken-lang/src/parser.rs | 4 ++-- crates/aiken-lang/src/tests/parser.rs | 6 ++--- crates/aiken-lang/src/tipo/expr.rs | 33 +++++++++++++++++---------- 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/crates/aiken-lang/src/expr.rs b/crates/aiken-lang/src/expr.rs index d669f171..e42c225c 100644 --- a/crates/aiken-lang/src/expr.rs +++ b/crates/aiken-lang/src/expr.rs @@ -399,6 +399,13 @@ impl TypedExpr { } } +// Represent how a function was written so that we can format it back. +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum FnStyle { + Plain, + Capture, +} + #[derive(Debug, Clone, PartialEq)] pub enum UntypedExpr { Int { @@ -424,7 +431,7 @@ pub enum UntypedExpr { Fn { location: Span, - is_capture: bool, + fn_style: FnStyle, arguments: Vec>, body: Box, return_annotation: Option, diff --git a/crates/aiken-lang/src/format.rs b/crates/aiken-lang/src/format.rs index 7bd755aa..31341e1c 100644 --- a/crates/aiken-lang/src/format.rs +++ b/crates/aiken-lang/src/format.rs @@ -8,7 +8,7 @@ use crate::{ Use, Validator, CAPTURE_VARIABLE, }, docvec, - expr::{UntypedExpr, DEFAULT_ERROR_STR, DEFAULT_TODO_STR}, + expr::{FnStyle, UntypedExpr, DEFAULT_ERROR_STR, DEFAULT_TODO_STR}, parser::{ extra::{Comment, ModuleExtra}, token::Base, @@ -768,12 +768,13 @@ impl<'comments> Formatter<'comments> { UntypedExpr::UnOp { value, op, .. } => self.un_op(value, op), UntypedExpr::Fn { - is_capture: true, + fn_style: FnStyle::Capture, body, .. } => self.fn_capture(body), UntypedExpr::Fn { + fn_style: FnStyle::Plain, return_annotation, arguments: args, body, @@ -1093,7 +1094,7 @@ impl<'comments> Formatter<'comments> { let comments = self.pop_comments(expr.location().start); let doc = match expr { UntypedExpr::Fn { - is_capture: true, + fn_style: FnStyle::Capture, body, .. } => self.pipe_capture_right_hand_side(body), diff --git a/crates/aiken-lang/src/parser.rs b/crates/aiken-lang/src/parser.rs index b24a738c..806704a9 100644 --- a/crates/aiken-lang/src/parser.rs +++ b/crates/aiken-lang/src/parser.rs @@ -935,7 +935,7 @@ pub fn expr_parser( arguments, body: Box::new(body), location: span, - is_capture: false, + fn_style: expr::FnStyle::Plain, return_annotation, }, ); @@ -1205,7 +1205,7 @@ pub fn expr_parser( } else { expr::UntypedExpr::Fn { location: call.location(), - is_capture: true, + fn_style: expr::FnStyle::Capture, arguments: holes, body: Box::new(call), return_annotation: None, diff --git a/crates/aiken-lang/src/tests/parser.rs b/crates/aiken-lang/src/tests/parser.rs index 19479451..9cd9bae0 100644 --- a/crates/aiken-lang/src/tests/parser.rs +++ b/crates/aiken-lang/src/tests/parser.rs @@ -1369,7 +1369,7 @@ fn anonymous_function() { location: Span::new((), 25..67), value: Box::new(expr::UntypedExpr::Fn { location: Span::new((), 39..67), - is_capture: false, + fn_style: expr::FnStyle::Plain, arguments: vec![ast::Arg { arg_name: ast::ArgName::Named { label: "a".to_string(), @@ -1547,7 +1547,7 @@ fn call() { location: Span::new((), 37..82), value: Box::new(expr::UntypedExpr::Fn { location: Span::new((), 53..82), - is_capture: true, + fn_style: expr::FnStyle::Capture, arguments: vec![ast::Arg { arg_name: ast::ArgName::Named { label: "_capture__0".to_string(), @@ -1574,7 +1574,7 @@ fn call() { location: Span::new((), 65..81), value: expr::UntypedExpr::Fn { location: Span::new((), 65..81), - is_capture: false, + fn_style: expr::FnStyle::Plain, arguments: vec![ast::Arg { arg_name: ast::ArgName::Named { label: "y".to_string(), diff --git a/crates/aiken-lang/src/tipo/expr.rs b/crates/aiken-lang/src/tipo/expr.rs index aa613e71..f4a39b06 100644 --- a/crates/aiken-lang/src/tipo/expr.rs +++ b/crates/aiken-lang/src/tipo/expr.rs @@ -12,7 +12,7 @@ use crate::{ UntypedRecordUpdateArg, }, builtins::{bool, byte_array, function, int, list, string, tuple}, - expr::{TypedExpr, UntypedExpr}, + expr::{FnStyle, TypedExpr, UntypedExpr}, format, tipo::fields::FieldMap, }; @@ -220,12 +220,19 @@ impl<'a, 'b> ExprTyper<'a, 'b> { UntypedExpr::Fn { location, - is_capture, + fn_style, arguments: args, body, return_annotation, .. - } => self.infer_fn(args, &[], *body, is_capture, return_annotation, location), + } => self.infer_fn( + args, + &[], + *body, + fn_style == FnStyle::Capture, + return_annotation, + location, + ), UntypedExpr::If { location, @@ -1011,17 +1018,19 @@ impl<'a, 'b> ExprTyper<'a, 'b> { body, return_annotation, location, - is_capture: false, + fn_style, .. }, - ) if expected_arguments.len() == arguments.len() => self.infer_fn( - arguments, - expected_arguments, - *body, - false, - return_annotation, - location, - ), + ) if fn_style != FnStyle::Capture && expected_arguments.len() == arguments.len() => { + self.infer_fn( + arguments, + expected_arguments, + *body, + false, + return_annotation, + location, + ) + } // Otherwise just perform normal type inference. (_, value) => self.infer(value),