Allow test definition to carry one parameter

The parameter is special as it takes no annotation but a 'via' keyword followed by an expression that should unify to a Fuzzer<a>, where Fuzzer<a> = fn(Seed) -> (Seed, a). The current commit only allow name identifiers for now. Ultimately, this may allow full expressions.
This commit is contained in:
KtorZ 2024-02-24 19:02:45 +01:00
parent bfb4455e0f
commit aadf3cfb48
No known key found for this signature in database
GPG Key ID: 33173CB6F77F4277
13 changed files with 579 additions and 187 deletions

View File

@ -158,12 +158,15 @@ fn str_to_keyword(word: &str) -> Option<Token> {
}
}
pub type TypedFunction = Function<Rc<Type>, TypedExpr>;
pub type UntypedFunction = Function<(), UntypedExpr>;
pub type TypedFunction = Function<Rc<Type>, TypedExpr, TypedArg>;
pub type UntypedFunction = Function<(), UntypedExpr, UntypedArg>;
pub type TypedTest = Function<Rc<Type>, TypedExpr, TypedArgVia>;
pub type UntypedTest = Function<(), UntypedExpr, UntypedArgVia>;
#[derive(Debug, Clone, PartialEq)]
pub struct Function<T, Expr> {
pub arguments: Vec<Arg<T>>,
pub struct Function<T, Expr, Arg> {
pub arguments: Vec<Arg>,
pub body: Expr,
pub doc: Option<String>,
pub location: Span,
@ -178,7 +181,7 @@ pub struct Function<T, Expr> {
pub type TypedTypeAlias = TypeAlias<Rc<Type>>;
pub type UntypedTypeAlias = TypeAlias<()>;
impl TypedFunction {
impl TypedTest {
pub fn test_hint(&self) -> Option<(BinOp, Box<TypedExpr>, Box<TypedExpr>)> {
do_test_hint(&self.body)
}
@ -358,18 +361,24 @@ pub type UntypedValidator = Validator<(), UntypedExpr>;
pub struct Validator<T, Expr> {
pub doc: Option<String>,
pub end_position: usize,
pub fun: Function<T, Expr>,
pub other_fun: Option<Function<T, Expr>>,
pub fun: Function<T, Expr, Arg<T>>,
pub other_fun: Option<Function<T, Expr, Arg<T>>>,
pub location: Span,
pub params: Vec<Arg<T>>,
}
pub type TypedDefinition = Definition<Rc<Type>, TypedExpr, String>;
pub type UntypedDefinition = Definition<(), UntypedExpr, ()>;
#[derive(Debug, Clone, PartialEq)]
pub struct DefinitionIdentifier {
pub module: Option<String>,
pub name: String,
}
pub type TypedDefinition = Definition<Rc<Type>, TypedExpr, String, ()>;
pub type UntypedDefinition = Definition<(), UntypedExpr, (), DefinitionIdentifier>;
#[derive(Debug, Clone, PartialEq)]
pub enum Definition<T, Expr, PackageName> {
Fn(Function<T, Expr>),
pub enum Definition<T, Expr, PackageName, Ann> {
Fn(Function<T, Expr, Arg<T>>),
TypeAlias(TypeAlias<T>),
@ -379,12 +388,12 @@ pub enum Definition<T, Expr, PackageName> {
ModuleConstant(ModuleConstant<T>),
Test(Function<T, Expr>),
Test(Function<T, Expr, ArgVia<T, Ann>>),
Validator(Validator<T, Expr>),
}
impl<A, B, C> Definition<A, B, C> {
impl<A, B, C, D> Definition<A, B, C, D> {
pub fn location(&self) -> Span {
match self {
Definition::Fn(Function { location, .. })
@ -634,6 +643,40 @@ impl<A> Arg<A> {
}
}
pub type TypedArgVia = ArgVia<Rc<Type>, ()>;
pub type UntypedArgVia = ArgVia<(), DefinitionIdentifier>;
#[derive(Debug, Clone, PartialEq)]
pub struct ArgVia<T, Ann> {
pub arg_name: ArgName,
pub location: Span,
pub via: Ann,
pub tipo: T,
}
impl<T, Ann> From<ArgVia<T, Ann>> for Arg<T> {
fn from(arg: ArgVia<T, Ann>) -> Arg<T> {
Arg {
arg_name: arg.arg_name,
location: arg.location,
tipo: arg.tipo,
annotation: None,
doc: None,
}
}
}
impl From<TypedArg> for TypedArgVia {
fn from(arg: TypedArg) -> TypedArgVia {
ArgVia {
arg_name: arg.arg_name,
tipo: arg.tipo,
location: arg.location,
via: (),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ArgName {
Discarded {

View File

@ -1,11 +1,12 @@
use crate::{
ast::{
Annotation, Arg, ArgName, AssignmentKind, BinOp, ByteArrayFormatPreference, CallArg,
ClauseGuard, Constant, CurveType, DataType, Definition, Function, IfBranch,
LogicalOpChainKind, ModuleConstant, Pattern, RecordConstructor, RecordConstructorArg,
RecordUpdateSpread, Span, TraceKind, TypeAlias, TypedArg, UnOp, UnqualifiedImport,
UntypedArg, UntypedClause, UntypedClauseGuard, UntypedDefinition, UntypedFunction,
UntypedModule, UntypedPattern, UntypedRecordUpdateArg, Use, Validator, CAPTURE_VARIABLE,
Annotation, Arg, ArgName, ArgVia, AssignmentKind, BinOp, ByteArrayFormatPreference,
CallArg, ClauseGuard, Constant, CurveType, DataType, Definition, DefinitionIdentifier,
Function, IfBranch, LogicalOpChainKind, ModuleConstant, Pattern, RecordConstructor,
RecordConstructorArg, RecordUpdateSpread, Span, TraceKind, TypeAlias, TypedArg, UnOp,
UnqualifiedImport, UntypedArg, UntypedArgVia, UntypedClause, UntypedClauseGuard,
UntypedDefinition, UntypedFunction, UntypedModule, UntypedPattern, UntypedRecordUpdateArg,
Use, Validator, CAPTURE_VARIABLE,
},
docvec,
expr::{FnStyle, UntypedExpr, DEFAULT_ERROR_STR, DEFAULT_TODO_STR},
@ -231,16 +232,7 @@ impl<'comments> Formatter<'comments> {
return_annotation,
end_position,
..
}) => self.definition_fn(
public,
"fn",
name,
args,
return_annotation,
body,
*end_position,
false,
),
}) => self.definition_fn(public, name, args, return_annotation, body, *end_position),
Definition::Validator(Validator {
end_position,
@ -257,16 +249,7 @@ impl<'comments> Formatter<'comments> {
end_position,
can_error,
..
}) => self.definition_fn(
&false,
"test",
name,
args,
&None,
body,
*end_position,
*can_error,
),
}) => self.definition_test(name, args, body, *end_position, *can_error),
Definition::TypeAlias(TypeAlias {
alias,
@ -488,25 +471,40 @@ impl<'comments> Formatter<'comments> {
commented(doc, comments)
}
fn fn_arg_via<'a, A>(&mut self, arg: &'a ArgVia<A, DefinitionIdentifier>) -> Document<'a> {
let comments = self.pop_comments(arg.location.start);
let doc_comments = self.doc_comments(arg.location.start);
let doc = arg.arg_name.to_doc().append(" via ");
let doc = match arg.via.module {
Some(ref module) => doc.append(module.to_doc()).append("."),
None => doc,
}
.append(arg.via.name.to_doc())
.group();
let doc = doc_comments.append(doc.group()).group();
commented(doc, comments)
}
#[allow(clippy::too_many_arguments)]
fn definition_fn<'a>(
&mut self,
public: &'a bool,
keyword: &'a str,
name: &'a str,
args: &'a [UntypedArg],
return_annotation: &'a Option<Annotation>,
body: &'a UntypedExpr,
end_location: usize,
can_error: bool,
) -> Document<'a> {
// Fn name and args
let head = pub_(*public)
.append(keyword)
.append(" ")
.append("fn ")
.append(name)
.append(wrap_args(args.iter().map(|e| (self.fn_arg(e), false))))
.append(if can_error { " fail" } else { "" });
.append(wrap_args(args.iter().map(|e| (self.fn_arg(e), false))));
// Add return annotation
let head = match return_annotation {
@ -531,6 +529,39 @@ impl<'comments> Formatter<'comments> {
.append("}")
}
#[allow(clippy::too_many_arguments)]
fn definition_test<'a>(
&mut self,
name: &'a str,
args: &'a [UntypedArgVia],
body: &'a UntypedExpr,
end_location: usize,
can_error: bool,
) -> Document<'a> {
// Fn name and args
let head = "test "
.to_doc()
.append(name)
.append(wrap_args(args.iter().map(|e| (self.fn_arg_via(e), false))))
.append(if can_error { " fail" } else { "" })
.group();
// Format body
let body = self.expr(body, true);
// Add any trailing comments
let body = match printed_comments(self.pop_comments(end_location), false) {
Some(comments) => body.append(line()).append(comments),
None => body,
};
// Stick it all together
head.append(" {")
.append(line().append(body).nest(INDENT).group())
.append(line())
.append("}")
}
fn definition_validator<'a>(
&mut self,
params: &'a [UntypedArg],
@ -550,13 +581,11 @@ impl<'comments> Formatter<'comments> {
let first_fn = self
.definition_fn(
&false,
"fn",
&fun.name,
&fun.arguments,
&fun.return_annotation,
&fun.body,
fun.end_position,
false,
)
.group();
let first_fn = commented(fun_doc_comments.append(first_fn).group(), fun_comments);
@ -570,13 +599,11 @@ impl<'comments> Formatter<'comments> {
let other_fn = self
.definition_fn(
&false,
"fn",
&other.name,
&other.arguments,
&other.return_annotation,
&other.body,
other.end_position,
false,
)
.group();

View File

@ -0,0 +1,50 @@
---
source: crates/aiken-lang/src/parser/definition/test.rs
description: "Code:\n\ntest foo(x via f, y via g) {\n True\n}\n"
---
Test(
Function {
arguments: [
ArgVia {
arg_name: Named {
name: "x",
label: "x",
location: 9..10,
is_validator_param: false,
},
location: 9..16,
via: DefinitionIdentifier {
module: None,
name: "f",
},
tipo: (),
},
ArgVia {
arg_name: Named {
name: "y",
label: "y",
location: 18..19,
is_validator_param: false,
},
location: 18..25,
via: DefinitionIdentifier {
module: None,
name: "g",
},
tipo: (),
},
],
body: Var {
location: 33..37,
name: "True",
},
doc: None,
location: 0..26,
name: "foo",
public: false,
return_annotation: None,
return_type: (),
end_position: 38,
can_error: false,
},
)

View File

@ -0,0 +1,38 @@
---
source: crates/aiken-lang/src/parser/definition/test.rs
description: "Code:\n\ntest foo(x via fuzz.any_int) {\n True\n}\n"
---
Test(
Function {
arguments: [
ArgVia {
arg_name: Named {
name: "x",
label: "x",
location: 9..10,
is_validator_param: false,
},
location: 9..27,
via: DefinitionIdentifier {
module: Some(
"fuzz",
),
name: "any_int",
},
tipo: (),
},
],
body: Var {
location: 35..39,
name: "True",
},
doc: None,
location: 0..28,
name: "foo",
public: false,
return_annotation: None,
return_type: (),
end_position: 40,
can_error: false,
},
)

View File

@ -0,0 +1,21 @@
---
source: crates/aiken-lang/src/parser/definition/test.rs
description: "Code:\n\ntest foo() {\n True\n}\n"
---
Test(
Function {
arguments: [],
body: Var {
location: 17..21,
name: "True",
},
doc: None,
location: 0..10,
name: "foo",
public: false,
return_annotation: None,
return_type: (),
end_position: 22,
can_error: false,
},
)

View File

@ -13,8 +13,12 @@ pub fn parser() -> impl Parser<Token, ast::UntypedDefinition, Error = ParseError
.or_not()
.then_ignore(just(Token::Test))
.then(select! {Token::Name {name} => name})
.then_ignore(just(Token::LeftParen))
.then_ignore(just(Token::RightParen))
.then(
via()
.separated_by(just(Token::Comma))
.allow_trailing()
.delimited_by(just(Token::LeftParen), just(Token::RightParen)),
)
.then(just(Token::Fail).ignored().or_not())
.map_with_span(|name, span| (name, span))
.then(
@ -22,26 +26,72 @@ pub fn parser() -> impl Parser<Token, ast::UntypedDefinition, Error = ParseError
.or_not()
.delimited_by(just(Token::LeftBrace), just(Token::RightBrace)),
)
.map_with_span(|((((old_fail, name), fail), span_end), body), span| {
ast::UntypedDefinition::Test(ast::Function {
arguments: vec![],
body: body.unwrap_or_else(|| UntypedExpr::todo(None, span)),
doc: None,
location: span_end,
end_position: span.end - 1,
.map_with_span(
|(((((old_fail, name), arguments), fail), span_end), body), span| {
ast::UntypedDefinition::Test(ast::Function {
arguments,
body: body.unwrap_or_else(|| UntypedExpr::todo(None, span)),
doc: None,
location: span_end,
end_position: span.end - 1,
name,
public: false,
return_annotation: Some(ast::Annotation::boolean(span)),
return_type: (),
can_error: fail.is_some() || old_fail.is_some(),
})
},
)
}
pub fn via() -> impl Parser<Token, ast::UntypedArgVia, Error = ParseError> {
choice((
select! {Token::DiscardName {name} => name}.map_with_span(|name, span| {
ast::ArgName::Discarded {
label: name.clone(),
name,
public: false,
return_annotation: None,
return_type: (),
can_error: fail.is_some() || old_fail.is_some(),
})
})
location: span,
}
}),
select! {Token::Name {name} => name}.map_with_span(move |name, location| {
ast::ArgName::Named {
label: name.clone(),
name,
location,
is_validator_param: false,
}
}),
))
.then_ignore(just(Token::Via))
.then(
select! { Token::Name { name } => name }
.then_ignore(just(Token::Dot))
.or_not(),
)
.then(select! { Token::Name { name } => name })
.map_with_span(|((arg_name, module), name), location| ast::ArgVia {
arg_name,
via: ast::DefinitionIdentifier { module, name },
tipo: (),
location,
})
}
#[cfg(test)]
mod tests {
use crate::assert_definition;
#[test]
fn def_test() {
assert_definition!(
r#"
test foo() {
True
}
"#
);
}
#[test]
fn def_test_fail() {
assert_definition!(
@ -54,4 +104,26 @@ mod tests {
"#
);
}
#[test]
fn def_property_test() {
assert_definition!(
r#"
test foo(x via fuzz.any_int) {
True
}
"#
);
}
#[test]
fn def_invalid_property_test() {
assert_definition!(
r#"
test foo(x via f, y via g) {
True
}
"#
);
}
}

View File

@ -240,6 +240,7 @@ pub fn lexer() -> impl Parser<char, Vec<(Token, Span)>, Error = ParseError> {
"type" => Token::Type,
"when" => Token::When,
"validator" => Token::Validator,
"via" => Token::Via,
_ => {
if s.chars().next().map_or(false, |c| c.is_uppercase()) {
Token::UpName {

View File

@ -89,6 +89,7 @@ pub enum Token {
When,
Trace,
Validator,
Via,
}
impl fmt::Display for Token {
@ -176,6 +177,7 @@ impl fmt::Display for Token {
Token::Test => "test",
Token::Fail => "fail",
Token::Validator => "validator",
Token::Via => "via",
};
write!(f, "\"{s}\"")
}

View File

@ -10,7 +10,7 @@ use crate::{
RecordConstructor, RecordConstructorArg, Span, TypeAlias, TypedDefinition, TypedPattern,
UnqualifiedImport, UntypedArg, UntypedDefinition, Use, Validator, PIPE_VARIABLE,
},
builtins::{self, function, generic_var, tuple, unbound_var},
builtins::{function, generic_var, tuple, unbound_var},
tipo::fields::FieldMap,
IdGenerator,
};
@ -1185,23 +1185,22 @@ impl<'a> Environment<'a> {
})
}
Definition::Test(Function { name, location, .. }) => {
assert_unique_value_name(names, name, location)?;
hydrators.insert(name.clone(), Hydrator::new());
let arg_types = vec![];
let return_type = builtins::bool();
self.insert_variable(
name.clone(),
ValueConstructorVariant::ModuleFn {
name: name.clone(),
field_map: None,
module: module_name.to_owned(),
arity: 0,
location: *location,
builtin: None,
},
function(arg_types, return_type),
);
Definition::Test(test) => {
let arguments = test
.arguments
.iter()
.map(|arg| arg.clone().into())
.collect::<Vec<_>>();
self.register_function(
&test.name,
&arguments,
&test.return_annotation,
module_name,
hydrators,
names,
&test.location,
)?;
}
Definition::DataType(DataType {

View File

@ -946,6 +946,17 @@ The best thing to do from here is to remove it."#))]
#[label("{} arguments", if *count < 2 { "not enough" } else { "too many" })]
location: Span,
},
#[error("I caught a test with too many arguments.\n")]
#[diagnostic(code("illegal::test_arity"))]
#[diagnostic(help(
"Tests are allowed to have 0 or 1 argument, but no more. Here I've found a test definition with {count} arguments. If you need to provide multiple values to a test, use a Record or a Tuple.",
))]
IncorrectTestArity {
count: usize,
#[label("too many arguments")]
location: Span,
},
}
impl ExtraData for Error {
@ -997,6 +1008,7 @@ impl ExtraData for Error {
| Error::UnnecessarySpreadOperator { .. }
| Error::UpdateMultiConstructorType { .. }
| Error::ValidatorImported { .. }
| Error::IncorrectTestArity { .. }
| Error::ValidatorMustReturnBool { .. } => None,
Error::UnknownType { name, .. }

View File

@ -1860,7 +1860,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
}
}
fn infer_value_constructor(
pub fn infer_value_constructor(
&mut self,
module: &Option<String>,
name: &str,

View File

@ -2,15 +2,19 @@ use std::collections::HashMap;
use crate::{
ast::{
ArgName, DataType, Definition, Function, Layer, ModuleConstant, ModuleKind,
RecordConstructor, RecordConstructorArg, Tracing, TypeAlias, TypedDefinition,
TypedFunction, TypedModule, UntypedDefinition, UntypedModule, Use, Validator,
Annotation, Arg, ArgName, DataType, Definition, Function, Layer, ModuleConstant,
ModuleKind, RecordConstructor, RecordConstructorArg, Tracing, TypeAlias, TypedArg,
TypedDefinition, TypedFunction, TypedModule, UntypedArg, UntypedDefinition, UntypedModule,
Use, Validator,
},
builtins,
builtins::function,
expr::{TypedExpr, UntypedExpr},
line_numbers::LineNumbers,
tipo::{Span, Type},
IdGenerator,
};
use std::rc::Rc;
use super::{
environment::{generalise, EntityKind, Environment},
@ -159,97 +163,14 @@ fn infer_definition(
tracing: Tracing,
) -> Result<TypedDefinition, Error> {
match def {
Definition::Fn(Function {
doc,
location,
name,
public,
arguments: args,
body,
return_annotation,
end_position,
can_error,
..
}) => {
let preregistered_fn = environment
.get_variable(&name)
.expect("Could not find preregistered type for function");
let field_map = preregistered_fn.field_map().cloned();
let preregistered_type = preregistered_fn.tipo.clone();
let (args_types, return_type) = preregistered_type
.function_types()
.expect("Preregistered type for fn was not a fn");
// Infer the type using the preregistered args + return types as a starting point
let (tipo, args, body, safe_to_generalise) =
environment.in_new_scope(|environment| {
let args = args
.into_iter()
.zip(&args_types)
.map(|(arg_name, tipo)| arg_name.set_type(tipo.clone()))
.collect();
let mut expr_typer = ExprTyper::new(environment, lines, tracing);
expr_typer.hydrator = hydrators
.remove(&name)
.expect("Could not find hydrator for fn");
let (args, body) =
expr_typer.infer_fn_with_known_types(args, body, Some(return_type))?;
let args_types = args.iter().map(|a| a.tipo.clone()).collect();
let tipo = function(args_types, body.tipo());
let safe_to_generalise = !expr_typer.ungeneralised_function_used;
Ok::<_, Error>((tipo, args, body, safe_to_generalise))
})?;
// Assert that the inferred type matches the type of any recursive call
environment.unify(preregistered_type, tipo.clone(), location, false)?;
// Generalise the function if safe to do so
let tipo = if safe_to_generalise {
environment.ungeneralised_functions.remove(&name);
let tipo = generalise(tipo, 0);
let module_fn = ValueConstructorVariant::ModuleFn {
name: name.clone(),
field_map,
module: module_name.to_owned(),
arity: args.len(),
location,
builtin: None,
};
environment.insert_variable(name.clone(), module_fn, tipo.clone());
tipo
} else {
tipo
};
Ok(Definition::Fn(Function {
doc,
location,
name,
public,
arguments: args,
return_annotation,
return_type: tipo
.return_type()
.expect("Could not find return type for fn"),
body,
can_error,
end_position,
}))
}
Definition::Fn(f) => Ok(Definition::Fn(infer_function(
f,
module_name,
hydrators,
environment,
lines,
tracing,
)?)),
Definition::Validator(Validator {
doc,
@ -412,20 +333,127 @@ fn infer_definition(
}
Definition::Test(f) => {
if let Definition::Fn(f) = infer_definition(
Definition::Fn(f),
fn annotate_fuzzer(tipo: &Type, location: &Span) -> Result<Annotation, Error> {
match tipo {
// TODO: Ensure args & first returned element is a Prelude's PRNG.
Type::Fn { ret, .. } => {
let ann = tipo_to_annotation(ret, location)?;
match ann {
Annotation::Tuple { elems, .. } if elems.len() == 2 => {
Ok(elems.get(1).expect("Tuple has two elements").to_owned())
}
_ => todo!("Fuzzer returns something else than a 2-tuple? "),
}
}
Type::Var { .. } | Type::App { .. } | Type::Tuple { .. } => {
todo!("Fuzzer type isn't a function?");
}
}
}
fn tipo_to_annotation(tipo: &Type, location: &Span) -> Result<Annotation, Error> {
match tipo {
Type::App {
name, module, args, ..
} => {
let arguments = args
.iter()
.map(|arg| tipo_to_annotation(arg, location))
.collect::<Result<Vec<Annotation>, _>>()?;
Ok(Annotation::Constructor {
name: name.to_owned(),
module: Some(module.to_owned()),
arguments,
location: *location,
})
}
Type::Tuple { elems } => {
let elems = elems
.iter()
.map(|arg| tipo_to_annotation(arg, location))
.collect::<Result<Vec<Annotation>, _>>()?;
Ok(Annotation::Tuple {
elems,
location: *location,
})
}
Type::Fn { .. } | Type::Var { .. } => {
todo!("Fuzzer contains functions and/or non-concrete data-types?");
}
}
}
let annotation = match f.arguments.first() {
Some(arg) => {
if f.arguments.len() > 1 {
return Err(Error::IncorrectTestArity {
count: f.arguments.len(),
location: f.arguments.get(1).unwrap().location,
});
}
let ValueConstructor { tipo, .. } = ExprTyper::new(environment, lines, tracing)
.infer_value_constructor(&arg.via.module, &arg.via.name, &arg.location)?;
Ok(Some(annotate_fuzzer(&tipo, &arg.location)?))
}
None => Ok(None),
}?;
let typed_f = infer_function(
Function {
doc: f.doc,
location: f.location,
name: f.name,
public: f.public,
arguments: f
.arguments
.into_iter()
.map(|arg| Arg {
annotation: annotation.clone(),
..arg.into()
})
.collect(),
return_annotation: f.return_annotation,
return_type: f.return_type,
body: f.body,
can_error: f.can_error,
end_position: f.end_position,
},
module_name,
hydrators,
environment,
lines,
tracing,
)? {
environment.unify(f.return_type.clone(), builtins::bool(), f.location, false)?;
)?;
Ok(Definition::Test(f))
} else {
unreachable!("test definition inferred as something other than a function?")
}
environment.unify(
typed_f.return_type.clone(),
builtins::bool(),
typed_f.location,
false,
)?;
Ok(Definition::Test(Function {
doc: typed_f.doc,
location: typed_f.location,
name: typed_f.name,
public: typed_f.public,
arguments: match annotation {
Some(_) => vec![typed_f
.arguments
.first()
.expect("has exactly one argument")
.to_owned()
.into()],
None => vec![],
},
return_annotation: typed_f.return_annotation,
return_type: typed_f.return_type,
body: typed_f.body,
can_error: typed_f.can_error,
end_position: typed_f.end_position,
}))
}
Definition::TypeAlias(TypeAlias {
@ -640,3 +668,102 @@ fn infer_definition(
}
}
}
fn infer_function(
f: Function<(), UntypedExpr, UntypedArg>,
module_name: &String,
hydrators: &mut HashMap<String, Hydrator>,
environment: &mut Environment<'_>,
lines: &LineNumbers,
tracing: Tracing,
) -> Result<Function<Rc<Type>, TypedExpr, TypedArg>, Error> {
let Function {
doc,
location,
name,
public,
arguments,
body,
return_annotation,
end_position,
can_error,
..
} = f;
let preregistered_fn = environment
.get_variable(&name)
.expect("Could not find preregistered type for function");
let field_map = preregistered_fn.field_map().cloned();
let preregistered_type = preregistered_fn.tipo.clone();
let (args_types, return_type) = preregistered_type
.function_types()
.expect("Preregistered type for fn was not a fn");
// Infer the type using the preregistered args + return types as a starting point
let (tipo, arguments, body, safe_to_generalise) = environment.in_new_scope(|environment| {
let args = arguments
.into_iter()
.zip(&args_types)
.map(|(arg_name, tipo)| arg_name.set_type(tipo.clone()))
.collect();
let mut expr_typer = ExprTyper::new(environment, lines, tracing);
expr_typer.hydrator = hydrators
.remove(&name)
.expect("Could not find hydrator for fn");
let (args, body) = expr_typer.infer_fn_with_known_types(args, body, Some(return_type))?;
let args_types = args.iter().map(|a| a.tipo.clone()).collect();
let tipo = function(args_types, body.tipo());
let safe_to_generalise = !expr_typer.ungeneralised_function_used;
Ok::<_, Error>((tipo, args, body, safe_to_generalise))
})?;
// Assert that the inferred type matches the type of any recursive call
environment.unify(preregistered_type, tipo.clone(), location, false)?;
// Generalise the function if safe to do so
let tipo = if safe_to_generalise {
environment.ungeneralised_functions.remove(&name);
let tipo = generalise(tipo, 0);
let module_fn = ValueConstructorVariant::ModuleFn {
name: name.clone(),
field_map,
module: module_name.to_owned(),
arity: arguments.len(),
location,
builtin: None,
};
environment.insert_variable(name.clone(), module_fn, tipo.clone());
tipo
} else {
tipo
};
Ok(Function {
doc,
location,
name,
public,
arguments,
return_annotation,
return_type: tipo
.return_type()
.expect("Could not find return type for fn"),
body,
can_error,
end_position,
})
}

View File

@ -1,6 +1,6 @@
use pretty_assertions::assert_eq;
use aiken_lang::ast::{Definition, Function, TraceLevel, Tracing, TypedFunction, TypedValidator};
use aiken_lang::ast::{Definition, Function, TraceLevel, Tracing, TypedTest, TypedValidator};
use uplc::{
ast::{Constant, Data, DeBruijn, Name, Program, Term, Type},
builder::{CONSTR_FIELDS_EXPOSER, CONSTR_INDEX_EXPOSER},
@ -13,7 +13,7 @@ use crate::module::CheckedModules;
use super::TestProject;
enum TestType {
Func(TypedFunction),
Func(TypedTest),
Validator(TypedValidator),
}