feat: fix inference comp issues

This commit is contained in:
rvcas 2024-07-18 19:57:38 -04:00 committed by KtorZ
parent 0de5cbc74e
commit fff90f7df5
No known key found for this signature in database
GPG Key ID: 33173CB6F77F4277
3 changed files with 118 additions and 114 deletions

View File

@ -459,8 +459,8 @@ fn exhaustiveness_simple() {
#[test]
fn validator_args_no_annotation() {
let source_code = r#"
validator(d) {
fn foo(a, b, c) {
validator hello(d) {
foo (a, b, c) {
True
}
}
@ -477,7 +477,7 @@ fn validator_args_no_annotation() {
assert!(param.tipo.is_data());
});
validator.fun.arguments.iter().for_each(|arg| {
validator.handlers[0].arguments.iter().for_each(|arg| {
assert!(arg.tipo.is_data());
})
})

View File

@ -307,32 +307,37 @@ impl<'a> Environment<'a> {
Definition::Validator(Validator {
doc,
end_position,
fun,
other_fun,
handlers,
name,
fallback,
location,
params,
}) => {
let Definition::Fn(fun) =
self.generalise_definition(Definition::Fn(fun), module_name)
let handlers = handlers
.into_iter()
.map(|fun| {
let Definition::Fn(fun) =
self.generalise_definition(Definition::Fn(fun), module_name)
else {
unreachable!()
};
fun
})
.collect();
let Definition::Fn(fallback) =
self.generalise_definition(Definition::Fn(fallback), module_name)
else {
unreachable!()
};
let other_fun = other_fun.map(|other_fun| {
let Definition::Fn(other_fun) =
self.generalise_definition(Definition::Fn(other_fun), module_name)
else {
unreachable!()
};
other_fun
});
Definition::Validator(Validator {
doc,
name,
end_position,
fun,
other_fun,
handlers,
fallback,
location,
params,
})
@ -1247,9 +1252,10 @@ impl<'a> Environment<'a> {
}
Definition::Validator(Validator {
fun,
other_fun,
handlers,
fallback,
params,
name,
doc: _,
location: _,
end_position: _,
@ -1264,41 +1270,41 @@ impl<'a> Environment<'a> {
}
};
let temp_params: Vec<UntypedArg> = params
.iter()
.cloned()
.chain(fun.arguments.clone())
.map(default_annotation)
.collect();
self.register_function(
&fun.name,
&temp_params,
&fun.return_annotation,
module_name,
hydrators,
names,
&fun.location,
)?;
if let Some(other) = other_fun {
for handler in handlers {
let temp_params: Vec<UntypedArg> = params
.iter()
.cloned()
.chain(other.arguments.clone())
.chain(handler.arguments.clone())
.map(default_annotation)
.collect();
self.register_function(
&other.name,
&handler.name,
&temp_params,
&other.return_annotation,
&handler.return_annotation,
module_name,
hydrators,
names,
&other.location,
&handler.location,
)?;
}
let temp_params: Vec<UntypedArg> = params
.iter()
.cloned()
.chain(fallback.arguments.clone())
.map(default_annotation)
.collect();
self.register_function(
&fallback.name,
&temp_params,
&fallback.return_annotation,
module_name,
hydrators,
names,
&fallback.location,
)?;
}
Definition::Validator(Validator { location, .. }) => {

View File

@ -172,13 +172,12 @@ fn infer_definition(
doc,
location,
end_position,
mut fun,
other_fun,
mut handlers,
fallback,
params,
name,
}) => {
let params_length = params.len();
let temp_params = params.iter().cloned().chain(fun.arguments);
fun.arguments = temp_params.collect();
environment.in_new_scope(|environment| {
let preregistered_fn = environment
@ -220,19 +219,74 @@ fn infer_definition(
};
}
let mut typed_fun =
infer_function(&fun, module_name, hydrators, environment, tracing)?;
let typed_handlers = vec![];
if !typed_fun.return_type.is_bool() {
for handler in handlers {
let temp_params = params.iter().cloned().chain(fun.arguments);
fun.arguments = temp_params.collect();
let mut typed_fun =
infer_function(&fun, module_name, hydrators, environment, lines, tracing)?;
if !typed_fun.return_type.is_bool() {
return Err(Error::ValidatorMustReturnBool {
return_type: typed_fun.return_type.clone(),
location: typed_fun.location,
});
}
if typed_fun.arguments.len() < 2 || typed_fun.arguments.len() > 3 {
return Err(Error::IncorrectValidatorArity {
count: typed_fun.arguments.len() as u32,
location: typed_fun.location,
});
}
for arg in typed_fun.arguments.iter_mut() {
if arg.tipo.is_unbound() {
arg.tipo = builtins::data();
}
}
}
let params = params.into_iter().chain(other.arguments);
other.arguments = params.collect();
let mut typed_fallback =
infer_function(&other, module_name, hydrators, environment, lines, tracing)?;
if !typed_fallback.return_type.is_bool() {
return Err(Error::ValidatorMustReturnBool {
return_type: typed_fun.return_type.clone(),
location: typed_fun.location,
return_type: typed_fallback.return_type.clone(),
location: typed_fallback.location,
});
}
let typed_params = typed_fun
.arguments
.drain(0..params_length)
typed_fallback.arguments.drain(0..params_length);
if typed_fallback.arguments.len() < 2 || typed_fallback.arguments.len() > 3 {
return Err(Error::IncorrectValidatorArity {
count: typed_fallback.arguments.len() as u32,
location: typed_fallback.location,
});
}
if typed_fun.arguments.len() == typed_fallback.arguments.len() {
return Err(Error::MultiValidatorEqualArgs {
location: typed_fun.location,
other_location: typed_fallback.location,
count: typed_fallback.arguments.len(),
});
}
for arg in typed_fallback.arguments.iter_mut() {
if arg.tipo.is_unbound() {
arg.tipo = builtins::data();
}
}
let typed_params = params
.into_iter()
.map(|mut arg| {
if arg.tipo.is_unbound() {
arg.tipo = builtins::data();
@ -242,68 +296,12 @@ fn infer_definition(
})
.collect();
if typed_fun.arguments.len() < 2 || typed_fun.arguments.len() > 3 {
return Err(Error::IncorrectValidatorArity {
count: typed_fun.arguments.len() as u32,
location: typed_fun.location,
});
}
for arg in typed_fun.arguments.iter_mut() {
if arg.tipo.is_unbound() {
arg.tipo = builtins::data();
}
}
let typed_other_fun = other_fun
.map(|mut other| -> Result<TypedFunction, Error> {
let params = params.into_iter().chain(other.arguments);
other.arguments = params.collect();
let mut other_typed_fun =
infer_function(&other, module_name, hydrators, environment, tracing)?;
if !other_typed_fun.return_type.is_bool() {
return Err(Error::ValidatorMustReturnBool {
return_type: other_typed_fun.return_type.clone(),
location: other_typed_fun.location,
});
}
other_typed_fun.arguments.drain(0..params_length);
if other_typed_fun.arguments.len() < 2
|| other_typed_fun.arguments.len() > 3
{
return Err(Error::IncorrectValidatorArity {
count: other_typed_fun.arguments.len() as u32,
location: other_typed_fun.location,
});
}
if typed_fun.arguments.len() == other_typed_fun.arguments.len() {
return Err(Error::MultiValidatorEqualArgs {
location: typed_fun.location,
other_location: other_typed_fun.location,
count: other_typed_fun.arguments.len(),
});
}
for arg in other_typed_fun.arguments.iter_mut() {
if arg.tipo.is_unbound() {
arg.tipo = builtins::data();
}
}
Ok(other_typed_fun)
})
.transpose()?;
Ok(Definition::Validator(Validator {
doc,
end_position,
fun: typed_fun,
other_fun: typed_other_fun,
handlers: typed_handlers,
fallback: typed_fallback,
name,
location,
params: typed_params,
}))