diff --git a/crates/aiken-lang/src/ast.rs b/crates/aiken-lang/src/ast.rs index 2d5102fe..6953f6a0 100644 --- a/crates/aiken-lang/src/ast.rs +++ b/crates/aiken-lang/src/ast.rs @@ -564,6 +564,7 @@ pub enum ArgName { name: String, label: String, location: Span, + is_validator_param: bool, }, } diff --git a/crates/aiken-lang/src/builtins.rs b/crates/aiken-lang/src/builtins.rs index c7ba3f37..b6e8dca5 100644 --- a/crates/aiken-lang/src/builtins.rs +++ b/crates/aiken-lang/src/builtins.rs @@ -675,6 +675,7 @@ pub fn prelude_functions(id_gen: &IdGenerator) -> IndexMap IndexMap IndexMap IndexMap IndexMap IndexMap impl Parser impl Parser { - let func_parser = just(Token::Fn) - .ignore_then(select! {Token::Name {name} => name}) - .then( - fn_param_parser() - .separated_by(just(Token::Comma)) - .allow_trailing() - .delimited_by(just(Token::LeftParen), just(Token::RightParen)) - .map_with_span(|arguments, span| (arguments, span)), - ) - .then(just(Token::RArrow).ignore_then(type_parser()).or_not()) - .then( - expr_seq_parser() - .or_not() - .delimited_by(just(Token::LeftBrace), just(Token::RightBrace)), - ) - .map_with_span( - |(((name, (arguments, args_span)), return_annotation), body), span| ast::Function { - arguments, - body: body.unwrap_or_else(|| expr::UntypedExpr::todo(span, None)), - doc: None, - location: Span { - start: span.start, - end: return_annotation - .as_ref() - .map(|l| l.location().end) - .unwrap_or_else(|| args_span.end), - }, - end_position: span.end - 1, - name, - public: false, - return_annotation, - return_type: (), - }, - ); - just(Token::Validator) .ignore_then( - fn_param_parser() + fn_param_parser(true) .separated_by(just(Token::Comma)) .allow_trailing() .delimited_by(just(Token::LeftParen), just(Token::RightParen)) @@ -287,12 +252,20 @@ pub fn validator_parser() -> impl Parser impl Parser name}) .then( - fn_param_parser() + fn_param_parser(false) .separated_by(just(Token::Comma)) .allow_trailing() .delimited_by(just(Token::LeftParen), just(Token::RightParen)) @@ -489,7 +462,9 @@ pub fn bytearray_parser( )) } -pub fn fn_param_parser() -> impl Parser { +pub fn fn_param_parser( + is_validator_param: bool, +) -> impl Parser { choice(( select! {Token::Name {name} => name} .then(select! {Token::DiscardName {name} => name}) @@ -507,15 +482,17 @@ pub fn fn_param_parser() -> impl Parser name} .then(select! {Token::Name {name} => name}) - .map_with_span(|(label, name), span| ast::ArgName::Named { + .map_with_span(move |(label, name), span| ast::ArgName::Named { label, name, location: span, + is_validator_param, }), - select! {Token::Name {name} => name}.map_with_span(|name, span| ast::ArgName::Named { + select! {Token::Name {name} => name}.map_with_span(move |name, span| ast::ArgName::Named { label: name.clone(), name, location: span, + is_validator_param, }), )) .then(just(Token::Colon).ignore_then(type_parser()).or_not()) @@ -541,6 +518,7 @@ pub fn anon_fn_param_parser() -> impl Parser ExprTyper<'a, 'b> { ) -> Result<(Vec, TypedExpr), Error> { self.assert_no_assignment(&body)?; - for (arg, t) in args.iter().zip(args.iter().map(|arg| arg.tipo.clone())) { + for arg in &args { match &arg.arg_name { - ArgName::Named { name, .. } => { + ArgName::Named { + name, + is_validator_param, + .. + } if !is_validator_param => { self.environment.insert_variable( name.to_string(), ValueConstructorVariant::LocalVariable { location: arg.location, }, - t, + arg.tipo.clone(), ); self.environment.init_usage( @@ -1497,7 +1501,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { arg.location, ); } - ArgName::Discarded { .. } => (), + ArgName::Named { .. } | ArgName::Discarded { .. } => (), }; } diff --git a/crates/aiken-lang/src/tipo/infer.rs b/crates/aiken-lang/src/tipo/infer.rs index a8e030c0..e63586e7 100644 --- a/crates/aiken-lang/src/tipo/infer.rs +++ b/crates/aiken-lang/src/tipo/infer.rs @@ -2,9 +2,9 @@ use std::collections::HashMap; use crate::{ ast::{ - DataType, Definition, Function, Layer, ModuleConstant, ModuleKind, RecordConstructor, - RecordConstructorArg, Span, Tracing, TypeAlias, TypedDefinition, TypedFunction, - TypedModule, UntypedDefinition, UntypedModule, Use, Validator, + ArgName, DataType, Definition, Function, Layer, ModuleConstant, ModuleKind, + RecordConstructor, RecordConstructorArg, Span, Tracing, TypeAlias, TypedDefinition, + TypedFunction, TypedModule, UntypedDefinition, UntypedModule, Use, Validator, }, builtins, builtins::function, @@ -262,87 +262,127 @@ fn infer_definition( let temp_params = params.iter().cloned().chain(fun.arguments); fun.arguments = temp_params.collect(); - let Definition::Fn(mut typed_fun) = infer_definition( - Definition::Fn(fun), - module_name, - hydrators, - environment, - tracing, - kind, - )? else { - unreachable!("validator definition inferred as something other than a function?") - }; + environment.in_new_scope(|environment| { + let preregistered_fn = environment + .get_variable(&fun.name) + .expect("Could not find preregistered type for function"); - if !typed_fun.return_type.is_bool() { - return Err(Error::ValidatorMustReturnBool { - return_type: typed_fun.return_type.clone(), - location: typed_fun.location, - }); - } - let typed_params = typed_fun.arguments.drain(0..params_length).collect(); + let preregistered_type = preregistered_fn.tipo.clone(); - if typed_fun.arguments.len() < 2 || typed_fun.arguments.len() > 3 { - return Err(Error::IncorrectValidatorArity { - count: typed_fun.arguments.len() as u32, - location: typed_fun.location, - }); - } + let (args_types, _return_type) = preregistered_type + .function_types() + .expect("Preregistered type for fn was not a fn"); - let typed_other_fun = other_fun - .map(|mut other| -> Result { - let params = params.into_iter().chain(other.arguments); - other.arguments = params.collect(); + for (arg, t) in params.iter().zip(args_types[0..params.len()].iter()) { + match &arg.arg_name { + ArgName::Named { + name, + is_validator_param, + .. + } if *is_validator_param => { + environment.insert_variable( + name.to_string(), + ValueConstructorVariant::LocalVariable { + location: arg.location, + }, + t.clone(), + ); - let Definition::Fn(mut other_typed_fun) = infer_definition( - Definition::Fn(other), - module_name, - hydrators, - environment, - tracing, - kind, - )? else { - unreachable!( - "validator definition inferred as something other than a function?" - ) + environment.init_usage( + name.to_string(), + EntityKind::Variable, + arg.location, + ); + } + ArgName::Named { .. } | ArgName::Discarded { .. } => (), }; + } - if !other_typed_fun.return_type.is_bool() { - return Err(Error::ValidatorMustReturnBool { - return_type: other_typed_fun.return_type.clone(), - location: other_typed_fun.location, - }); - } + let Definition::Fn(mut typed_fun) = infer_definition( + Definition::Fn(fun), + module_name, + hydrators, + environment, + tracing, + kind, + )? else { + unreachable!("validator definition inferred as something other than a function?") + }; - other_typed_fun.arguments.drain(0..params_length); + if !typed_fun.return_type.is_bool() { + return Err(Error::ValidatorMustReturnBool { + return_type: typed_fun.return_type.clone(), + location: typed_fun.location, + }); + } - if other_typed_fun.arguments.len() < 2 || other_typed_fun.arguments.len() > 3 { - return Err(Error::IncorrectValidatorArity { - count: other_typed_fun.arguments.len() as u32, - location: other_typed_fun.location, - }); - } + let typed_params = typed_fun.arguments.drain(0..params_length).collect(); - if typed_fun.arguments.len() == other_typed_fun.arguments.len() { - return Err(Error::MultiValidatorEqualArgs { - location: typed_fun.location, - other_location: other_typed_fun.location, - count: other_typed_fun.arguments.len(), - }); - } + if typed_fun.arguments.len() < 2 || typed_fun.arguments.len() > 3 { + return Err(Error::IncorrectValidatorArity { + count: typed_fun.arguments.len() as u32, + location: typed_fun.location, + }); + } - Ok(other_typed_fun) - }) - .transpose(); + let typed_other_fun = other_fun + .map(|mut other| -> Result { + let params = params.into_iter().chain(other.arguments); + other.arguments = params.collect(); - Ok(Definition::Validator(Validator { - doc, - end_position, - fun: typed_fun, - other_fun: typed_other_fun?, - location, - params: typed_params, - })) + let Definition::Fn(mut other_typed_fun) = infer_definition( + Definition::Fn(other), + module_name, + hydrators, + environment, + tracing, + kind, + )? else { + unreachable!( + "validator definition inferred as something other than a function?" + ) + }; + + if !other_typed_fun.return_type.is_bool() { + return Err(Error::ValidatorMustReturnBool { + return_type: other_typed_fun.return_type.clone(), + location: other_typed_fun.location, + }); + } + + other_typed_fun.arguments.drain(0..params_length); + + if other_typed_fun.arguments.len() < 2 + || other_typed_fun.arguments.len() > 3 + { + return Err(Error::IncorrectValidatorArity { + count: other_typed_fun.arguments.len() as u32, + location: other_typed_fun.location, + }); + } + + if typed_fun.arguments.len() == other_typed_fun.arguments.len() { + return Err(Error::MultiValidatorEqualArgs { + location: typed_fun.location, + other_location: other_typed_fun.location, + count: other_typed_fun.arguments.len(), + }); + } + + Ok(other_typed_fun) + }) + .transpose(); + + Ok(Definition::Validator(Validator { + doc, + end_position, + fun: typed_fun, + other_fun: typed_other_fun?, + location, + params: typed_params, + })) + }) } Definition::Test(f) => {