diff --git a/CHANGELOG.md b/CHANGELOG.md index e7a73ed2..e507b2e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ - **aiken-lang**: Strings can contain a nul byte using the escape sequence `\0`. @KtorZ - **aiken**: The `check` command now accept an extra (optional) option `--max-success` to control the number of property-test iterations to perform. @KtorZ - **aiken**: The `docs` command now accept an optional flag `--include-dependencies` to include all dependencies in the generated documentation. @KtorZ +- **aiken-lang**: Implement [function backpassing](https://www.roc-lang.org/tutorial#backpassing) as a syntactic sugar. @KtorZ ### Fixed diff --git a/crates/aiken-lang/src/ast.rs b/crates/aiken-lang/src/ast.rs index 409b7304..c603b685 100644 --- a/crates/aiken-lang/src/ast.rs +++ b/crates/aiken-lang/src/ast.rs @@ -16,6 +16,7 @@ use std::{ use uplc::machine::runtime::Compressable; use vec1::Vec1; +pub const BACKPASS_VARIABLE: &str = "_backpass"; pub const CAPTURE_VARIABLE: &str = "_capture"; pub const PIPE_VARIABLE: &str = "_pipe"; @@ -792,6 +793,19 @@ impl Arg { self.arg_name.get_variable_name() } + pub fn is_capture(&self) -> bool { + if let ArgName::Named { + ref name, location, .. + } = self.arg_name + { + return name.starts_with(CAPTURE_VARIABLE) + && location == Span::empty() + && self.location == Span::empty(); + } + + false + } + pub fn put_doc(&mut self, new_doc: String) { self.doc = Some(new_doc); } diff --git a/crates/aiken-lang/src/expr.rs b/crates/aiken-lang/src/expr.rs index 28db5a8a..7c310fe4 100644 --- a/crates/aiken-lang/src/expr.rs +++ b/crates/aiken-lang/src/expr.rs @@ -1,10 +1,10 @@ use crate::{ ast::{ - self, Annotation, Arg, AssignmentKind, BinOp, Bls12_381Point, ByteArrayFormatPreference, - CallArg, Curve, DataType, DataTypeKey, DefinitionLocation, IfBranch, Located, - LogicalOpChainKind, ParsedCallArg, Pattern, RecordConstructorArg, RecordUpdateSpread, Span, - TraceKind, TypedClause, TypedDataType, TypedRecordUpdateArg, UnOp, UntypedClause, - UntypedRecordUpdateArg, + self, Annotation, Arg, ArgName, AssignmentKind, BinOp, Bls12_381Point, + ByteArrayFormatPreference, CallArg, Curve, DataType, DataTypeKey, DefinitionLocation, + IfBranch, Located, LogicalOpChainKind, ParsedCallArg, Pattern, RecordConstructorArg, + RecordUpdateSpread, Span, TraceKind, TypedClause, TypedDataType, TypedRecordUpdateArg, + UnOp, UntypedClause, UntypedRecordUpdateArg, }, builtins::void, parser::token::Base, @@ -1299,4 +1299,29 @@ impl UntypedExpr { Self::String { .. } | Self::UInt { .. } | Self::ByteArray { .. } ) } + + pub fn lambda(name: String, expressions: Vec, location: Span) -> Self { + Self::Fn { + location, + fn_style: FnStyle::Plain, + arguments: vec![Arg { + location, + doc: None, + annotation: None, + tipo: (), + arg_name: ArgName::Named { + label: name.clone(), + name, + location, + is_validator_param: false, + }, + }], + body: Self::Sequence { + location, + expressions, + } + .into(), + return_annotation: None, + } + } } diff --git a/crates/aiken-lang/src/format.rs b/crates/aiken-lang/src/format.rs index 33861153..e5c67046 100644 --- a/crates/aiken-lang/src/format.rs +++ b/crates/aiken-lang/src/format.rs @@ -1844,7 +1844,11 @@ impl<'a> Documentable<'a> for &'a ArgName { } fn pub_(public: bool) -> Document<'static> { - if public { "pub ".to_doc() } else { nil() } + if public { + "pub ".to_doc() + } else { + nil() + } } impl<'a> Documentable<'a> for &'a UnqualifiedImport { diff --git a/crates/aiken-lang/src/tests/check.rs b/crates/aiken-lang/src/tests/check.rs index 33afd01a..89184c28 100644 --- a/crates/aiken-lang/src/tests/check.rs +++ b/crates/aiken-lang/src/tests/check.rs @@ -1185,6 +1185,141 @@ fn trace_if_false_ok() { assert!(check(parse(source_code)).is_ok()) } +#[test] +fn backpassing_basic() { + let source_code = r#" + fn and_then(opt: Option, then: fn(a) -> Option) -> Option { + when opt is { + None -> None + Some(a) -> then(a) + } + } + + fn backpassing(opt_i: Option, opt_j: Option) -> Option { + let i <- and_then(opt_i) + let j <- and_then(opt_j) + Some(i + j) + } + "#; + + assert!(check(parse(source_code)).is_ok()) +} + +#[test] +fn backpassing_interleaved_capture() { + let source_code = r#" + fn and_then(opt: Option, then: fn(a) -> Option) -> Option { + when opt is { + None -> None + Some(a) -> then(a) + } + } + + fn backpassing(opt_i: Option, opt_j: Option) -> Option { + let f = and_then(opt_i, _) + let i <- f + let g = and_then(opt_j, _) + let j <- g + Some(i + j) + } + "#; + + assert!(check(parse(source_code)).is_ok()) +} + +#[test] +fn backpassing_patterns() { + let source_code = r#" + fn and_then(opt: Option, then: fn(a) -> Option) -> Option { + when opt is { + None -> None + Some(a) -> then(a) + } + } + + type Foo { + foo: Int, + } + + fn backpassing(opt_i: Option, opt_j: Option) -> Option { + let Foo { foo: i } <- and_then(opt_i) + let Foo { foo: j } <- and_then(opt_j) + Some(i + j) + } + "#; + + assert!(check(parse(source_code)).is_ok()) +} + +#[test] +fn backpassing_not_a_function() { + let source_code = r#" + fn and_then(opt: Option, then: fn(a) -> Option) -> Option { + when opt is { + None -> None + Some(a) -> then(a) + } + } + + fn backpassing(opt_i: Option, opt_j: Option) -> Option { + let i <- opt_i + let j <- and_then(opt_j) + Some(i + j) + } + "#; + + assert!(matches!( + check(parse(source_code)), + Err((_, Error::NotFn { .. })) + )) +} + +#[test] +fn backpassing_non_exhaustive_pattern() { + let source_code = r#" + fn and_then(opt: Option, then: fn(a) -> Option) -> Option { + when opt is { + None -> None + Some(a) -> then(a) + } + } + + fn backpassing(opt_i: Option, opt_j: Option) -> Option { + let 42 <- and_then(opt_i) + let j <- and_then(opt_j) + Some(i + j) + } + "#; + + assert!(matches!( + check(parse(source_code)), + Err((_, Error::NotExhaustivePatternMatch { .. })) + )) +} + +#[test] +fn backpassing_unsaturated_fn() { + let source_code = r#" + fn and_then(opt: Option, then: fn(a) -> Option) -> Option { + when opt is { + None -> None + Some(a) -> then(a) + } + } + + fn backpassing(opt_i: Option, opt_j: Option) -> Option { + let i <- and_then + let j <- and_then(opt_j) + Some(i + j) + } + "#; + + assert!(matches!( + check(parse(source_code)), + Err((_, Error::IncorrectFieldsArity { .. })) + )) +} + #[test] fn trace_if_false_ko() { let source_code = r#" diff --git a/crates/aiken-lang/src/tipo/expr.rs b/crates/aiken-lang/src/tipo/expr.rs index fe706e08..8608b2f6 100644 --- a/crates/aiken-lang/src/tipo/expr.rs +++ b/crates/aiken-lang/src/tipo/expr.rs @@ -8,12 +8,12 @@ use super::{ }; use crate::{ ast::{ - Annotation, Arg, ArgName, AssignmentKind, BinOp, Bls12_381Point, ByteArrayFormatPreference, - CallArg, ClauseGuard, Constant, Curve, IfBranch, LogicalOpChainKind, Pattern, - RecordUpdateSpread, Span, TraceKind, TraceLevel, Tracing, TypedArg, TypedCallArg, - TypedClause, TypedClauseGuard, TypedIfBranch, TypedPattern, TypedRecordUpdateArg, UnOp, - UntypedArg, UntypedClause, UntypedClauseGuard, UntypedIfBranch, UntypedPattern, - UntypedRecordUpdateArg, + self, Annotation, Arg, ArgName, AssignmentKind, BinOp, Bls12_381Point, + ByteArrayFormatPreference, CallArg, ClauseGuard, Constant, Curve, IfBranch, + LogicalOpChainKind, Pattern, RecordUpdateSpread, Span, TraceKind, TraceLevel, Tracing, + TypedArg, TypedCallArg, TypedClause, TypedClauseGuard, TypedIfBranch, TypedPattern, + TypedRecordUpdateArg, UnOp, UntypedArg, UntypedClause, UntypedClauseGuard, UntypedIfBranch, + UntypedPattern, UntypedRecordUpdateArg, }, builtins::{ bool, byte_array, function, g1_element, g2_element, int, list, string, tuple, void, @@ -24,7 +24,7 @@ use crate::{ tipo::{fields::FieldMap, PatternConstructor, TypeVar}, }; use std::{cmp::Ordering, collections::HashMap, ops::Deref, rc::Rc}; -use vec1::{vec1, Vec1}; +use vec1::Vec1; #[derive(Debug)] pub(crate) struct ExprTyper<'a, 'b> { @@ -1711,27 +1711,150 @@ impl<'a, 'b> ExprTyper<'a, 'b> { PipeTyper::infer(self, expressions) } - fn infer_seq(&mut self, location: Span, untyped: Vec) -> Result { - let mut breakpoint = None; + fn backpass(&mut self, breakpoint: UntypedExpr, continuation: Vec) -> UntypedExpr { + let (assign_location, value, pattern, annotation) = match breakpoint { + UntypedExpr::Assignment { + location, + value, + pattern, + annotation, + .. + } => (location, value, pattern, annotation), + _ => unreachable!("backpass misuse: breakpoint isn't an Assignment ?!"), + }; - let mut sequence = self.in_new_scope(|scope| { - let count = untyped.len(); + // In case where we have a Pattern that isn't simply a let-binding to a name, we do insert an extra let-binding + // in front of the continuation sequence. This is because we do not support patterns in function argument + // (which is perhaps something we should support?). + let (name, continuation) = match pattern { + Pattern::Var { name, .. } | Pattern::Discard { name, .. } => { + (name.clone(), continuation) + } + _ => { + let mut with_assignment = vec![UntypedExpr::Assignment { + location: assign_location, + value: UntypedExpr::Var { + location: assign_location, + name: ast::BACKPASS_VARIABLE.to_string(), + } + .into(), + pattern, + kind: AssignmentKind::Let, + annotation, + }]; + with_assignment.extend(continuation); + (ast::BACKPASS_VARIABLE.to_string(), with_assignment) + } + }; + + match *value { + UntypedExpr::Call { + location: call_location, + fun, + arguments, + } => { + let mut new_arguments = Vec::new(); + new_arguments.extend(arguments); + new_arguments.push(CallArg { + location: assign_location, + label: None, + value: UntypedExpr::lambda(name, continuation, call_location), + }); + + UntypedExpr::Call { + location: call_location, + fun, + arguments: new_arguments, + } + } + + // This typically occurs on function captures. We do not try to assert anything on the + // length of the arguments here. We defer that to the rest of the type-checker. The + // only thing we have to do is rewrite the AST as-if someone had passed a callback. + // + // Now, whether this leads to an invalid call usage, that's not *our* immediate + // problem. + UntypedExpr::Fn { + location: call_location, + fn_style, + ref arguments, + ref return_annotation, + .. + } => { + let return_annotation = return_annotation.clone(); + + let arguments = arguments.iter().skip(1).cloned().collect::>(); + + let call = UntypedExpr::Call { + location: call_location, + fun: value, + arguments: vec![CallArg { + location: assign_location, + label: None, + value: UntypedExpr::lambda(name, continuation, call_location), + }], + }; + + if arguments.is_empty() { + call + } else { + UntypedExpr::Fn { + location: call_location, + fn_style, + arguments, + body: call.into(), + return_annotation, + } + } + } + + // Similarly to function captures, if we have any other expression we simply call it + // with our continuation. If the expression isn't callable? No problem, the + // type-checker will catch that eventually in exactly the same way as if the code was + // written like that to begin with. + _ => UntypedExpr::Call { + location: assign_location, + fun: value, + arguments: vec![CallArg { + location: assign_location, + label: None, + value: UntypedExpr::lambda(name, continuation, assign_location), + }], + }, + } + } + + fn infer_seq(&mut self, location: Span, untyped: Vec) -> Result { + // Search for backpassing. + let mut breakpoint = None; + let mut prefix = Vec::with_capacity(untyped.len()); + let mut suffix = Vec::with_capacity(untyped.len()); + for expression in untyped.into_iter() { + match expression { + _ if breakpoint.is_some() => suffix.push(expression), + UntypedExpr::Assignment { + kind: AssignmentKind::Bind, + .. + } => { + breakpoint = Some(expression); + } + _ => prefix.push(expression), + } + } + if let Some(breakpoint) = breakpoint { + prefix.push(self.backpass(breakpoint, suffix)); + return self.infer_seq(location, prefix); + } + + let sequence = self.in_new_scope(|scope| { + let count = prefix.len(); let mut expressions = Vec::with_capacity(count); - for (i, expression) in untyped.iter().enumerate() { - let no_assignment = assert_no_assignment(expression); + for (i, expression) in prefix.into_iter().enumerate() { + let no_assignment = assert_no_assignment(&expression); - let typed_expression = match expression { - UntypedExpr::Assignment { - kind: AssignmentKind::Bind, - .. - } => { - breakpoint = Some((i, expression.clone())); - return Ok(expressions); - } - _ => scope.infer(expression.to_owned())?, - }; + let typed_expression = scope.infer(expression)?; expressions.push(match i.cmp(&(count - 1)) { // When the expression is the last in a sequence, we enforce it is NOT @@ -1753,74 +1876,6 @@ impl<'a, 'b> ExprTyper<'a, 'b> { Ok(expressions) })?; - if let Some(( - i, - UntypedExpr::Assignment { - location, - value, - pattern, - .. - }, - )) = breakpoint - { - let then = UntypedExpr::Sequence { - location, - expressions: untyped.into_iter().skip(i + 1).collect::>(), - }; - - // TODO: This must be constructed based on the inferred type of *value*. - // - // let tipo = self.infer(untyped_value.clone())?.tipo(); - // - // The following is the `and_then` for Option. The one for Fuzzer is a bit - // different. - let desugar = UntypedExpr::When { - location, - subject: value.clone(), - clauses: vec![ - UntypedClause { - location, - guard: None, - patterns: vec1![Pattern::Constructor { - location, - is_record: false, - with_spread: false, - name: "None".to_string(), - module: None, - constructor: (), - tipo: (), - arguments: vec![], - }], - then: UntypedExpr::Var { - location, - name: "None".to_string(), - }, - }, - UntypedClause { - location, - guard: None, - patterns: vec1![Pattern::Constructor { - location, - is_record: false, - with_spread: false, - name: "Some".to_string(), - module: None, - constructor: (), - tipo: (), - arguments: vec![CallArg { - location, - label: None, - value: pattern.clone(), - }], - }], - then, - }, - ], - }; - - sequence.push(self.infer(desugar)?); - }; - let unused = self .environment .warnings