From fff90f7df52f7c9782c50aa86bdafcde776bd909 Mon Sep 17 00:00:00 2001 From: rvcas Date: Thu, 18 Jul 2024 19:57:38 -0400 Subject: [PATCH] feat: fix inference comp issues --- crates/aiken-lang/src/tests/check.rs | 6 +- crates/aiken-lang/src/tipo/environment.rs | 86 ++++++------- crates/aiken-lang/src/tipo/infer.rs | 140 +++++++++++----------- 3 files changed, 118 insertions(+), 114 deletions(-) diff --git a/crates/aiken-lang/src/tests/check.rs b/crates/aiken-lang/src/tests/check.rs index 9c333fee..8049e535 100644 --- a/crates/aiken-lang/src/tests/check.rs +++ b/crates/aiken-lang/src/tests/check.rs @@ -459,8 +459,8 @@ fn exhaustiveness_simple() { #[test] fn validator_args_no_annotation() { let source_code = r#" - validator(d) { - fn foo(a, b, c) { + validator hello(d) { + foo (a, b, c) { True } } @@ -477,7 +477,7 @@ fn validator_args_no_annotation() { assert!(param.tipo.is_data()); }); - validator.fun.arguments.iter().for_each(|arg| { + validator.handlers[0].arguments.iter().for_each(|arg| { assert!(arg.tipo.is_data()); }) }) diff --git a/crates/aiken-lang/src/tipo/environment.rs b/crates/aiken-lang/src/tipo/environment.rs index 87760a9a..c1c309fa 100644 --- a/crates/aiken-lang/src/tipo/environment.rs +++ b/crates/aiken-lang/src/tipo/environment.rs @@ -307,32 +307,37 @@ impl<'a> Environment<'a> { Definition::Validator(Validator { doc, end_position, - fun, - other_fun, + handlers, + name, + fallback, location, params, }) => { - let Definition::Fn(fun) = - self.generalise_definition(Definition::Fn(fun), module_name) + let handlers = handlers + .into_iter() + .map(|fun| { + let Definition::Fn(fun) = + self.generalise_definition(Definition::Fn(fun), module_name) + else { + unreachable!() + }; + + fun + }) + .collect(); + + let Definition::Fn(fallback) = + self.generalise_definition(Definition::Fn(fallback), module_name) else { unreachable!() }; - let other_fun = other_fun.map(|other_fun| { - let Definition::Fn(other_fun) = - self.generalise_definition(Definition::Fn(other_fun), module_name) - else { - unreachable!() - }; - - other_fun - }); - Definition::Validator(Validator { doc, + name, end_position, - fun, - other_fun, + handlers, + fallback, location, params, }) @@ -1247,9 +1252,10 @@ impl<'a> Environment<'a> { } Definition::Validator(Validator { - fun, - other_fun, + handlers, + fallback, params, + name, doc: _, location: _, end_position: _, @@ -1264,41 +1270,41 @@ impl<'a> Environment<'a> { } }; - let temp_params: Vec = params - .iter() - .cloned() - .chain(fun.arguments.clone()) - .map(default_annotation) - .collect(); - - self.register_function( - &fun.name, - &temp_params, - &fun.return_annotation, - module_name, - hydrators, - names, - &fun.location, - )?; - - if let Some(other) = other_fun { + for handler in handlers { let temp_params: Vec = params .iter() .cloned() - .chain(other.arguments.clone()) + .chain(handler.arguments.clone()) .map(default_annotation) .collect(); self.register_function( - &other.name, + &handler.name, &temp_params, - &other.return_annotation, + &handler.return_annotation, module_name, hydrators, names, - &other.location, + &handler.location, )?; } + + let temp_params: Vec = params + .iter() + .cloned() + .chain(fallback.arguments.clone()) + .map(default_annotation) + .collect(); + + self.register_function( + &fallback.name, + &temp_params, + &fallback.return_annotation, + module_name, + hydrators, + names, + &fallback.location, + )?; } Definition::Validator(Validator { location, .. }) => { diff --git a/crates/aiken-lang/src/tipo/infer.rs b/crates/aiken-lang/src/tipo/infer.rs index a3e0ef6b..044a9beb 100644 --- a/crates/aiken-lang/src/tipo/infer.rs +++ b/crates/aiken-lang/src/tipo/infer.rs @@ -172,13 +172,12 @@ fn infer_definition( doc, location, end_position, - mut fun, - other_fun, + mut handlers, + fallback, params, + name, }) => { let params_length = params.len(); - let temp_params = params.iter().cloned().chain(fun.arguments); - fun.arguments = temp_params.collect(); environment.in_new_scope(|environment| { let preregistered_fn = environment @@ -220,19 +219,74 @@ fn infer_definition( }; } - let mut typed_fun = - infer_function(&fun, module_name, hydrators, environment, tracing)?; + let typed_handlers = vec![]; - if !typed_fun.return_type.is_bool() { + for handler in handlers { + let temp_params = params.iter().cloned().chain(fun.arguments); + fun.arguments = temp_params.collect(); + + let mut typed_fun = + infer_function(&fun, module_name, hydrators, environment, lines, tracing)?; + + if !typed_fun.return_type.is_bool() { + return Err(Error::ValidatorMustReturnBool { + return_type: typed_fun.return_type.clone(), + location: typed_fun.location, + }); + } + + if typed_fun.arguments.len() < 2 || typed_fun.arguments.len() > 3 { + return Err(Error::IncorrectValidatorArity { + count: typed_fun.arguments.len() as u32, + location: typed_fun.location, + }); + } + + for arg in typed_fun.arguments.iter_mut() { + if arg.tipo.is_unbound() { + arg.tipo = builtins::data(); + } + } + } + + let params = params.into_iter().chain(other.arguments); + other.arguments = params.collect(); + + let mut typed_fallback = + infer_function(&other, module_name, hydrators, environment, lines, tracing)?; + + if !typed_fallback.return_type.is_bool() { return Err(Error::ValidatorMustReturnBool { - return_type: typed_fun.return_type.clone(), - location: typed_fun.location, + return_type: typed_fallback.return_type.clone(), + location: typed_fallback.location, }); } - let typed_params = typed_fun - .arguments - .drain(0..params_length) + typed_fallback.arguments.drain(0..params_length); + + if typed_fallback.arguments.len() < 2 || typed_fallback.arguments.len() > 3 { + return Err(Error::IncorrectValidatorArity { + count: typed_fallback.arguments.len() as u32, + location: typed_fallback.location, + }); + } + + if typed_fun.arguments.len() == typed_fallback.arguments.len() { + return Err(Error::MultiValidatorEqualArgs { + location: typed_fun.location, + other_location: typed_fallback.location, + count: typed_fallback.arguments.len(), + }); + } + + for arg in typed_fallback.arguments.iter_mut() { + if arg.tipo.is_unbound() { + arg.tipo = builtins::data(); + } + } + + let typed_params = params + .into_iter() .map(|mut arg| { if arg.tipo.is_unbound() { arg.tipo = builtins::data(); @@ -242,68 +296,12 @@ fn infer_definition( }) .collect(); - if typed_fun.arguments.len() < 2 || typed_fun.arguments.len() > 3 { - return Err(Error::IncorrectValidatorArity { - count: typed_fun.arguments.len() as u32, - location: typed_fun.location, - }); - } - - for arg in typed_fun.arguments.iter_mut() { - if arg.tipo.is_unbound() { - arg.tipo = builtins::data(); - } - } - - let typed_other_fun = other_fun - .map(|mut other| -> Result { - let params = params.into_iter().chain(other.arguments); - other.arguments = params.collect(); - - let mut other_typed_fun = - infer_function(&other, module_name, hydrators, environment, tracing)?; - - if !other_typed_fun.return_type.is_bool() { - return Err(Error::ValidatorMustReturnBool { - return_type: other_typed_fun.return_type.clone(), - location: other_typed_fun.location, - }); - } - - other_typed_fun.arguments.drain(0..params_length); - - if other_typed_fun.arguments.len() < 2 - || other_typed_fun.arguments.len() > 3 - { - return Err(Error::IncorrectValidatorArity { - count: other_typed_fun.arguments.len() as u32, - location: other_typed_fun.location, - }); - } - - if typed_fun.arguments.len() == other_typed_fun.arguments.len() { - return Err(Error::MultiValidatorEqualArgs { - location: typed_fun.location, - other_location: other_typed_fun.location, - count: other_typed_fun.arguments.len(), - }); - } - - for arg in other_typed_fun.arguments.iter_mut() { - if arg.tipo.is_unbound() { - arg.tipo = builtins::data(); - } - } - - Ok(other_typed_fun) - }) - .transpose()?; - Ok(Definition::Validator(Validator { doc, end_position, - fun: typed_fun, - other_fun: typed_other_fun, + handlers: typed_handlers, + fallback: typed_fallback, + name, location, params: typed_params, }))