diff --git a/crates/aiken-lang/src/tipo/environment.rs b/crates/aiken-lang/src/tipo/environment.rs index d02f0cba..972df9ad 100644 --- a/crates/aiken-lang/src/tipo/environment.rs +++ b/crates/aiken-lang/src/tipo/environment.rs @@ -8,8 +8,9 @@ use super::{ use crate::{ ast::{ Annotation, CallArg, DataType, Definition, Function, ModuleConstant, ModuleKind, - RecordConstructor, RecordConstructorArg, Span, TypeAlias, TypedDefinition, TypedPattern, - UnqualifiedImport, UntypedArg, UntypedDefinition, Use, Validator, PIPE_VARIABLE, + RecordConstructor, RecordConstructorArg, Span, TypeAlias, TypedDefinition, TypedFunction, + TypedPattern, UnqualifiedImport, UntypedArg, UntypedDefinition, UntypedFunction, Use, + Validator, PIPE_VARIABLE, }, builtins::{function, generic_var, pair, tuple, unbound_var}, tipo::{fields::FieldMap, TypeAliasAnnotation}, @@ -54,6 +55,12 @@ pub struct Environment<'a> { /// Values defined in the current module (or the prelude) pub module_values: HashMap, + /// Top-level function definitions from the module + pub module_functions: HashMap, + + /// Top-level functions that have been inferred + pub inferred_functions: HashMap, + previous_id: u64, /// Values defined in the current function (or the prelude) @@ -707,9 +714,11 @@ impl<'a> Environment<'a> { previous_id: id_gen.next(), id_gen, ungeneralised_functions: HashSet::new(), + inferred_functions: HashMap::new(), module_types: prelude.types.clone(), module_types_constructors: prelude.types_constructors.clone(), module_values: HashMap::new(), + module_functions: HashMap::new(), imported_modules: HashMap::new(), unused_modules: HashMap::new(), unqualified_imported_names: HashMap::new(), @@ -1201,6 +1210,8 @@ impl<'a> Environment<'a> { &fun.location, )?; + self.module_functions.insert(fun.name.clone(), fun); + if !fun.public { self.init_usage(fun.name.clone(), EntityKind::PrivateFunction, fun.location); } diff --git a/crates/aiken-lang/src/tipo/expr.rs b/crates/aiken-lang/src/tipo/expr.rs index 3579a2ed..c6b4a634 100644 --- a/crates/aiken-lang/src/tipo/expr.rs +++ b/crates/aiken-lang/src/tipo/expr.rs @@ -1,5 +1,7 @@ use super::{ - environment::{assert_no_labeled_arguments, collapse_links, EntityKind, Environment}, + environment::{ + assert_no_labeled_arguments, collapse_links, generalise, EntityKind, Environment, + }, error::{Error, Warning}, hydrator::Hydrator, pattern::PatternTyper, @@ -9,11 +11,12 @@ use super::{ use crate::{ ast::{ self, Annotation, Arg, ArgName, AssignmentKind, AssignmentPattern, BinOp, Bls12_381Point, - ByteArrayFormatPreference, CallArg, ClauseGuard, Constant, Curve, IfBranch, + ByteArrayFormatPreference, CallArg, ClauseGuard, Constant, Curve, Function, IfBranch, LogicalOpChainKind, Pattern, RecordUpdateSpread, Span, TraceKind, TraceLevel, Tracing, TypedArg, TypedCallArg, TypedClause, TypedClauseGuard, TypedIfBranch, TypedPattern, TypedRecordUpdateArg, UnOp, UntypedArg, UntypedAssignmentKind, UntypedClause, - UntypedClauseGuard, UntypedIfBranch, UntypedPattern, UntypedRecordUpdateArg, + UntypedClauseGuard, UntypedFunction, UntypedIfBranch, UntypedPattern, + UntypedRecordUpdateArg, }, builtins::{ bool, byte_array, function, g1_element, g2_element, int, list, pair, string, tuple, void, @@ -26,12 +29,126 @@ use crate::{ use std::{cmp::Ordering, collections::HashMap, ops::Deref, rc::Rc}; use vec1::Vec1; +pub(crate) fn infer_function( + fun: &UntypedFunction, + module_name: &str, + hydrators: &mut HashMap, + environment: &mut Environment<'_>, + lines: &LineNumbers, + tracing: Tracing, +) -> Result, TypedExpr, TypedArg>, Error> { + if let Some(typed_fun) = environment.inferred_functions.get(&fun.name) { + return Ok(typed_fun.clone()); + }; + + let Function { + doc, + location, + name, + public, + arguments, + body, + return_annotation, + end_position, + can_error, + return_type: _, + } = fun; + + let preregistered_fn = environment + .get_variable(name) + .expect("Could not find preregistered type for function"); + + let field_map = preregistered_fn.field_map().cloned(); + + let preregistered_type = preregistered_fn.tipo.clone(); + + let (args_types, return_type) = preregistered_type + .function_types() + .unwrap_or_else(|| panic!("Preregistered type for fn {name} was not a fn")); + + // Infer the type using the preregistered args + return types as a starting point + let (tipo, arguments, body, safe_to_generalise) = environment.in_new_scope(|environment| { + let args = arguments + .iter() + .zip(&args_types) + .map(|(arg_name, tipo)| arg_name.to_owned().set_type(tipo.clone())) + .collect(); + + let hydrator = hydrators + .remove(name) + .unwrap_or_else(|| panic!("Could not find hydrator for fn {name}")); + + let mut expr_typer = ExprTyper::new(environment, hydrators, lines, tracing); + + expr_typer.hydrator = hydrator; + + let (args, body, return_type) = + expr_typer.infer_fn_with_known_types(args, body.to_owned(), Some(return_type))?; + + let args_types = args.iter().map(|a| a.tipo.clone()).collect(); + + let tipo = function(args_types, return_type); + + let safe_to_generalise = !expr_typer.ungeneralised_function_used; + + Ok::<_, Error>((tipo, args, body, safe_to_generalise)) + })?; + + // Assert that the inferred type matches the type of any recursive call + environment.unify(preregistered_type, tipo.clone(), *location, false)?; + + // Generalise the function if safe to do so + let tipo = if safe_to_generalise { + environment.ungeneralised_functions.remove(name); + + let tipo = generalise(tipo, 0); + + let module_fn = ValueConstructorVariant::ModuleFn { + name: name.clone(), + field_map, + module: module_name.to_owned(), + arity: arguments.len(), + location: *location, + builtin: None, + }; + + environment.insert_variable(name.clone(), module_fn, tipo.clone()); + + tipo + } else { + tipo + }; + + let inferred_fn = Function { + doc: doc.clone(), + location: *location, + name: name.clone(), + public: *public, + arguments, + return_annotation: return_annotation.clone(), + return_type: tipo + .return_type() + .expect("Could not find return type for fn"), + body, + can_error: *can_error, + end_position: *end_position, + }; + + environment + .inferred_functions + .insert(name.to_string(), inferred_fn.clone()); + + Ok(inferred_fn) +} + #[derive(Debug)] pub(crate) struct ExprTyper<'a, 'b> { pub(crate) lines: &'a LineNumbers, pub(crate) environment: &'a mut Environment<'b>, + pub(crate) hydrators: &'a mut HashMap, + // We tweak the tracing behavior during type-check. Traces are either kept or left out of the // typed AST depending on this setting. pub(crate) tracing: Tracing, @@ -46,6 +163,22 @@ pub(crate) struct ExprTyper<'a, 'b> { } impl<'a, 'b> ExprTyper<'a, 'b> { + pub fn new( + environment: &'a mut Environment<'b>, + hydrators: &'a mut HashMap, + lines: &'a LineNumbers, + tracing: Tracing, + ) -> Self { + Self { + hydrator: Hydrator::new(), + environment, + hydrators, + tracing, + ungeneralised_function_used: false, + lines, + } + } + fn check_when_exhaustiveness( &mut self, typed_clauses: &[TypedClause], @@ -2184,17 +2317,40 @@ impl<'a, 'b> ExprTyper<'a, 'b> { variables: self.environment.local_value_names(), })?; - // Note whether we are using an ungeneralised function so that we can - // tell if it is safe to generalise this function after inference has - // completed. - if matches!( - &constructor.variant, - ValueConstructorVariant::ModuleFn { .. } - ) { + if let ValueConstructorVariant::ModuleFn { name: fn_name, .. } = + &constructor.variant + { + // Note whether we are using an ungeneralised function so that we can + // tell if it is safe to generalise this function after inference has + // completed. let is_ungeneralised = self.environment.ungeneralised_functions.contains(name); self.ungeneralised_function_used = self.ungeneralised_function_used || is_ungeneralised; + + // In case we use another function, infer it first before going further. + // This ensures we have as much information possible about the function + // when we start inferring expressions using it (i.e. calls). + // + // In a way, this achieves a cheap topological processing of definitions + // where we infer used definitions first. And as a consequence, it solves + // issues where expressions would be wrongly assigned generic variables + // from other definitions. + if let Some(fun) = self.environment.module_functions.remove(fn_name) { + // NOTE: Recursive functions should not run into this multiple time. + // If we have no hydrator for this function, it means that we have already + // encountered it. + if self.hydrators.get(&fun.name).is_some() { + infer_function( + fun, + self.environment.current_module, + self.hydrators, + self.environment, + self.lines, + self.tracing, + )?; + } + } } // Register the value as seen for detection of unused values @@ -2323,20 +2479,6 @@ impl<'a, 'b> ExprTyper<'a, 'b> { self.environment.instantiate(t, ids, &self.hydrator) } - pub fn new( - environment: &'a mut Environment<'b>, - lines: &'a LineNumbers, - tracing: Tracing, - ) -> Self { - Self { - hydrator: Hydrator::new(), - environment, - tracing, - ungeneralised_function_used: false, - lines, - } - } - pub fn new_unbound_var(&mut self) -> Rc { self.environment.new_unbound_var() } diff --git a/crates/aiken-lang/src/tipo/infer.rs b/crates/aiken-lang/src/tipo/infer.rs index f4f6cd85..f43d061a 100644 --- a/crates/aiken-lang/src/tipo/infer.rs +++ b/crates/aiken-lang/src/tipo/infer.rs @@ -1,5 +1,5 @@ use super::{ - environment::{generalise, EntityKind, Environment}, + environment::{EntityKind, Environment}, error::{Error, UnifyErrorSituation, Warning}, expr::ExprTyper, hydrator::Hydrator, @@ -8,15 +8,13 @@ use super::{ use crate::{ ast::{ Annotation, Arg, ArgName, ArgVia, DataType, Definition, Function, ModuleConstant, - ModuleKind, RecordConstructor, RecordConstructorArg, Tracing, TypeAlias, TypedArg, - TypedDefinition, TypedFunction, TypedModule, UntypedArg, UntypedDefinition, UntypedModule, - Use, Validator, + ModuleKind, RecordConstructor, RecordConstructorArg, Tracing, TypeAlias, TypedDefinition, + TypedFunction, TypedModule, UntypedDefinition, UntypedModule, Use, Validator, }, builtins, - builtins::{function, fuzzer, generic_var}, - expr::{TypedExpr, UntypedExpr}, + builtins::{fuzzer, generic_var}, line_numbers::LineNumbers, - tipo::{Span, Type, TypeVar}, + tipo::{expr::infer_function, Span, Type, TypeVar}, IdGenerator, }; use std::{borrow::Borrow, collections::HashMap, ops::Deref, rc::Rc}; @@ -31,9 +29,10 @@ impl UntypedModule { tracing: Tracing, warnings: &mut Vec, ) -> Result { - let name = self.name.clone(); + let module_name = self.name.clone(); let docs = std::mem::take(&mut self.docs); - let mut environment = Environment::new(id_gen.clone(), &name, &kind, modules, warnings); + let mut environment = + Environment::new(id_gen.clone(), &module_name, &kind, modules, warnings); let mut type_names = HashMap::with_capacity(self.definitions.len()); let mut value_names = HashMap::with_capacity(self.definitions.len()); @@ -50,14 +49,20 @@ impl UntypedModule { // earlier in the module. environment.register_types( self.definitions.iter().collect(), - &name, + &module_name, &mut hydrators, &mut type_names, )?; // Register values so they can be used in functions earlier in the module. for def in self.definitions() { - environment.register_values(def, &name, &mut hydrators, &mut value_names, kind)?; + environment.register_values( + def, + &module_name, + &mut hydrators, + &mut value_names, + kind, + )?; } // Infer the types of each definition in the module @@ -83,7 +88,7 @@ impl UntypedModule { for def in consts.into_iter().chain(not_consts) { let definition = infer_definition( def, - &name, + &module_name, &mut hydrators, &mut environment, &self.lines, @@ -96,7 +101,7 @@ impl UntypedModule { // Generalise functions now that the entire module has been inferred let definitions = definitions .into_iter() - .map(|def| environment.generalise_definition(def, &name)) + .map(|def| environment.generalise_definition(def, &module_name)) .collect(); // Generate warnings for unused items @@ -105,7 +110,7 @@ impl UntypedModule { // Remove private and imported types and values to create the public interface environment .module_types - .retain(|_, info| info.public && info.module == name); + .retain(|_, info| info.public && info.module == module_name); environment.module_values.retain(|_, info| info.public); @@ -134,12 +139,12 @@ impl UntypedModule { Ok(TypedModule { docs, - name: name.clone(), + name: module_name.clone(), definitions, kind, lines: self.lines, type_info: TypeInfo { - name, + name: module_name, types, types_constructors, values, @@ -162,7 +167,7 @@ fn infer_definition( ) -> Result { match def { Definition::Fn(f) => Ok(Definition::Fn(infer_function( - f, + &f, module_name, hydrators, environment, @@ -219,19 +224,8 @@ fn infer_definition( }; } - let Definition::Fn(mut typed_fun) = infer_definition( - Definition::Fn(fun), - module_name, - hydrators, - environment, - lines, - tracing, - )? - else { - unreachable!( - "validator definition inferred as something other than a function?" - ) - }; + let mut typed_fun = + infer_function(&fun, module_name, hydrators, environment, lines, tracing)?; if !typed_fun.return_type.is_bool() { return Err(Error::ValidatorMustReturnBool { @@ -270,19 +264,14 @@ fn infer_definition( let params = params.into_iter().chain(other.arguments); other.arguments = params.collect(); - let Definition::Fn(mut other_typed_fun) = infer_definition( - Definition::Fn(other), + let mut other_typed_fun = infer_function( + &other, module_name, hydrators, environment, lines, tracing, - )? - else { - unreachable!( - "validator definition inferred as something other than a function?" - ) - }; + )?; if !other_typed_fun.return_type.is_bool() { return Err(Error::ValidatorMustReturnBool { @@ -341,8 +330,8 @@ fn infer_definition( }); } - let typed_via = - ExprTyper::new(environment, lines, tracing).infer(arg.via.clone())?; + let typed_via = ExprTyper::new(environment, hydrators, lines, tracing) + .infer(arg.via.clone())?; let hydrator: &mut Hydrator = hydrators.get_mut(&f.name).unwrap(); @@ -406,7 +395,7 @@ fn infer_definition( }?; let typed_f = infer_function( - f.into(), + &f.into(), module_name, hydrators, environment, @@ -635,8 +624,8 @@ fn infer_definition( value, tipo: _, }) => { - let typed_expr = - ExprTyper::new(environment, lines, tracing).infer_const(&annotation, *value)?; + let typed_expr = ExprTyper::new(environment, hydrators, lines, tracing) + .infer_const(&annotation, *value)?; let tipo = typed_expr.tipo(); @@ -671,106 +660,6 @@ fn infer_definition( } } -fn infer_function( - f: Function<(), UntypedExpr, UntypedArg>, - module_name: &String, - hydrators: &mut HashMap, - environment: &mut Environment<'_>, - lines: &LineNumbers, - tracing: Tracing, -) -> Result, TypedExpr, TypedArg>, Error> { - let Function { - doc, - location, - name, - public, - arguments, - body, - return_annotation, - end_position, - can_error, - return_type: _, - } = f; - - let preregistered_fn = environment - .get_variable(&name) - .expect("Could not find preregistered type for function"); - - let field_map = preregistered_fn.field_map().cloned(); - - let preregistered_type = preregistered_fn.tipo.clone(); - - let (args_types, return_type) = preregistered_type - .function_types() - .expect("Preregistered type for fn was not a fn"); - - // Infer the type using the preregistered args + return types as a starting point - let (tipo, arguments, body, safe_to_generalise) = environment.in_new_scope(|environment| { - let args = arguments - .into_iter() - .zip(&args_types) - .map(|(arg_name, tipo)| arg_name.set_type(tipo.clone())) - .collect(); - - let mut expr_typer = ExprTyper::new(environment, lines, tracing); - - expr_typer.hydrator = hydrators - .remove(&name) - .expect("Could not find hydrator for fn"); - - let (args, body, return_type) = - expr_typer.infer_fn_with_known_types(args, body, Some(return_type))?; - - let args_types = args.iter().map(|a| a.tipo.clone()).collect(); - - let tipo = function(args_types, return_type); - - let safe_to_generalise = !expr_typer.ungeneralised_function_used; - - Ok::<_, Error>((tipo, args, body, safe_to_generalise)) - })?; - - // Assert that the inferred type matches the type of any recursive call - environment.unify(preregistered_type, tipo.clone(), location, false)?; - - // Generalise the function if safe to do so - let tipo = if safe_to_generalise { - environment.ungeneralised_functions.remove(&name); - - let tipo = generalise(tipo, 0); - - let module_fn = ValueConstructorVariant::ModuleFn { - name: name.clone(), - field_map, - module: module_name.to_owned(), - arity: arguments.len(), - location, - builtin: None, - }; - - environment.insert_variable(name.clone(), module_fn, tipo.clone()); - - tipo - } else { - tipo - }; - - Ok(Function { - doc, - location, - name, - public, - arguments, - return_annotation, - return_type: tipo - .return_type() - .expect("Could not find return type for fn"), - body, - can_error, - end_position, - }) -} - fn infer_fuzzer( environment: &mut Environment<'_>, expected_inner_type: Option>,