Infer callee first in function call

The current inferrence system walks expressions from "top to bottom".
  Starting from definitions higher in the source file, and down. When a
  call is encountered, we use the information known for the callee
  definition we have at the moment it is inferred.

  This causes interesting issues in the case where the callee doesn't
  have annotations and in only partially known. For example:

  ```
  pub fn list(fuzzer: Option<a>) -> Option<List<a>> {
    inner(fuzzer, [])
  }

  fn inner(fuzzer, xs) -> Option<List<b>> {
    when fuzzer is {
      None -> Some(xs)
      Some(x) -> Some([x, ..xs])
    }
  }
  ```

  In this small program, we infer `list` first and run into `inner`.
  Yet, the arguments for `inner` are not annotated, so since we haven't
  inferred `inner` yet, we will create two unbound variables.

  And naturally, we will link the type of `[]` to being of the same type
  as `xs` -- which is still unbound at this point. The return type of
  `inner` is given by the annotation, so all-in-all, the unification
  will work without ever having to commit to a type of `[]`.

  It is only later, when `inner` is inferred, that we will generalise
  the unbound type of `xs` to a generic which the same as `b` in the
  annotation. At this point, `[]` is also typed with this same generic,
  which has a different id than `a` in `list` since it comes from
  another type definition.

  This is unfortunate and will cause issues down the line for the code
  generation. The problem doesn't occur when `inner`'s arguments are
  properly annotated or, when `inner` is actually inferred first.

  Hence, I saw two possible avenues for fixing this problem:

  1. Detect the presence of 'uncongruous generics' in definitions after
     they've all been inferred, and raise a user error asking for more
     annotations.

  2. Infer definitions in dependency order, with definitions used in
     other inferred first.

  This commit does (2) (although it may still be a good idea to do (1)
  eventually) since it offers a much better user experience. One way to
  do (2) is to construct a dependency graph between function calls, and
  ensure perform a topological sort.

  Building such graph is, however, quite tricky as it requires walking
  through the AST while maintaining scope etc. which is more-or-less
  already what the inferrence step is doing; so it feels like double
  work.

  Thus instead, this commit tries to do a deep-first inferrence and
  "pause" inferrence of definitions when encountering a call to fully
  infer the callee first. To achieve this properly, we must ensure that
  we do not infer the same definition again, so we "remember" already
  inferred definitions in the environment now.
This commit is contained in:
KtorZ 2024-05-05 13:12:49 +02:00 committed by Kasey
parent 7b71389519
commit a124bdbb05
3 changed files with 211 additions and 169 deletions

View File

