diff --git a/CHANGELOG.md b/CHANGELOG.md
index e7a73ed2..e507b2e4 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -19,6 +19,7 @@
- **aiken-lang**: Strings can contain a nul byte using the escape sequence `\0`. @KtorZ
- **aiken**: The `check` command now accept an extra (optional) option `--max-success` to control the number of property-test iterations to perform. @KtorZ
- **aiken**: The `docs` command now accept an optional flag `--include-dependencies` to include all dependencies in the generated documentation. @KtorZ
+- **aiken-lang**: Implement [function backpassing](https://www.roc-lang.org/tutorial#backpassing) as a syntactic sugar. @KtorZ
### Fixed
diff --git a/crates/aiken-lang/src/ast.rs b/crates/aiken-lang/src/ast.rs
index 409b7304..c603b685 100644
--- a/crates/aiken-lang/src/ast.rs
+++ b/crates/aiken-lang/src/ast.rs
@@ -16,6 +16,7 @@ use std::{
use uplc::machine::runtime::Compressable;
use vec1::Vec1;
+pub const BACKPASS_VARIABLE: &str = "_backpass";
pub const CAPTURE_VARIABLE: &str = "_capture";
pub const PIPE_VARIABLE: &str = "_pipe";
@@ -792,6 +793,19 @@ impl Arg {
self.arg_name.get_variable_name()
}
+ pub fn is_capture(&self) -> bool {
+ if let ArgName::Named {
+ ref name, location, ..
+ } = self.arg_name
+ {
+ return name.starts_with(CAPTURE_VARIABLE)
+ && location == Span::empty()
+ && self.location == Span::empty();
+ }
+
+ false
+ }
+
pub fn put_doc(&mut self, new_doc: String) {
self.doc = Some(new_doc);
}
diff --git a/crates/aiken-lang/src/expr.rs b/crates/aiken-lang/src/expr.rs
index 28db5a8a..7c310fe4 100644
--- a/crates/aiken-lang/src/expr.rs
+++ b/crates/aiken-lang/src/expr.rs
@@ -1,10 +1,10 @@
use crate::{
ast::{
- self, Annotation, Arg, AssignmentKind, BinOp, Bls12_381Point, ByteArrayFormatPreference,
- CallArg, Curve, DataType, DataTypeKey, DefinitionLocation, IfBranch, Located,
- LogicalOpChainKind, ParsedCallArg, Pattern, RecordConstructorArg, RecordUpdateSpread, Span,
- TraceKind, TypedClause, TypedDataType, TypedRecordUpdateArg, UnOp, UntypedClause,
- UntypedRecordUpdateArg,
+ self, Annotation, Arg, ArgName, AssignmentKind, BinOp, Bls12_381Point,
+ ByteArrayFormatPreference, CallArg, Curve, DataType, DataTypeKey, DefinitionLocation,
+ IfBranch, Located, LogicalOpChainKind, ParsedCallArg, Pattern, RecordConstructorArg,
+ RecordUpdateSpread, Span, TraceKind, TypedClause, TypedDataType, TypedRecordUpdateArg,
+ UnOp, UntypedClause, UntypedRecordUpdateArg,
},
builtins::void,
parser::token::Base,
@@ -1299,4 +1299,29 @@ impl UntypedExpr {
Self::String { .. } | Self::UInt { .. } | Self::ByteArray { .. }
)
}
+
+ pub fn lambda(name: String, expressions: Vec, location: Span) -> Self {
+ Self::Fn {
+ location,
+ fn_style: FnStyle::Plain,
+ arguments: vec![Arg {
+ location,
+ doc: None,
+ annotation: None,
+ tipo: (),
+ arg_name: ArgName::Named {
+ label: name.clone(),
+ name,
+ location,
+ is_validator_param: false,
+ },
+ }],
+ body: Self::Sequence {
+ location,
+ expressions,
+ }
+ .into(),
+ return_annotation: None,
+ }
+ }
}
diff --git a/crates/aiken-lang/src/format.rs b/crates/aiken-lang/src/format.rs
index 33861153..e5c67046 100644
--- a/crates/aiken-lang/src/format.rs
+++ b/crates/aiken-lang/src/format.rs
@@ -1844,7 +1844,11 @@ impl<'a> Documentable<'a> for &'a ArgName {
}
fn pub_(public: bool) -> Document<'static> {
- if public { "pub ".to_doc() } else { nil() }
+ if public {
+ "pub ".to_doc()
+ } else {
+ nil()
+ }
}
impl<'a> Documentable<'a> for &'a UnqualifiedImport {
diff --git a/crates/aiken-lang/src/tests/check.rs b/crates/aiken-lang/src/tests/check.rs
index 33afd01a..89184c28 100644
--- a/crates/aiken-lang/src/tests/check.rs
+++ b/crates/aiken-lang/src/tests/check.rs
@@ -1185,6 +1185,141 @@ fn trace_if_false_ok() {
assert!(check(parse(source_code)).is_ok())
}
+#[test]
+fn backpassing_basic() {
+ let source_code = r#"
+ fn and_then(opt: Option, then: fn(a) -> Option) -> Option {
+ when opt is {
+ None -> None
+ Some(a) -> then(a)
+ }
+ }
+
+ fn backpassing(opt_i: Option, opt_j: Option) -> Option {
+ let i <- and_then(opt_i)
+ let j <- and_then(opt_j)
+ Some(i + j)
+ }
+ "#;
+
+ assert!(check(parse(source_code)).is_ok())
+}
+
+#[test]
+fn backpassing_interleaved_capture() {
+ let source_code = r#"
+ fn and_then(opt: Option, then: fn(a) -> Option) -> Option {
+ when opt is {
+ None -> None
+ Some(a) -> then(a)
+ }
+ }
+
+ fn backpassing(opt_i: Option, opt_j: Option) -> Option {
+ let f = and_then(opt_i, _)
+ let i <- f
+ let g = and_then(opt_j, _)
+ let j <- g
+ Some(i + j)
+ }
+ "#;
+
+ assert!(check(parse(source_code)).is_ok())
+}
+
+#[test]
+fn backpassing_patterns() {
+ let source_code = r#"
+ fn and_then(opt: Option, then: fn(a) -> Option) -> Option {
+ when opt is {
+ None -> None
+ Some(a) -> then(a)
+ }
+ }
+
+ type Foo {
+ foo: Int,
+ }
+
+ fn backpassing(opt_i: Option, opt_j: Option) -> Option {
+ let Foo { foo: i } <- and_then(opt_i)
+ let Foo { foo: j } <- and_then(opt_j)
+ Some(i + j)
+ }
+ "#;
+
+ assert!(check(parse(source_code)).is_ok())
+}
+
+#[test]
+fn backpassing_not_a_function() {
+ let source_code = r#"
+ fn and_then(opt: Option, then: fn(a) -> Option) -> Option {
+ when opt is {
+ None -> None
+ Some(a) -> then(a)
+ }
+ }
+
+ fn backpassing(opt_i: Option, opt_j: Option) -> Option {
+ let i <- opt_i
+ let j <- and_then(opt_j)
+ Some(i + j)
+ }
+ "#;
+
+ assert!(matches!(
+ check(parse(source_code)),
+ Err((_, Error::NotFn { .. }))
+ ))
+}
+
+#[test]
+fn backpassing_non_exhaustive_pattern() {
+ let source_code = r#"
+ fn and_then(opt: Option, then: fn(a) -> Option) -> Option {
+ when opt is {
+ None -> None
+ Some(a) -> then(a)
+ }
+ }
+
+ fn backpassing(opt_i: Option, opt_j: Option) -> Option {
+ let 42 <- and_then(opt_i)
+ let j <- and_then(opt_j)
+ Some(i + j)
+ }
+ "#;
+
+ assert!(matches!(
+ check(parse(source_code)),
+ Err((_, Error::NotExhaustivePatternMatch { .. }))
+ ))
+}
+
+#[test]
+fn backpassing_unsaturated_fn() {
+ let source_code = r#"
+ fn and_then(opt: Option, then: fn(a) -> Option) -> Option {
+ when opt is {
+ None -> None
+ Some(a) -> then(a)
+ }
+ }
+
+ fn backpassing(opt_i: Option, opt_j: Option) -> Option {
+ let i <- and_then
+ let j <- and_then(opt_j)
+ Some(i + j)
+ }
+ "#;
+
+ assert!(matches!(
+ check(parse(source_code)),
+ Err((_, Error::IncorrectFieldsArity { .. }))
+ ))
+}
+
#[test]
fn trace_if_false_ko() {
let source_code = r#"
diff --git a/crates/aiken-lang/src/tipo/expr.rs b/crates/aiken-lang/src/tipo/expr.rs
index fe706e08..8608b2f6 100644
--- a/crates/aiken-lang/src/tipo/expr.rs
+++ b/crates/aiken-lang/src/tipo/expr.rs
@@ -8,12 +8,12 @@ use super::{
};
use crate::{
ast::{
- Annotation, Arg, ArgName, AssignmentKind, BinOp, Bls12_381Point, ByteArrayFormatPreference,
- CallArg, ClauseGuard, Constant, Curve, IfBranch, LogicalOpChainKind, Pattern,
- RecordUpdateSpread, Span, TraceKind, TraceLevel, Tracing, TypedArg, TypedCallArg,
- TypedClause, TypedClauseGuard, TypedIfBranch, TypedPattern, TypedRecordUpdateArg, UnOp,
- UntypedArg, UntypedClause, UntypedClauseGuard, UntypedIfBranch, UntypedPattern,
- UntypedRecordUpdateArg,
+ self, Annotation, Arg, ArgName, AssignmentKind, BinOp, Bls12_381Point,
+ ByteArrayFormatPreference, CallArg, ClauseGuard, Constant, Curve, IfBranch,
+ LogicalOpChainKind, Pattern, RecordUpdateSpread, Span, TraceKind, TraceLevel, Tracing,
+ TypedArg, TypedCallArg, TypedClause, TypedClauseGuard, TypedIfBranch, TypedPattern,
+ TypedRecordUpdateArg, UnOp, UntypedArg, UntypedClause, UntypedClauseGuard, UntypedIfBranch,
+ UntypedPattern, UntypedRecordUpdateArg,
},
builtins::{
bool, byte_array, function, g1_element, g2_element, int, list, string, tuple, void,
@@ -24,7 +24,7 @@ use crate::{
tipo::{fields::FieldMap, PatternConstructor, TypeVar},
};
use std::{cmp::Ordering, collections::HashMap, ops::Deref, rc::Rc};
-use vec1::{vec1, Vec1};
+use vec1::Vec1;
#[derive(Debug)]
pub(crate) struct ExprTyper<'a, 'b> {
@@ -1711,27 +1711,150 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
PipeTyper::infer(self, expressions)
}
- fn infer_seq(&mut self, location: Span, untyped: Vec) -> Result {
- let mut breakpoint = None;
+ fn backpass(&mut self, breakpoint: UntypedExpr, continuation: Vec) -> UntypedExpr {
+ let (assign_location, value, pattern, annotation) = match breakpoint {
+ UntypedExpr::Assignment {
+ location,
+ value,
+ pattern,
+ annotation,
+ ..
+ } => (location, value, pattern, annotation),
+ _ => unreachable!("backpass misuse: breakpoint isn't an Assignment ?!"),
+ };
- let mut sequence = self.in_new_scope(|scope| {
- let count = untyped.len();
+ // In case where we have a Pattern that isn't simply a let-binding to a name, we do insert an extra let-binding
+ // in front of the continuation sequence. This is because we do not support patterns in function argument
+ // (which is perhaps something we should support?).
+ let (name, continuation) = match pattern {
+ Pattern::Var { name, .. } | Pattern::Discard { name, .. } => {
+ (name.clone(), continuation)
+ }
+ _ => {
+ let mut with_assignment = vec![UntypedExpr::Assignment {
+ location: assign_location,
+ value: UntypedExpr::Var {
+ location: assign_location,
+ name: ast::BACKPASS_VARIABLE.to_string(),
+ }
+ .into(),
+ pattern,
+ kind: AssignmentKind::Let,
+ annotation,
+ }];
+ with_assignment.extend(continuation);
+ (ast::BACKPASS_VARIABLE.to_string(), with_assignment)
+ }
+ };
+
+ match *value {
+ UntypedExpr::Call {
+ location: call_location,
+ fun,
+ arguments,
+ } => {
+ let mut new_arguments = Vec::new();
+ new_arguments.extend(arguments);
+ new_arguments.push(CallArg {
+ location: assign_location,
+ label: None,
+ value: UntypedExpr::lambda(name, continuation, call_location),
+ });
+
+ UntypedExpr::Call {
+ location: call_location,
+ fun,
+ arguments: new_arguments,
+ }
+ }
+
+ // This typically occurs on function captures. We do not try to assert anything on the
+ // length of the arguments here. We defer that to the rest of the type-checker. The
+ // only thing we have to do is rewrite the AST as-if someone had passed a callback.
+ //
+ // Now, whether this leads to an invalid call usage, that's not *our* immediate
+ // problem.
+ UntypedExpr::Fn {
+ location: call_location,
+ fn_style,
+ ref arguments,
+ ref return_annotation,
+ ..
+ } => {
+ let return_annotation = return_annotation.clone();
+
+ let arguments = arguments.iter().skip(1).cloned().collect::>();
+
+ let call = UntypedExpr::Call {
+ location: call_location,
+ fun: value,
+ arguments: vec![CallArg {
+ location: assign_location,
+ label: None,
+ value: UntypedExpr::lambda(name, continuation, call_location),
+ }],
+ };
+
+ if arguments.is_empty() {
+ call
+ } else {
+ UntypedExpr::Fn {
+ location: call_location,
+ fn_style,
+ arguments,
+ body: call.into(),
+ return_annotation,
+ }
+ }
+ }
+
+ // Similarly to function captures, if we have any other expression we simply call it
+ // with our continuation. If the expression isn't callable? No problem, the
+ // type-checker will catch that eventually in exactly the same way as if the code was
+ // written like that to begin with.
+ _ => UntypedExpr::Call {
+ location: assign_location,
+ fun: value,
+ arguments: vec![CallArg {
+ location: assign_location,
+ label: None,
+ value: UntypedExpr::lambda(name, continuation, assign_location),
+ }],
+ },
+ }
+ }
+
+ fn infer_seq(&mut self, location: Span, untyped: Vec) -> Result {
+ // Search for backpassing.
+ let mut breakpoint = None;
+ let mut prefix = Vec::with_capacity(untyped.len());
+ let mut suffix = Vec::with_capacity(untyped.len());
+ for expression in untyped.into_iter() {
+ match expression {
+ _ if breakpoint.is_some() => suffix.push(expression),
+ UntypedExpr::Assignment {
+ kind: AssignmentKind::Bind,
+ ..
+ } => {
+ breakpoint = Some(expression);
+ }
+ _ => prefix.push(expression),
+ }
+ }
+ if let Some(breakpoint) = breakpoint {
+ prefix.push(self.backpass(breakpoint, suffix));
+ return self.infer_seq(location, prefix);
+ }
+
+ let sequence = self.in_new_scope(|scope| {
+ let count = prefix.len();
let mut expressions = Vec::with_capacity(count);
- for (i, expression) in untyped.iter().enumerate() {
- let no_assignment = assert_no_assignment(expression);
+ for (i, expression) in prefix.into_iter().enumerate() {
+ let no_assignment = assert_no_assignment(&expression);
- let typed_expression = match expression {
- UntypedExpr::Assignment {
- kind: AssignmentKind::Bind,
- ..
- } => {
- breakpoint = Some((i, expression.clone()));
- return Ok(expressions);
- }
- _ => scope.infer(expression.to_owned())?,
- };
+ let typed_expression = scope.infer(expression)?;
expressions.push(match i.cmp(&(count - 1)) {
// When the expression is the last in a sequence, we enforce it is NOT
@@ -1753,74 +1876,6 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
Ok(expressions)
})?;
- if let Some((
- i,
- UntypedExpr::Assignment {
- location,
- value,
- pattern,
- ..
- },
- )) = breakpoint
- {
- let then = UntypedExpr::Sequence {
- location,
- expressions: untyped.into_iter().skip(i + 1).collect::>(),
- };
-
- // TODO: This must be constructed based on the inferred type of *value*.
- //
- // let tipo = self.infer(untyped_value.clone())?.tipo();
- //
- // The following is the `and_then` for Option. The one for Fuzzer is a bit
- // different.
- let desugar = UntypedExpr::When {
- location,
- subject: value.clone(),
- clauses: vec![
- UntypedClause {
- location,
- guard: None,
- patterns: vec1![Pattern::Constructor {
- location,
- is_record: false,
- with_spread: false,
- name: "None".to_string(),
- module: None,
- constructor: (),
- tipo: (),
- arguments: vec![],
- }],
- then: UntypedExpr::Var {
- location,
- name: "None".to_string(),
- },
- },
- UntypedClause {
- location,
- guard: None,
- patterns: vec1![Pattern::Constructor {
- location,
- is_record: false,
- with_spread: false,
- name: "Some".to_string(),
- module: None,
- constructor: (),
- tipo: (),
- arguments: vec![CallArg {
- location,
- label: None,
- value: pattern.clone(),
- }],
- }],
- then,
- },
- ],
- };
-
- sequence.push(self.infer(desugar)?);
- };
-
let unused = self
.environment
.warnings