diff --git a/crates/aiken-lang/src/ast.rs b/crates/aiken-lang/src/ast.rs index 1811243b..68af1566 100644 --- a/crates/aiken-lang/src/ast.rs +++ b/crates/aiken-lang/src/ast.rs @@ -158,12 +158,15 @@ fn str_to_keyword(word: &str) -> Option { } } -pub type TypedFunction = Function, TypedExpr>; -pub type UntypedFunction = Function<(), UntypedExpr>; +pub type TypedFunction = Function, TypedExpr, TypedArg>; +pub type UntypedFunction = Function<(), UntypedExpr, UntypedArg>; + +pub type TypedTest = Function, TypedExpr, TypedArgVia>; +pub type UntypedTest = Function<(), UntypedExpr, UntypedArgVia>; #[derive(Debug, Clone, PartialEq)] -pub struct Function { - pub arguments: Vec>, +pub struct Function { + pub arguments: Vec, pub body: Expr, pub doc: Option, pub location: Span, @@ -178,7 +181,7 @@ pub struct Function { pub type TypedTypeAlias = TypeAlias>; pub type UntypedTypeAlias = TypeAlias<()>; -impl TypedFunction { +impl TypedTest { pub fn test_hint(&self) -> Option<(BinOp, Box, Box)> { do_test_hint(&self.body) } @@ -358,18 +361,24 @@ pub type UntypedValidator = Validator<(), UntypedExpr>; pub struct Validator { pub doc: Option, pub end_position: usize, - pub fun: Function, - pub other_fun: Option>, + pub fun: Function>, + pub other_fun: Option>>, pub location: Span, pub params: Vec>, } -pub type TypedDefinition = Definition, TypedExpr, String>; -pub type UntypedDefinition = Definition<(), UntypedExpr, ()>; +#[derive(Debug, Clone, PartialEq)] +pub struct DefinitionIdentifier { + pub module: Option, + pub name: String, +} + +pub type TypedDefinition = Definition, TypedExpr, String, ()>; +pub type UntypedDefinition = Definition<(), UntypedExpr, (), DefinitionIdentifier>; #[derive(Debug, Clone, PartialEq)] -pub enum Definition { - Fn(Function), +pub enum Definition { + Fn(Function>), TypeAlias(TypeAlias), @@ -379,12 +388,12 @@ pub enum Definition { ModuleConstant(ModuleConstant), - Test(Function), + Test(Function>), Validator(Validator), } -impl Definition { +impl Definition { pub fn location(&self) -> Span { match self { Definition::Fn(Function { location, .. }) @@ -634,6 +643,40 @@ impl Arg { } } +pub type TypedArgVia = ArgVia, ()>; +pub type UntypedArgVia = ArgVia<(), DefinitionIdentifier>; + +#[derive(Debug, Clone, PartialEq)] +pub struct ArgVia { + pub arg_name: ArgName, + pub location: Span, + pub via: Ann, + pub tipo: T, +} + +impl From> for Arg { + fn from(arg: ArgVia) -> Arg { + Arg { + arg_name: arg.arg_name, + location: arg.location, + tipo: arg.tipo, + annotation: None, + doc: None, + } + } +} + +impl From for TypedArgVia { + fn from(arg: TypedArg) -> TypedArgVia { + ArgVia { + arg_name: arg.arg_name, + tipo: arg.tipo, + location: arg.location, + via: (), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum ArgName { Discarded { diff --git a/crates/aiken-lang/src/format.rs b/crates/aiken-lang/src/format.rs index 176aa02c..9d3045fb 100644 --- a/crates/aiken-lang/src/format.rs +++ b/crates/aiken-lang/src/format.rs @@ -1,11 +1,12 @@ use crate::{ ast::{ - Annotation, Arg, ArgName, AssignmentKind, BinOp, ByteArrayFormatPreference, CallArg, - ClauseGuard, Constant, CurveType, DataType, Definition, Function, IfBranch, - LogicalOpChainKind, ModuleConstant, Pattern, RecordConstructor, RecordConstructorArg, - RecordUpdateSpread, Span, TraceKind, TypeAlias, TypedArg, UnOp, UnqualifiedImport, - UntypedArg, UntypedClause, UntypedClauseGuard, UntypedDefinition, UntypedFunction, - UntypedModule, UntypedPattern, UntypedRecordUpdateArg, Use, Validator, CAPTURE_VARIABLE, + Annotation, Arg, ArgName, ArgVia, AssignmentKind, BinOp, ByteArrayFormatPreference, + CallArg, ClauseGuard, Constant, CurveType, DataType, Definition, DefinitionIdentifier, + Function, IfBranch, LogicalOpChainKind, ModuleConstant, Pattern, RecordConstructor, + RecordConstructorArg, RecordUpdateSpread, Span, TraceKind, TypeAlias, TypedArg, UnOp, + UnqualifiedImport, UntypedArg, UntypedArgVia, UntypedClause, UntypedClauseGuard, + UntypedDefinition, UntypedFunction, UntypedModule, UntypedPattern, UntypedRecordUpdateArg, + Use, Validator, CAPTURE_VARIABLE, }, docvec, expr::{FnStyle, UntypedExpr, DEFAULT_ERROR_STR, DEFAULT_TODO_STR}, @@ -231,16 +232,7 @@ impl<'comments> Formatter<'comments> { return_annotation, end_position, .. - }) => self.definition_fn( - public, - "fn", - name, - args, - return_annotation, - body, - *end_position, - false, - ), + }) => self.definition_fn(public, name, args, return_annotation, body, *end_position), Definition::Validator(Validator { end_position, @@ -257,16 +249,7 @@ impl<'comments> Formatter<'comments> { end_position, can_error, .. - }) => self.definition_fn( - &false, - "test", - name, - args, - &None, - body, - *end_position, - *can_error, - ), + }) => self.definition_test(name, args, body, *end_position, *can_error), Definition::TypeAlias(TypeAlias { alias, @@ -488,25 +471,40 @@ impl<'comments> Formatter<'comments> { commented(doc, comments) } + fn fn_arg_via<'a, A>(&mut self, arg: &'a ArgVia) -> Document<'a> { + let comments = self.pop_comments(arg.location.start); + + let doc_comments = self.doc_comments(arg.location.start); + + let doc = arg.arg_name.to_doc().append(" via "); + + let doc = match arg.via.module { + Some(ref module) => doc.append(module.to_doc()).append("."), + None => doc, + } + .append(arg.via.name.to_doc()) + .group(); + + let doc = doc_comments.append(doc.group()).group(); + + commented(doc, comments) + } + #[allow(clippy::too_many_arguments)] fn definition_fn<'a>( &mut self, public: &'a bool, - keyword: &'a str, name: &'a str, args: &'a [UntypedArg], return_annotation: &'a Option, body: &'a UntypedExpr, end_location: usize, - can_error: bool, ) -> Document<'a> { // Fn name and args let head = pub_(*public) - .append(keyword) - .append(" ") + .append("fn ") .append(name) - .append(wrap_args(args.iter().map(|e| (self.fn_arg(e), false)))) - .append(if can_error { " fail" } else { "" }); + .append(wrap_args(args.iter().map(|e| (self.fn_arg(e), false)))); // Add return annotation let head = match return_annotation { @@ -531,6 +529,39 @@ impl<'comments> Formatter<'comments> { .append("}") } + #[allow(clippy::too_many_arguments)] + fn definition_test<'a>( + &mut self, + name: &'a str, + args: &'a [UntypedArgVia], + body: &'a UntypedExpr, + end_location: usize, + can_error: bool, + ) -> Document<'a> { + // Fn name and args + let head = "test " + .to_doc() + .append(name) + .append(wrap_args(args.iter().map(|e| (self.fn_arg_via(e), false)))) + .append(if can_error { " fail" } else { "" }) + .group(); + + // Format body + let body = self.expr(body, true); + + // Add any trailing comments + let body = match printed_comments(self.pop_comments(end_location), false) { + Some(comments) => body.append(line()).append(comments), + None => body, + }; + + // Stick it all together + head.append(" {") + .append(line().append(body).nest(INDENT).group()) + .append(line()) + .append("}") + } + fn definition_validator<'a>( &mut self, params: &'a [UntypedArg], @@ -550,13 +581,11 @@ impl<'comments> Formatter<'comments> { let first_fn = self .definition_fn( &false, - "fn", &fun.name, &fun.arguments, &fun.return_annotation, &fun.body, fun.end_position, - false, ) .group(); let first_fn = commented(fun_doc_comments.append(first_fn).group(), fun_comments); @@ -570,13 +599,11 @@ impl<'comments> Formatter<'comments> { let other_fn = self .definition_fn( &false, - "fn", &other.name, &other.arguments, &other.return_annotation, &other.body, other.end_position, - false, ) .group(); diff --git a/crates/aiken-lang/src/parser/definition/snapshots/def_invalid_property_test.snap b/crates/aiken-lang/src/parser/definition/snapshots/def_invalid_property_test.snap new file mode 100644 index 00000000..7cd61a30 --- /dev/null +++ b/crates/aiken-lang/src/parser/definition/snapshots/def_invalid_property_test.snap @@ -0,0 +1,50 @@ +--- +source: crates/aiken-lang/src/parser/definition/test.rs +description: "Code:\n\ntest foo(x via f, y via g) {\n True\n}\n" +--- +Test( + Function { + arguments: [ + ArgVia { + arg_name: Named { + name: "x", + label: "x", + location: 9..10, + is_validator_param: false, + }, + location: 9..16, + via: DefinitionIdentifier { + module: None, + name: "f", + }, + tipo: (), + }, + ArgVia { + arg_name: Named { + name: "y", + label: "y", + location: 18..19, + is_validator_param: false, + }, + location: 18..25, + via: DefinitionIdentifier { + module: None, + name: "g", + }, + tipo: (), + }, + ], + body: Var { + location: 33..37, + name: "True", + }, + doc: None, + location: 0..26, + name: "foo", + public: false, + return_annotation: None, + return_type: (), + end_position: 38, + can_error: false, + }, +) diff --git a/crates/aiken-lang/src/parser/definition/snapshots/def_property_test.snap b/crates/aiken-lang/src/parser/definition/snapshots/def_property_test.snap new file mode 100644 index 00000000..8047f21e --- /dev/null +++ b/crates/aiken-lang/src/parser/definition/snapshots/def_property_test.snap @@ -0,0 +1,38 @@ +--- +source: crates/aiken-lang/src/parser/definition/test.rs +description: "Code:\n\ntest foo(x via fuzz.any_int) {\n True\n}\n" +--- +Test( + Function { + arguments: [ + ArgVia { + arg_name: Named { + name: "x", + label: "x", + location: 9..10, + is_validator_param: false, + }, + location: 9..27, + via: DefinitionIdentifier { + module: Some( + "fuzz", + ), + name: "any_int", + }, + tipo: (), + }, + ], + body: Var { + location: 35..39, + name: "True", + }, + doc: None, + location: 0..28, + name: "foo", + public: false, + return_annotation: None, + return_type: (), + end_position: 40, + can_error: false, + }, +) diff --git a/crates/aiken-lang/src/parser/definition/snapshots/def_test.snap b/crates/aiken-lang/src/parser/definition/snapshots/def_test.snap new file mode 100644 index 00000000..de24dc47 --- /dev/null +++ b/crates/aiken-lang/src/parser/definition/snapshots/def_test.snap @@ -0,0 +1,21 @@ +--- +source: crates/aiken-lang/src/parser/definition/test.rs +description: "Code:\n\ntest foo() {\n True\n}\n" +--- +Test( + Function { + arguments: [], + body: Var { + location: 17..21, + name: "True", + }, + doc: None, + location: 0..10, + name: "foo", + public: false, + return_annotation: None, + return_type: (), + end_position: 22, + can_error: false, + }, +) diff --git a/crates/aiken-lang/src/parser/definition/test.rs b/crates/aiken-lang/src/parser/definition/test.rs index 6691edb2..33d48628 100644 --- a/crates/aiken-lang/src/parser/definition/test.rs +++ b/crates/aiken-lang/src/parser/definition/test.rs @@ -13,8 +13,12 @@ pub fn parser() -> impl Parser name}) - .then_ignore(just(Token::LeftParen)) - .then_ignore(just(Token::RightParen)) + .then( + via() + .separated_by(just(Token::Comma)) + .allow_trailing() + .delimited_by(just(Token::LeftParen), just(Token::RightParen)), + ) .then(just(Token::Fail).ignored().or_not()) .map_with_span(|name, span| (name, span)) .then( @@ -22,26 +26,72 @@ pub fn parser() -> impl Parser impl Parser { + choice(( + select! {Token::DiscardName {name} => name}.map_with_span(|name, span| { + ast::ArgName::Discarded { + label: name.clone(), name, - public: false, - return_annotation: None, - return_type: (), - can_error: fail.is_some() || old_fail.is_some(), - }) - }) + location: span, + } + }), + select! {Token::Name {name} => name}.map_with_span(move |name, location| { + ast::ArgName::Named { + label: name.clone(), + name, + location, + is_validator_param: false, + } + }), + )) + .then_ignore(just(Token::Via)) + .then( + select! { Token::Name { name } => name } + .then_ignore(just(Token::Dot)) + .or_not(), + ) + .then(select! { Token::Name { name } => name }) + .map_with_span(|((arg_name, module), name), location| ast::ArgVia { + arg_name, + via: ast::DefinitionIdentifier { module, name }, + tipo: (), + location, + }) } #[cfg(test)] mod tests { use crate::assert_definition; + #[test] + fn def_test() { + assert_definition!( + r#" + test foo() { + True + } + "# + ); + } + #[test] fn def_test_fail() { assert_definition!( @@ -54,4 +104,26 @@ mod tests { "# ); } + + #[test] + fn def_property_test() { + assert_definition!( + r#" + test foo(x via fuzz.any_int) { + True + } + "# + ); + } + + #[test] + fn def_invalid_property_test() { + assert_definition!( + r#" + test foo(x via f, y via g) { + True + } + "# + ); + } } diff --git a/crates/aiken-lang/src/parser/lexer.rs b/crates/aiken-lang/src/parser/lexer.rs index 9dea075a..3e0df9cf 100644 --- a/crates/aiken-lang/src/parser/lexer.rs +++ b/crates/aiken-lang/src/parser/lexer.rs @@ -240,6 +240,7 @@ pub fn lexer() -> impl Parser, Error = ParseError> { "type" => Token::Type, "when" => Token::When, "validator" => Token::Validator, + "via" => Token::Via, _ => { if s.chars().next().map_or(false, |c| c.is_uppercase()) { Token::UpName { diff --git a/crates/aiken-lang/src/parser/token.rs b/crates/aiken-lang/src/parser/token.rs index cae2665b..d48e1179 100644 --- a/crates/aiken-lang/src/parser/token.rs +++ b/crates/aiken-lang/src/parser/token.rs @@ -89,6 +89,7 @@ pub enum Token { When, Trace, Validator, + Via, } impl fmt::Display for Token { @@ -176,6 +177,7 @@ impl fmt::Display for Token { Token::Test => "test", Token::Fail => "fail", Token::Validator => "validator", + Token::Via => "via", }; write!(f, "\"{s}\"") } diff --git a/crates/aiken-lang/src/tipo/environment.rs b/crates/aiken-lang/src/tipo/environment.rs index 3ba8e904..8bb9b383 100644 --- a/crates/aiken-lang/src/tipo/environment.rs +++ b/crates/aiken-lang/src/tipo/environment.rs @@ -10,7 +10,7 @@ use crate::{ RecordConstructor, RecordConstructorArg, Span, TypeAlias, TypedDefinition, TypedPattern, UnqualifiedImport, UntypedArg, UntypedDefinition, Use, Validator, PIPE_VARIABLE, }, - builtins::{self, function, generic_var, tuple, unbound_var}, + builtins::{function, generic_var, tuple, unbound_var}, tipo::fields::FieldMap, IdGenerator, }; @@ -1185,23 +1185,22 @@ impl<'a> Environment<'a> { }) } - Definition::Test(Function { name, location, .. }) => { - assert_unique_value_name(names, name, location)?; - hydrators.insert(name.clone(), Hydrator::new()); - let arg_types = vec![]; - let return_type = builtins::bool(); - self.insert_variable( - name.clone(), - ValueConstructorVariant::ModuleFn { - name: name.clone(), - field_map: None, - module: module_name.to_owned(), - arity: 0, - location: *location, - builtin: None, - }, - function(arg_types, return_type), - ); + Definition::Test(test) => { + let arguments = test + .arguments + .iter() + .map(|arg| arg.clone().into()) + .collect::>(); + + self.register_function( + &test.name, + &arguments, + &test.return_annotation, + module_name, + hydrators, + names, + &test.location, + )?; } Definition::DataType(DataType { diff --git a/crates/aiken-lang/src/tipo/error.rs b/crates/aiken-lang/src/tipo/error.rs index 1acbc454..1ed18458 100644 --- a/crates/aiken-lang/src/tipo/error.rs +++ b/crates/aiken-lang/src/tipo/error.rs @@ -946,6 +946,17 @@ The best thing to do from here is to remove it."#))] #[label("{} arguments", if *count < 2 { "not enough" } else { "too many" })] location: Span, }, + + #[error("I caught a test with too many arguments.\n")] + #[diagnostic(code("illegal::test_arity"))] + #[diagnostic(help( + "Tests are allowed to have 0 or 1 argument, but no more. Here I've found a test definition with {count} arguments. If you need to provide multiple values to a test, use a Record or a Tuple.", + ))] + IncorrectTestArity { + count: usize, + #[label("too many arguments")] + location: Span, + }, } impl ExtraData for Error { @@ -997,6 +1008,7 @@ impl ExtraData for Error { | Error::UnnecessarySpreadOperator { .. } | Error::UpdateMultiConstructorType { .. } | Error::ValidatorImported { .. } + | Error::IncorrectTestArity { .. } | Error::ValidatorMustReturnBool { .. } => None, Error::UnknownType { name, .. } diff --git a/crates/aiken-lang/src/tipo/expr.rs b/crates/aiken-lang/src/tipo/expr.rs index 83e55f33..b7bee950 100644 --- a/crates/aiken-lang/src/tipo/expr.rs +++ b/crates/aiken-lang/src/tipo/expr.rs @@ -1860,7 +1860,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { } } - fn infer_value_constructor( + pub fn infer_value_constructor( &mut self, module: &Option, name: &str, diff --git a/crates/aiken-lang/src/tipo/infer.rs b/crates/aiken-lang/src/tipo/infer.rs index 9b94181b..106a9aa3 100644 --- a/crates/aiken-lang/src/tipo/infer.rs +++ b/crates/aiken-lang/src/tipo/infer.rs @@ -2,15 +2,19 @@ use std::collections::HashMap; use crate::{ ast::{ - ArgName, DataType, Definition, Function, Layer, ModuleConstant, ModuleKind, - RecordConstructor, RecordConstructorArg, Tracing, TypeAlias, TypedDefinition, - TypedFunction, TypedModule, UntypedDefinition, UntypedModule, Use, Validator, + Annotation, Arg, ArgName, DataType, Definition, Function, Layer, ModuleConstant, + ModuleKind, RecordConstructor, RecordConstructorArg, Tracing, TypeAlias, TypedArg, + TypedDefinition, TypedFunction, TypedModule, UntypedArg, UntypedDefinition, UntypedModule, + Use, Validator, }, builtins, builtins::function, + expr::{TypedExpr, UntypedExpr}, line_numbers::LineNumbers, + tipo::{Span, Type}, IdGenerator, }; +use std::rc::Rc; use super::{ environment::{generalise, EntityKind, Environment}, @@ -159,97 +163,14 @@ fn infer_definition( tracing: Tracing, ) -> Result { match def { - Definition::Fn(Function { - doc, - location, - name, - public, - arguments: args, - body, - return_annotation, - end_position, - can_error, - .. - }) => { - let preregistered_fn = environment - .get_variable(&name) - .expect("Could not find preregistered type for function"); - - let field_map = preregistered_fn.field_map().cloned(); - - let preregistered_type = preregistered_fn.tipo.clone(); - - let (args_types, return_type) = preregistered_type - .function_types() - .expect("Preregistered type for fn was not a fn"); - - // Infer the type using the preregistered args + return types as a starting point - let (tipo, args, body, safe_to_generalise) = - environment.in_new_scope(|environment| { - let args = args - .into_iter() - .zip(&args_types) - .map(|(arg_name, tipo)| arg_name.set_type(tipo.clone())) - .collect(); - - let mut expr_typer = ExprTyper::new(environment, lines, tracing); - - expr_typer.hydrator = hydrators - .remove(&name) - .expect("Could not find hydrator for fn"); - - let (args, body) = - expr_typer.infer_fn_with_known_types(args, body, Some(return_type))?; - - let args_types = args.iter().map(|a| a.tipo.clone()).collect(); - - let tipo = function(args_types, body.tipo()); - - let safe_to_generalise = !expr_typer.ungeneralised_function_used; - - Ok::<_, Error>((tipo, args, body, safe_to_generalise)) - })?; - - // Assert that the inferred type matches the type of any recursive call - environment.unify(preregistered_type, tipo.clone(), location, false)?; - - // Generalise the function if safe to do so - let tipo = if safe_to_generalise { - environment.ungeneralised_functions.remove(&name); - - let tipo = generalise(tipo, 0); - - let module_fn = ValueConstructorVariant::ModuleFn { - name: name.clone(), - field_map, - module: module_name.to_owned(), - arity: args.len(), - location, - builtin: None, - }; - - environment.insert_variable(name.clone(), module_fn, tipo.clone()); - - tipo - } else { - tipo - }; - - Ok(Definition::Fn(Function { - doc, - location, - name, - public, - arguments: args, - return_annotation, - return_type: tipo - .return_type() - .expect("Could not find return type for fn"), - body, - can_error, - end_position, - })) - } + Definition::Fn(f) => Ok(Definition::Fn(infer_function( + f, + module_name, + hydrators, + environment, + lines, + tracing, + )?)), Definition::Validator(Validator { doc, @@ -412,20 +333,127 @@ fn infer_definition( } Definition::Test(f) => { - if let Definition::Fn(f) = infer_definition( - Definition::Fn(f), + fn annotate_fuzzer(tipo: &Type, location: &Span) -> Result { + match tipo { + // TODO: Ensure args & first returned element is a Prelude's PRNG. + Type::Fn { ret, .. } => { + let ann = tipo_to_annotation(ret, location)?; + match ann { + Annotation::Tuple { elems, .. } if elems.len() == 2 => { + Ok(elems.get(1).expect("Tuple has two elements").to_owned()) + } + _ => todo!("Fuzzer returns something else than a 2-tuple? "), + } + } + Type::Var { .. } | Type::App { .. } | Type::Tuple { .. } => { + todo!("Fuzzer type isn't a function?"); + } + } + } + + fn tipo_to_annotation(tipo: &Type, location: &Span) -> Result { + match tipo { + Type::App { + name, module, args, .. + } => { + let arguments = args + .iter() + .map(|arg| tipo_to_annotation(arg, location)) + .collect::, _>>()?; + Ok(Annotation::Constructor { + name: name.to_owned(), + module: Some(module.to_owned()), + arguments, + location: *location, + }) + } + Type::Tuple { elems } => { + let elems = elems + .iter() + .map(|arg| tipo_to_annotation(arg, location)) + .collect::, _>>()?; + Ok(Annotation::Tuple { + elems, + location: *location, + }) + } + Type::Fn { .. } | Type::Var { .. } => { + todo!("Fuzzer contains functions and/or non-concrete data-types?"); + } + } + } + + let annotation = match f.arguments.first() { + Some(arg) => { + if f.arguments.len() > 1 { + return Err(Error::IncorrectTestArity { + count: f.arguments.len(), + location: f.arguments.get(1).unwrap().location, + }); + } + + let ValueConstructor { tipo, .. } = ExprTyper::new(environment, lines, tracing) + .infer_value_constructor(&arg.via.module, &arg.via.name, &arg.location)?; + + Ok(Some(annotate_fuzzer(&tipo, &arg.location)?)) + } + None => Ok(None), + }?; + + let typed_f = infer_function( + Function { + doc: f.doc, + location: f.location, + name: f.name, + public: f.public, + arguments: f + .arguments + .into_iter() + .map(|arg| Arg { + annotation: annotation.clone(), + ..arg.into() + }) + .collect(), + return_annotation: f.return_annotation, + return_type: f.return_type, + body: f.body, + can_error: f.can_error, + end_position: f.end_position, + }, module_name, hydrators, environment, lines, tracing, - )? { - environment.unify(f.return_type.clone(), builtins::bool(), f.location, false)?; + )?; - Ok(Definition::Test(f)) - } else { - unreachable!("test definition inferred as something other than a function?") - } + environment.unify( + typed_f.return_type.clone(), + builtins::bool(), + typed_f.location, + false, + )?; + + Ok(Definition::Test(Function { + doc: typed_f.doc, + location: typed_f.location, + name: typed_f.name, + public: typed_f.public, + arguments: match annotation { + Some(_) => vec![typed_f + .arguments + .first() + .expect("has exactly one argument") + .to_owned() + .into()], + None => vec![], + }, + return_annotation: typed_f.return_annotation, + return_type: typed_f.return_type, + body: typed_f.body, + can_error: typed_f.can_error, + end_position: typed_f.end_position, + })) } Definition::TypeAlias(TypeAlias { @@ -640,3 +668,102 @@ fn infer_definition( } } } + +fn infer_function( + f: Function<(), UntypedExpr, UntypedArg>, + module_name: &String, + hydrators: &mut HashMap, + environment: &mut Environment<'_>, + lines: &LineNumbers, + tracing: Tracing, +) -> Result, TypedExpr, TypedArg>, Error> { + let Function { + doc, + location, + name, + public, + arguments, + body, + return_annotation, + end_position, + can_error, + .. + } = f; + + let preregistered_fn = environment + .get_variable(&name) + .expect("Could not find preregistered type for function"); + + let field_map = preregistered_fn.field_map().cloned(); + + let preregistered_type = preregistered_fn.tipo.clone(); + + let (args_types, return_type) = preregistered_type + .function_types() + .expect("Preregistered type for fn was not a fn"); + + // Infer the type using the preregistered args + return types as a starting point + let (tipo, arguments, body, safe_to_generalise) = environment.in_new_scope(|environment| { + let args = arguments + .into_iter() + .zip(&args_types) + .map(|(arg_name, tipo)| arg_name.set_type(tipo.clone())) + .collect(); + + let mut expr_typer = ExprTyper::new(environment, lines, tracing); + + expr_typer.hydrator = hydrators + .remove(&name) + .expect("Could not find hydrator for fn"); + + let (args, body) = expr_typer.infer_fn_with_known_types(args, body, Some(return_type))?; + + let args_types = args.iter().map(|a| a.tipo.clone()).collect(); + + let tipo = function(args_types, body.tipo()); + + let safe_to_generalise = !expr_typer.ungeneralised_function_used; + + Ok::<_, Error>((tipo, args, body, safe_to_generalise)) + })?; + + // Assert that the inferred type matches the type of any recursive call + environment.unify(preregistered_type, tipo.clone(), location, false)?; + + // Generalise the function if safe to do so + let tipo = if safe_to_generalise { + environment.ungeneralised_functions.remove(&name); + + let tipo = generalise(tipo, 0); + + let module_fn = ValueConstructorVariant::ModuleFn { + name: name.clone(), + field_map, + module: module_name.to_owned(), + arity: arguments.len(), + location, + builtin: None, + }; + + environment.insert_variable(name.clone(), module_fn, tipo.clone()); + + tipo + } else { + tipo + }; + + Ok(Function { + doc, + location, + name, + public, + arguments, + return_annotation, + return_type: tipo + .return_type() + .expect("Could not find return type for fn"), + body, + can_error, + end_position, + }) +} diff --git a/crates/aiken-project/src/tests/gen_uplc.rs b/crates/aiken-project/src/tests/gen_uplc.rs index 6647e7f9..a9ab36d5 100644 --- a/crates/aiken-project/src/tests/gen_uplc.rs +++ b/crates/aiken-project/src/tests/gen_uplc.rs @@ -1,6 +1,6 @@ use pretty_assertions::assert_eq; -use aiken_lang::ast::{Definition, Function, TraceLevel, Tracing, TypedFunction, TypedValidator}; +use aiken_lang::ast::{Definition, Function, TraceLevel, Tracing, TypedTest, TypedValidator}; use uplc::{ ast::{Constant, Data, DeBruijn, Name, Program, Term, Type}, builder::{CONSTR_FIELDS_EXPOSER, CONSTR_INDEX_EXPOSER}, @@ -13,7 +13,7 @@ use crate::module::CheckedModules; use super::TestProject; enum TestType { - Func(TypedFunction), + Func(TypedTest), Validator(TypedValidator), }