@ -8,8 +8,9 @@ use super::{
use crate::{
ast::{
Annotation, CallArg, DataType, Definition, Function, ModuleConstant, ModuleKind,
RecordConstructor, RecordConstructorArg, Span, TypeAlias, TypedDefinition, TypedPattern,
UnqualifiedImport, UntypedArg, UntypedDefinition, Use, Validator, PIPE_VARIABLE,
RecordConstructor, RecordConstructorArg, Span, TypeAlias, TypedDefinition, TypedFunction,
TypedPattern, UnqualifiedImport, UntypedArg, UntypedDefinition, UntypedFunction, Use,
Validator, PIPE_VARIABLE,
},
builtins::{function, generic_var, pair, tuple, unbound_var},
tipo::{fields::FieldMap, TypeAliasAnnotation},
@ -54,6 +55,12 @@ pub struct Environment<'a> {
/// Values defined in the current module (or the prelude)
pub module_values: HashMap<String, ValueConstructor>,
/// Top-level function definitions from the module
pub module_functions: HashMap<String, &'a UntypedFunction>,
/// Top-level functions that have been inferred
pub inferred_functions: HashMap<String, TypedFunction>,
previous_id: u64,
/// Values defined in the current function (or the prelude)
@ -707,9 +714,11 @@ impl<'a> Environment<'a> {
previous_id: id_gen.next(),
id_gen,
ungeneralised_functions: HashSet::new(),
inferred_functions: HashMap::new(),
module_types: prelude.types.clone(),
module_types_constructors: prelude.types_constructors.clone(),
module_values: HashMap::new(),
module_functions: HashMap::new(),
imported_modules: HashMap::new(),
unused_modules: HashMap::new(),
unqualified_imported_names: HashMap::new(),
@ -1201,6 +1210,8 @@ impl<'a> Environment<'a> {
&fun.location,
)?;
self.module_functions.insert(fun.name.clone(), fun);
if !fun.public {
self.init_usage(fun.name.clone(), EntityKind::PrivateFunction, fun.location);
}

View File

@ -1,5 +1,7 @@
use super::{
environment::{assert_no_labeled_arguments, collapse_links, EntityKind, Environment},
environment::{
assert_no_labeled_arguments, collapse_links, generalise, EntityKind, Environment,
},
error::{Error, Warning},
hydrator::Hydrator,
pattern::PatternTyper,
@ -9,11 +11,12 @@ use super::{
use crate::{
ast::{
self, Annotation, Arg, ArgName, AssignmentKind, AssignmentPattern, BinOp, Bls12_381Point,
ByteArrayFormatPreference, CallArg, ClauseGuard, Constant, Curve, IfBranch,
ByteArrayFormatPreference, CallArg, ClauseGuard, Constant, Curve, Function, IfBranch,
LogicalOpChainKind, Pattern, RecordUpdateSpread, Span, TraceKind, TraceLevel, Tracing,
TypedArg, TypedCallArg, TypedClause, TypedClauseGuard, TypedIfBranch, TypedPattern,
TypedRecordUpdateArg, UnOp, UntypedArg, UntypedAssignmentKind, UntypedClause,
UntypedClauseGuard, UntypedIfBranch, UntypedPattern, UntypedRecordUpdateArg,
UntypedClauseGuard, UntypedFunction, UntypedIfBranch, UntypedPattern,
UntypedRecordUpdateArg,
},
builtins::{
bool, byte_array, function, g1_element, g2_element, int, list, pair, string, tuple, void,
@ -26,12 +29,126 @@ use crate::{
use std::{cmp::Ordering, collections::HashMap, ops::Deref, rc::Rc};
use vec1::Vec1;
pub(crate) fn infer_function(
fun: &UntypedFunction,
module_name: &str,
hydrators: &mut HashMap<String, Hydrator>,
environment: &mut Environment<'_>,
lines: &LineNumbers,
tracing: Tracing,
) -> Result<Function<Rc<Type>, TypedExpr, TypedArg>, Error> {
if let Some(typed_fun) = environment.inferred_functions.get(&fun.name) {
return Ok(typed_fun.clone());
};
let Function {
doc,
location,
name,
public,
arguments,
body,
return_annotation,
end_position,
can_error,
return_type: _,
} = fun;
let preregistered_fn = environment
.get_variable(name)
.expect("Could not find preregistered type for function");
let field_map = preregistered_fn.field_map().cloned();
let preregistered_type = preregistered_fn.tipo.clone();
let (args_types, return_type) = preregistered_type
.function_types()
.unwrap_or_else(|| panic!("Preregistered type for fn {name} was not a fn"));
// Infer the type using the preregistered args + return types as a starting point
let (tipo, arguments, body, safe_to_generalise) = environment.in_new_scope(|environment| {
let args = arguments
.iter()
.zip(&args_types)
.map(|(arg_name, tipo)| arg_name.to_owned().set_type(tipo.clone()))
.collect();
let hydrator = hydrators
.remove(name)
.unwrap_or_else(|| panic!("Could not find hydrator for fn {name}"));
let mut expr_typer = ExprTyper::new(environment, hydrators, lines, tracing);
expr_typer.hydrator = hydrator;
let (args, body, return_type) =
expr_typer.infer_fn_with_known_types(args, body.to_owned(), Some(return_type))?;
let args_types = args.iter().map(|a| a.tipo.clone()).collect();
let tipo = function(args_types, return_type);
let safe_to_generalise = !expr_typer.ungeneralised_function_used;
Ok::<_, Error>((tipo, args, body, safe_to_generalise))
})?;
// Assert that the inferred type matches the type of any recursive call
environment.unify(preregistered_type, tipo.clone(), *location, false)?;
// Generalise the function if safe to do so
let tipo = if safe_to_generalise {
environment.ungeneralised_functions.remove(name);
let tipo = generalise(tipo, 0);
let module_fn = ValueConstructorVariant::ModuleFn {
name: name.clone(),
field_map,
module: module_name.to_owned(),
arity: arguments.len(),
location: *location,
builtin: None,
};
environment.insert_variable(name.clone(), module_fn, tipo.clone());
tipo
} else {
tipo
};
let inferred_fn = Function {
doc: doc.clone(),
location: *location,
name: name.clone(),
public: *public,
arguments,
return_annotation: return_annotation.clone(),
return_type: tipo
.return_type()
.expect("Could not find return type for fn"),
body,
can_error: *can_error,
end_position: *end_position,
};
environment
.inferred_functions
.insert(name.to_string(), inferred_fn.clone());
Ok(inferred_fn)
}
#[derive(Debug)]
pub(crate) struct ExprTyper<'a, 'b> {
pub(crate) lines: &'a LineNumbers,
pub(crate) environment: &'a mut Environment<'b>,
pub(crate) hydrators: &'a mut HashMap<String, Hydrator>,
// We tweak the tracing behavior during type-check. Traces are either kept or left out of the
// typed AST depending on this setting.
pub(crate) tracing: Tracing,
@ -46,6 +163,22 @@ pub(crate) struct ExprTyper<'a, 'b> {
}
impl<'a, 'b> ExprTyper<'a, 'b> {
pub fn new(
environment: &'a mut Environment<'b>,
hydrators: &'a mut HashMap<String, Hydrator>,
lines: &'a LineNumbers,
tracing: Tracing,
) -> Self {
Self {
hydrator: Hydrator::new(),
environment,
hydrators,
tracing,
ungeneralised_function_used: false,
lines,
}
}
fn check_when_exhaustiveness(
&mut self,
typed_clauses: &[TypedClause],
@ -2184,17 +2317,40 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
variables: self.environment.local_value_names(),
})?;
// Note whether we are using an ungeneralised function so that we can
// tell if it is safe to generalise this function after inference has
// completed.
if matches!(
&constructor.variant,
ValueConstructorVariant::ModuleFn { .. }
) {
if let ValueConstructorVariant::ModuleFn { name: fn_name, .. } =
&constructor.variant
{
// Note whether we are using an ungeneralised function so that we can
// tell if it is safe to generalise this function after inference has
// completed.
let is_ungeneralised = self.environment.ungeneralised_functions.contains(name);
self.ungeneralised_function_used =
self.ungeneralised_function_used || is_ungeneralised;
// In case we use another function, infer it first before going further.
// This ensures we have as much information possible about the function
// when we start inferring expressions using it (i.e. calls).
//
// In a way, this achieves a cheap topological processing of definitions
// where we infer used definitions first. And as a consequence, it solves
// issues where expressions would be wrongly assigned generic variables
// from other definitions.
if let Some(fun) = self.environment.module_functions.remove(fn_name) {
// NOTE: Recursive functions should not run into this multiple time.
// If we have no hydrator for this function, it means that we have already
// encountered it.
if self.hydrators.get(&fun.name).is_some() {
infer_function(
fun,
self.environment.current_module,
self.hydrators,
self.environment,
self.lines,
self.tracing,
)?;
}
}
}
// Register the value as seen for detection of unused values
@ -2323,20 +2479,6 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
self.environment.instantiate(t, ids, &self.hydrator)
}
pub fn new(
environment: &'a mut Environment<'b>,
lines: &'a LineNumbers,
tracing: Tracing,
) -> Self {
Self {
hydrator: Hydrator::new(),
environment,
tracing,
ungeneralised_function_used: false,
lines,
}
}
pub fn new_unbound_var(&mut self) -> Rc<Type> {
self.environment.new_unbound_var()
}

View File

@ -1,5 +1,5 @@
use super::{
environment::{generalise, EntityKind, Environment},
environment::{EntityKind, Environment},
error::{Error, UnifyErrorSituation, Warning},
expr::ExprTyper,
hydrator::Hydrator,
@ -8,15 +8,13 @@ use super::{
use crate::{
ast::{
Annotation, Arg, ArgName, ArgVia, DataType, Definition, Function, ModuleConstant,
ModuleKind, RecordConstructor, RecordConstructorArg, Tracing, TypeAlias, TypedArg,
TypedDefinition, TypedFunction, TypedModule, UntypedArg, UntypedDefinition, UntypedModule,
Use, Validator,
ModuleKind, RecordConstructor, RecordConstructorArg, Tracing, TypeAlias, TypedDefinition,
TypedFunction, TypedModule, UntypedDefinition, UntypedModule, Use, Validator,
},
builtins,
builtins::{function, fuzzer, generic_var},
expr::{TypedExpr, UntypedExpr},
builtins::{fuzzer, generic_var},
line_numbers::LineNumbers,
tipo::{Span, Type, TypeVar},
tipo::{expr::infer_function, Span, Type, TypeVar},
IdGenerator,
};
use std::{borrow::Borrow, collections::HashMap, ops::Deref, rc::Rc};
@ -31,9 +29,10 @@ impl UntypedModule {
tracing: Tracing,
warnings: &mut Vec<Warning>,
) -> Result<TypedModule, Error> {
let name = self.name.clone();
let module_name = self.name.clone();
let docs = std::mem::take(&mut self.docs);
let mut environment = Environment::new(id_gen.clone(), &name, &kind, modules, warnings);
let mut environment =
Environment::new(id_gen.clone(), &module_name, &kind, modules, warnings);
let mut type_names = HashMap::with_capacity(self.definitions.len());
let mut value_names = HashMap::with_capacity(self.definitions.len());
@ -50,14 +49,20 @@ impl UntypedModule {
// earlier in the module.
environment.register_types(
self.definitions.iter().collect(),
&name,
&module_name,
&mut hydrators,
&mut type_names,
)?;
// Register values so they can be used in functions earlier in the module.
for def in self.definitions() {
environment.register_values(def, &name, &mut hydrators, &mut value_names, kind)?;
environment.register_values(
def,
&module_name,
&mut hydrators,
&mut value_names,
kind,
)?;
}
// Infer the types of each definition in the module
@ -83,7 +88,7 @@ impl UntypedModule {
for def in consts.into_iter().chain(not_consts) {
let definition = infer_definition(
def,
&name,
&module_name,
&mut hydrators,
&mut environment,
&self.lines,
@ -96,7 +101,7 @@ impl UntypedModule {
// Generalise functions now that the entire module has been inferred
let definitions = definitions
.into_iter()
.map(|def| environment.generalise_definition(def, &name))
.map(|def| environment.generalise_definition(def, &module_name))
.collect();
// Generate warnings for unused items
@ -105,7 +110,7 @@ impl UntypedModule {
// Remove private and imported types and values to create the public interface
environment
.module_types
.retain(|_, info| info.public && info.module == name);
.retain(|_, info| info.public && info.module == module_name);
environment.module_values.retain(|_, info| info.public);
@ -134,12 +139,12 @@ impl UntypedModule {
Ok(TypedModule {
docs,
name: name.clone(),
name: module_name.clone(),
definitions,
kind,
lines: self.lines,
type_info: TypeInfo {
name,
name: module_name,
types,
types_constructors,
values,
@ -162,7 +167,7 @@ fn infer_definition(
) -> Result<TypedDefinition, Error> {
match def {
Definition::Fn(f) => Ok(Definition::Fn(infer_function(
f,
&f,
module_name,
hydrators,
environment,
@ -219,19 +224,8 @@ fn infer_definition(
};
}
let Definition::Fn(mut typed_fun) = infer_definition(
Definition::Fn(fun),
module_name,
hydrators,
environment,
lines,
tracing,
)?
else {
unreachable!(
"validator definition inferred as something other than a function?"
)
};
let mut typed_fun =
infer_function(&fun, module_name, hydrators, environment, lines, tracing)?;
if !typed_fun.return_type.is_bool() {
return Err(Error::ValidatorMustReturnBool {
@ -270,19 +264,14 @@ fn infer_definition(
let params = params.into_iter().chain(other.arguments);
other.arguments = params.collect();
let Definition::Fn(mut other_typed_fun) = infer_definition(
Definition::Fn(other),
let mut other_typed_fun = infer_function(
&other,
module_name,
hydrators,
environment,
lines,
tracing,
)?
else {
unreachable!(
"validator definition inferred as something other than a function?"
)
};
)?;
if !other_typed_fun.return_type.is_bool() {
return Err(Error::ValidatorMustReturnBool {
@ -341,8 +330,8 @@ fn infer_definition(
});
}
let typed_via =
ExprTyper::new(environment, lines, tracing).infer(arg.via.clone())?;
let typed_via = ExprTyper::new(environment, hydrators, lines, tracing)
.infer(arg.via.clone())?;
let hydrator: &mut Hydrator = hydrators.get_mut(&f.name).unwrap();
@ -406,7 +395,7 @@ fn infer_definition(
}?;
let typed_f = infer_function(
f.into(),
&f.into(),
module_name,
hydrators,
environment,
@ -635,8 +624,8 @@ fn infer_definition(
value,
tipo: _,
}) => {
let typed_expr =
ExprTyper::new(environment, lines, tracing).infer_const(&annotation, *value)?;
let typed_expr = ExprTyper::new(environment, hydrators, lines, tracing)
.infer_const(&annotation, *value)?;
let tipo = typed_expr.tipo();
@ -671,106 +660,6 @@ fn infer_definition(
}
}
fn infer_function(
f: Function<(), UntypedExpr, UntypedArg>,
module_name: &String,
hydrators: &mut HashMap<String, Hydrator>,
environment: &mut Environment<'_>,
lines: &LineNumbers,
tracing: Tracing,
) -> Result<Function<Rc<Type>, TypedExpr, TypedArg>, Error> {
let Function {
doc,
location,
name,
public,
arguments,
body,
return_annotation,
end_position,
can_error,
return_type: _,
} = f;
let preregistered_fn = environment
.get_variable(&name)
.expect("Could not find preregistered type for function");
let field_map = preregistered_fn.field_map().cloned();
let preregistered_type = preregistered_fn.tipo.clone();
let (args_types, return_type) = preregistered_type
.function_types()
.expect("Preregistered type for fn was not a fn");
// Infer the type using the preregistered args + return types as a starting point
let (tipo, arguments, body, safe_to_generalise) = environment.in_new_scope(|environment| {
let args = arguments
.into_iter()
.zip(&args_types)
.map(|(arg_name, tipo)| arg_name.set_type(tipo.clone()))
.collect();
let mut expr_typer = ExprTyper::new(environment, lines, tracing);
expr_typer.hydrator = hydrators
.remove(&name)
.expect("Could not find hydrator for fn");
let (args, body, return_type) =
expr_typer.infer_fn_with_known_types(args, body, Some(return_type))?;
let args_types = args.iter().map(|a| a.tipo.clone()).collect();
let tipo = function(args_types, return_type);
let safe_to_generalise = !expr_typer.ungeneralised_function_used;
Ok::<_, Error>((tipo, args, body, safe_to_generalise))
})?;
// Assert that the inferred type matches the type of any recursive call
environment.unify(preregistered_type, tipo.clone(), location, false)?;
// Generalise the function if safe to do so
let tipo = if safe_to_generalise {
environment.ungeneralised_functions.remove(&name);
let tipo = generalise(tipo, 0);
let module_fn = ValueConstructorVariant::ModuleFn {
name: name.clone(),
field_map,
module: module_name.to_owned(),
arity: arguments.len(),
location,
builtin: None,
};
environment.insert_variable(name.clone(), module_fn, tipo.clone());
tipo
} else {
tipo
};
Ok(Function {
doc,
location,
name,
public,
arguments,
return_annotation,
return_type: tipo
.return_type()
.expect("Could not find return type for fn"),
body,
can_error,
end_position,
})
}
fn infer_fuzzer(
environment: &mut Environment<'_>,
expected_inner_type: Option<Rc<Type>>,