diff --git a/crates/aiken-lang/src/format.rs b/crates/aiken-lang/src/format.rs index 4537c25d..d74c6ad8 100644 --- a/crates/aiken-lang/src/format.rs +++ b/crates/aiken-lang/src/format.rs @@ -516,7 +516,7 @@ impl<'comments> Formatter<'comments> { .group(); // Format body - let body = self.expr(body); + let body = self.expr(body, true); // Add any trailing comments let body = match printed_comments(self.pop_comments(end_location), false) { @@ -609,8 +609,10 @@ impl<'comments> Formatter<'comments> { ) -> Document<'a> { let args = wrap_args(args.iter().map(|e| (self.fn_arg(e), false))).group(); let body = match body { - UntypedExpr::Trace { .. } | UntypedExpr::When { .. } => self.expr(body).force_break(), - _ => self.expr(body), + UntypedExpr::Trace { .. } | UntypedExpr::When { .. } => { + self.expr(body, false).force_break() + } + _ => self.expr(body, false), }; let header = "fn".to_doc().append(args); @@ -634,15 +636,19 @@ impl<'comments> Formatter<'comments> { fn sequence<'a>(&mut self, expressions: &'a [UntypedExpr]) -> Document<'a> { let count = expressions.len(); let mut documents = Vec::with_capacity(count * 2); + for (i, expression) in expressions.iter().enumerate() { let preceding_newline = self.pop_empty_lines(expression.start_byte_index()); + if i != 0 && preceding_newline { documents.push(lines(2)); } else if i != 0 { documents.push(lines(1)); } - documents.push(self.expr(expression).group()); + + documents.push(self.expr(expression, false).group()); } + documents.to_doc().force_break() } @@ -782,7 +788,7 @@ impl<'comments> Formatter<'comments> { } } - pub fn expr<'a>(&mut self, expr: &'a UntypedExpr) -> Document<'a> { + pub fn expr<'a>(&mut self, expr: &'a UntypedExpr, is_top_level: bool) -> Document<'a> { let comments = self.pop_comments(expr.start_byte_index()); let document = match expr { @@ -821,7 +827,18 @@ impl<'comments> Formatter<'comments> { UntypedExpr::String { value, .. } => self.string(value), - UntypedExpr::Sequence { expressions, .. } => self.sequence(expressions), + UntypedExpr::Sequence { expressions, .. } => { + let sequence = self.sequence(expressions); + + if is_top_level { + sequence + } else { + "{".to_doc() + .append(line().append(sequence).nest(INDENT).group()) + .append(line()) + .append("}") + } + } UntypedExpr::Var { name, .. } if name.contains(CAPTURE_VARIABLE) => "_".to_doc(), @@ -878,7 +895,10 @@ impl<'comments> Formatter<'comments> { UntypedExpr::FieldAccess { label, container, .. - } => self.expr(container).append(".").append(label.as_str()), + } => self + .expr(container, false) + .append(".") + .append(label.as_str()), UntypedExpr::RecordUpdate { constructor, @@ -893,7 +913,7 @@ impl<'comments> Formatter<'comments> { UntypedExpr::TupleIndex { index, tuple, .. } => { let suffix = Ordinal(*index + 1).suffix().to_doc(); - self.expr(tuple) + self.expr(tuple, false) .append(".".to_doc()) .append((index + 1).to_doc()) .append(suffix) @@ -953,7 +973,7 @@ impl<'comments> Formatter<'comments> { } else { line() }) - .append(self.expr(then)), + .append(self.expr(then, false)), } } @@ -1027,7 +1047,7 @@ impl<'comments> Formatter<'comments> { false }; - self.expr(fun) + self.expr(fun, false) .append(wrap_args( args.iter() .map(|a| (self.call_arg(a, needs_curly), needs_curly)), @@ -1051,7 +1071,7 @@ impl<'comments> Formatter<'comments> { let else_begin = line().append("} else {"); - let else_body = line().append(self.expr(final_else)).nest(INDENT); + let else_body = line().append(self.expr(final_else, false)).nest(INDENT); let else_end = line().append("}"); @@ -1072,7 +1092,7 @@ impl<'comments> Formatter<'comments> { .append(break_("{", " {")) .group(); - let if_body = line().append(self.expr(&branch.body)).nest(INDENT); + let if_body = line().append(self.expr(&branch.body, false)).nest(INDENT); if_begin.append(if_body) } @@ -1110,8 +1130,8 @@ impl<'comments> Formatter<'comments> { args: &'a [UntypedRecordUpdateArg], ) -> Document<'a> { use std::iter::once; - let constructor_doc = self.expr(constructor); - let spread_doc = "..".to_doc().append(self.expr(&spread.base)); + let constructor_doc = self.expr(constructor, false); + let spread_doc = "..".to_doc().append(self.expr(&spread.base, false)); let arg_docs = args.iter().map(|a| (self.record_update_arg(a), true)); let all_arg_docs = once((spread_doc, true)).chain(arg_docs); constructor_doc.append(wrap_args(all_arg_docs)).group() @@ -1128,8 +1148,8 @@ impl<'comments> Formatter<'comments> { let left_precedence = left.binop_precedence(); let right_precedence = right.binop_precedence(); - let left = self.expr(left); - let right = self.expr(right); + let left = self.expr(left, false); + let right = self.expr(right, false); self.operator_side(left, precedence, left_precedence) .append(" ") @@ -1161,7 +1181,9 @@ impl<'comments> Formatter<'comments> { .append( line() .append(join( - expressions.iter().map(|expression| self.expr(expression)), + expressions + .iter() + .map(|expression| self.expr(expression, false)), ",".to_doc().append(line()), )) .nest(INDENT) @@ -1241,10 +1263,10 @@ impl<'comments> Formatter<'comments> { if hole_in_first_position && args.len() == 1 { // x |> fun(_) - self.expr(fun) + self.expr(fun, false) } else if hole_in_first_position { // x |> fun(_, 2, 3) - self.expr(fun).append( + self.expr(fun, false).append( wrap_args( args.iter() .skip(1) @@ -1254,7 +1276,7 @@ impl<'comments> Formatter<'comments> { ) } else { // x |> fun(1, _, 3) - self.expr(fun) + self.expr(fun, false) .append(wrap_args(args.iter().map(|a| (self.call_arg(a, false), false))).group()) } } @@ -1267,14 +1289,14 @@ impl<'comments> Formatter<'comments> { .. } => match args.as_slice() { [first, second] if is_breakable_expr(&second.value) && first.is_capture_hole() => { - self.expr(fun) + self.expr(fun, false) .append("(_, ") .append(self.call_arg(second, false)) .append(")") .group() } - _ => self.expr(fun).append( + _ => self.expr(fun, false).append( wrap_args(args.iter().map(|a| (self.call_arg(a, false), false))).group(), ), }, @@ -1556,12 +1578,12 @@ impl<'comments> Formatter<'comments> { | UntypedExpr::Sequence { .. } | UntypedExpr::Assignment { .. } => "{" .to_doc() - .append(line().append(self.expr(expr)).nest(INDENT)) + .append(line().append(self.expr(expr, false)).nest(INDENT)) .append(line()) .append("}") .force_break(), - _ => self.expr(expr), + _ => self.expr(expr, false), } } @@ -1598,18 +1620,21 @@ impl<'comments> Formatter<'comments> { | UntypedExpr::Sequence { .. } | UntypedExpr::Assignment { .. } => " {" .to_doc() - .append(line().append(self.expr(expr)).nest(INDENT).group()) + .append(line().append(self.expr(expr, true)).nest(INDENT).group()) .append(line()) .append("}") .force_break(), UntypedExpr::Fn { .. } | UntypedExpr::List { .. } => { - line().append(self.expr(expr)).nest(INDENT).group() + line().append(self.expr(expr, false)).nest(INDENT).group() } - UntypedExpr::When { .. } => line().append(self.expr(expr)).nest(INDENT).group(), + UntypedExpr::When { .. } => line().append(self.expr(expr, false)).nest(INDENT).group(), - _ => break_("", " ").append(self.expr(expr)).nest(INDENT).group(), + _ => break_("", " ") + .append(self.expr(expr, false)) + .nest(INDENT) + .group(), } } @@ -1647,7 +1672,7 @@ impl<'comments> Formatter<'comments> { || break_(",", ", ") }; let elements_document = join(elements.iter().map(|e| self.wrap_expr(e)), comma()); - let tail = tail.map(|e| self.expr(e)); + let tail = tail.map(|e| self.expr(e, false)); list(elements_document, elements.len(), tail) } @@ -1771,7 +1796,7 @@ impl<'comments> Formatter<'comments> { fn wrap_unary_op<'a>(&mut self, expr: &'a UntypedExpr) -> Document<'a> { match expr { - UntypedExpr::BinOp { .. } => "(".to_doc().append(self.expr(expr)).append(")"), + UntypedExpr::BinOp { .. } => "(".to_doc().append(self.expr(expr, false)).append(")"), _ => self.wrap_expr(expr), } } diff --git a/crates/aiken-lang/src/tests/format.rs b/crates/aiken-lang/src/tests/format.rs index ed11c321..d796b09c 100644 --- a/crates/aiken-lang/src/tests/format.rs +++ b/crates/aiken-lang/src/tests/format.rs @@ -62,6 +62,31 @@ fn format_if() { ); } +#[test] +fn format_logic_op_with_code_block() { + assert_format!( + r#" + fn foo() { + True || { + let bar = 1 + bar == bar + } + } + "# + ); +} + +#[test] +fn format_grouped_expression() { + assert_format!( + r#" + fn foo() { + y == { x |> f } + } + "# + ); +} + #[test] fn format_validator() { assert_format!( diff --git a/crates/aiken-lang/src/tests/snapshots/format_logic_op_with_code_block.snap b/crates/aiken-lang/src/tests/snapshots/format_logic_op_with_code_block.snap new file mode 100644 index 00000000..1129fb7d --- /dev/null +++ b/crates/aiken-lang/src/tests/snapshots/format_logic_op_with_code_block.snap @@ -0,0 +1,11 @@ +--- +source: crates/aiken-lang/src/tests/format.rs +description: "Code:\n\nfn foo() {\n True || {\n let bar = 1\n bar == bar\n }\n}\n" +--- +fn foo() { + True || { + let bar = 1 + bar == bar + } +} + diff --git a/crates/aiken-lang/src/tipo/error.rs b/crates/aiken-lang/src/tipo/error.rs index d298244a..b2fa7b56 100644 --- a/crates/aiken-lang/src/tipo/error.rs +++ b/crates/aiken-lang/src/tipo/error.rs @@ -1630,7 +1630,7 @@ pub enum UnknownRecordFieldSituation { fn format_suggestion(sample: &UntypedExpr) -> String { Formatter::new() - .expr(sample) + .expr(sample, false) .to_pretty_string(70) .lines() .enumerate() diff --git a/crates/aiken-lang/src/tipo/expr.rs b/crates/aiken-lang/src/tipo/expr.rs index bef14653..681d262d 100644 --- a/crates/aiken-lang/src/tipo/expr.rs +++ b/crates/aiken-lang/src/tipo/expr.rs @@ -426,7 +426,9 @@ impl<'a, 'b> ExprTyper<'a, 'b> { tipo: string(), value: format!( "{} ? False", - format::Formatter::new().expr(&value).to_pretty_string(999) + format::Formatter::new() + .expr(&value, false) + .to_pretty_string(999) ), };