feat: fix inference comp issues

This commit is contained in:
rvcas 2024-07-18 19:57:38 -04:00 committed by KtorZ
parent 0de5cbc74e
commit fff90f7df5
No known key found for this signature in database
GPG Key ID: 33173CB6F77F4277
3 changed files with 118 additions and 114 deletions

View File

@ -459,8 +459,8 @@ fn exhaustiveness_simple() {
#[test] #[test]
fn validator_args_no_annotation() { fn validator_args_no_annotation() {
let source_code = r#" let source_code = r#"
validator(d) { validator hello(d) {
fn foo(a, b, c) { foo (a, b, c) {
True True
} }
} }
@ -477,7 +477,7 @@ fn validator_args_no_annotation() {
assert!(param.tipo.is_data()); assert!(param.tipo.is_data());
}); });
validator.fun.arguments.iter().for_each(|arg| { validator.handlers[0].arguments.iter().for_each(|arg| {
assert!(arg.tipo.is_data()); assert!(arg.tipo.is_data());
}) })
}) })

View File

@ -307,32 +307,37 @@ impl<'a> Environment<'a> {
Definition::Validator(Validator { Definition::Validator(Validator {
doc, doc,
end_position, end_position,
fun, handlers,
other_fun, name,
fallback,
location, location,
params, params,
}) => { }) => {
let handlers = handlers
.into_iter()
.map(|fun| {
let Definition::Fn(fun) = let Definition::Fn(fun) =
self.generalise_definition(Definition::Fn(fun), module_name) self.generalise_definition(Definition::Fn(fun), module_name)
else { else {
unreachable!() unreachable!()
}; };
let other_fun = other_fun.map(|other_fun| { fun
let Definition::Fn(other_fun) = })
self.generalise_definition(Definition::Fn(other_fun), module_name) .collect();
let Definition::Fn(fallback) =
self.generalise_definition(Definition::Fn(fallback), module_name)
else { else {
unreachable!() unreachable!()
}; };
other_fun
});
Definition::Validator(Validator { Definition::Validator(Validator {
doc, doc,
name,
end_position, end_position,
fun, handlers,
other_fun, fallback,
location, location,
params, params,
}) })
@ -1247,9 +1252,10 @@ impl<'a> Environment<'a> {
} }
Definition::Validator(Validator { Definition::Validator(Validator {
fun, handlers,
other_fun, fallback,
params, params,
name,
doc: _, doc: _,
location: _, location: _,
end_position: _, end_position: _,
@ -1264,41 +1270,41 @@ impl<'a> Environment<'a> {
} }
}; };
for handler in handlers {
let temp_params: Vec<UntypedArg> = params let temp_params: Vec<UntypedArg> = params
.iter() .iter()
.cloned() .cloned()
.chain(fun.arguments.clone()) .chain(handler.arguments.clone())
.map(default_annotation) .map(default_annotation)
.collect(); .collect();
self.register_function( self.register_function(
&fun.name, &handler.name,
&temp_params, &temp_params,
&fun.return_annotation, &handler.return_annotation,
module_name, module_name,
hydrators, hydrators,
names, names,
&fun.location, &handler.location,
)?;
if let Some(other) = other_fun {
let temp_params: Vec<UntypedArg> = params
.iter()
.cloned()
.chain(other.arguments.clone())
.map(default_annotation)
.collect();
self.register_function(
&other.name,
&temp_params,
&other.return_annotation,
module_name,
hydrators,
names,
&other.location,
)?; )?;
} }
let temp_params: Vec<UntypedArg> = params
.iter()
.cloned()
.chain(fallback.arguments.clone())
.map(default_annotation)
.collect();
self.register_function(
&fallback.name,
&temp_params,
&fallback.return_annotation,
module_name,
hydrators,
names,
&fallback.location,
)?;
} }
Definition::Validator(Validator { location, .. }) => { Definition::Validator(Validator { location, .. }) => {

View File

@ -172,13 +172,12 @@ fn infer_definition(
doc, doc,
location, location,
end_position, end_position,
mut fun, mut handlers,
other_fun, fallback,
params, params,
name,
}) => { }) => {
let params_length = params.len(); let params_length = params.len();
let temp_params = params.iter().cloned().chain(fun.arguments);
fun.arguments = temp_params.collect();
environment.in_new_scope(|environment| { environment.in_new_scope(|environment| {
let preregistered_fn = environment let preregistered_fn = environment
@ -220,8 +219,14 @@ fn infer_definition(
}; };
} }
let typed_handlers = vec![];
for handler in handlers {
let temp_params = params.iter().cloned().chain(fun.arguments);
fun.arguments = temp_params.collect();
let mut typed_fun = let mut typed_fun =
infer_function(&fun, module_name, hydrators, environment, tracing)?; infer_function(&fun, module_name, hydrators, environment, lines, tracing)?;
if !typed_fun.return_type.is_bool() { if !typed_fun.return_type.is_bool() {
return Err(Error::ValidatorMustReturnBool { return Err(Error::ValidatorMustReturnBool {
@ -230,18 +235,6 @@ fn infer_definition(
}); });
} }
let typed_params = typed_fun
.arguments
.drain(0..params_length)
.map(|mut arg| {
if arg.tipo.is_unbound() {
arg.tipo = builtins::data();
}
arg
})
.collect();
if typed_fun.arguments.len() < 2 || typed_fun.arguments.len() > 3 { if typed_fun.arguments.len() < 2 || typed_fun.arguments.len() > 3 {
return Err(Error::IncorrectValidatorArity { return Err(Error::IncorrectValidatorArity {
count: typed_fun.arguments.len() as u32, count: typed_fun.arguments.len() as u32,
@ -254,56 +247,61 @@ fn infer_definition(
arg.tipo = builtins::data(); arg.tipo = builtins::data();
} }
} }
}
let typed_other_fun = other_fun
.map(|mut other| -> Result<TypedFunction, Error> {
let params = params.into_iter().chain(other.arguments); let params = params.into_iter().chain(other.arguments);
other.arguments = params.collect(); other.arguments = params.collect();
let mut other_typed_fun = let mut typed_fallback =
infer_function(&other, module_name, hydrators, environment, tracing)?; infer_function(&other, module_name, hydrators, environment, lines, tracing)?;
if !other_typed_fun.return_type.is_bool() { if !typed_fallback.return_type.is_bool() {
return Err(Error::ValidatorMustReturnBool { return Err(Error::ValidatorMustReturnBool {
return_type: other_typed_fun.return_type.clone(), return_type: typed_fallback.return_type.clone(),
location: other_typed_fun.location, location: typed_fallback.location,
}); });
} }
other_typed_fun.arguments.drain(0..params_length); typed_fallback.arguments.drain(0..params_length);
if other_typed_fun.arguments.len() < 2 if typed_fallback.arguments.len() < 2 || typed_fallback.arguments.len() > 3 {
|| other_typed_fun.arguments.len() > 3
{
return Err(Error::IncorrectValidatorArity { return Err(Error::IncorrectValidatorArity {
count: other_typed_fun.arguments.len() as u32, count: typed_fallback.arguments.len() as u32,
location: other_typed_fun.location, location: typed_fallback.location,
}); });
} }
if typed_fun.arguments.len() == other_typed_fun.arguments.len() { if typed_fun.arguments.len() == typed_fallback.arguments.len() {
return Err(Error::MultiValidatorEqualArgs { return Err(Error::MultiValidatorEqualArgs {
location: typed_fun.location, location: typed_fun.location,
other_location: other_typed_fun.location, other_location: typed_fallback.location,
count: other_typed_fun.arguments.len(), count: typed_fallback.arguments.len(),
}); });
} }
for arg in other_typed_fun.arguments.iter_mut() { for arg in typed_fallback.arguments.iter_mut() {
if arg.tipo.is_unbound() { if arg.tipo.is_unbound() {
arg.tipo = builtins::data(); arg.tipo = builtins::data();
} }
} }
Ok(other_typed_fun) let typed_params = params
.into_iter()
.map(|mut arg| {
if arg.tipo.is_unbound() {
arg.tipo = builtins::data();
}
arg
}) })
.transpose()?; .collect();
Ok(Definition::Validator(Validator { Ok(Definition::Validator(Validator {
doc, doc,
end_position, end_position,
fun: typed_fun, handlers: typed_handlers,
other_fun: typed_other_fun, fallback: typed_fallback,
name,
location, location,
params: typed_params, params: typed_params,
})) }))