diff --git a/crates/aiken-lang/src/tipo/environment.rs b/crates/aiken-lang/src/tipo/environment.rs index 0efcfc5f..8b991f86 100644 --- a/crates/aiken-lang/src/tipo/environment.rs +++ b/crates/aiken-lang/src/tipo/environment.rs @@ -1205,19 +1205,30 @@ impl<'a> Environment<'a> { /// /// It two types are found to not be the same an error is returned. #[allow(clippy::only_used_in_recursion)] - pub fn unify(&mut self, t1: Arc, t2: Arc, location: Span) -> Result<(), Error> { + pub fn unify( + &mut self, + t1: Arc, + t2: Arc, + location: Span, + allow_cast: bool, + ) -> Result<(), Error> { if t1 == t2 { return Ok(()); } - if (t1.is_data() || t2.is_data()) && !(t1.is_unbound() || t2.is_unbound()) { + if allow_cast + && (t1.is_data() || t2.is_data()) + && !(t1.is_unbound() || t2.is_unbound()) + && !(t1.is_function() || t2.is_function()) + && !(t1.is_generic() || t2.is_generic()) + { return Ok(()); } // Collapse right hand side type links. Left hand side will be collapsed in the next block. if let Type::Var { tipo } = t2.deref() { if let TypeVar::Link { tipo } = tipo.borrow().deref() { - return self.unify(t1, tipo.clone(), location); + return self.unify(t1, tipo.clone(), location, allow_cast); } } @@ -1253,7 +1264,7 @@ impl<'a> Environment<'a> { Ok(()) } - Action::Unify(t) => self.unify(t, t2, location), + Action::Unify(t) => self.unify(t, t2, location, allow_cast), Action::CouldNotUnify => Err(Error::CouldNotUnify { location, @@ -1266,7 +1277,9 @@ impl<'a> Environment<'a> { } if let Type::Var { .. } = t2.deref() { - return self.unify(t2, t1, location).map_err(|e| e.flip_unify()); + return self + .unify(t2, t1, location, allow_cast) + .map_err(|e| e.flip_unify()); } match (t1.deref(), t2.deref()) { @@ -1288,7 +1301,7 @@ impl<'a> Environment<'a> { unify_enclosed_type( t1.clone(), t2.clone(), - self.unify(a.clone(), b.clone(), location), + self.unify(a.clone(), b.clone(), location, allow_cast), )?; } Ok(()) @@ -1301,7 +1314,7 @@ impl<'a> Environment<'a> { unify_enclosed_type( t1.clone(), t2.clone(), - self.unify(a.clone(), b.clone(), location), + self.unify(a.clone(), b.clone(), location, allow_cast), )?; } Ok(()) @@ -1320,17 +1333,16 @@ impl<'a> Environment<'a> { }, ) if args1.len() == args2.len() => { for (a, b) in args1.iter().zip(args2) { - self.unify(a.clone(), b.clone(), location).map_err(|_| { - Error::CouldNotUnify { + self.unify(a.clone(), b.clone(), location, allow_cast) + .map_err(|_| Error::CouldNotUnify { location, expected: t1.clone(), given: t2.clone(), situation: None, rigid_type_names: HashMap::new(), - } - })?; + })?; } - self.unify(retrn1.clone(), retrn2.clone(), location) + self.unify(retrn1.clone(), retrn2.clone(), location, allow_cast) .map_err(|_| Error::CouldNotUnify { location, expected: t1.clone(), diff --git a/crates/aiken-lang/src/tipo/error.rs b/crates/aiken-lang/src/tipo/error.rs index a36c6e9e..2c0fed74 100644 --- a/crates/aiken-lang/src/tipo/error.rs +++ b/crates/aiken-lang/src/tipo/error.rs @@ -951,19 +951,6 @@ fn suggest_unify( expected.green(), given.red() }, - Some(UnifyErrorSituation::UnsafeCast) => formatdoc! { - r#"I am inferring the following type: - - {} - - but I found an expression with a different type: - - {} - - It is unsafe to cast Data without using assert"#, - expected.green(), - given.red() - }, None => formatdoc! { r#"I am inferring the following type: @@ -1200,9 +1187,6 @@ pub enum UnifyErrorSituation { /// The operands of a binary operator were incorrect. Operator(BinOp), - - /// Called a function with something of type Data but something else was expected - UnsafeCast, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/crates/aiken-lang/src/tipo/expr.rs b/crates/aiken-lang/src/tipo/expr.rs index 6d2f6c5d..3787b991 100644 --- a/crates/aiken-lang/src/tipo/expr.rs +++ b/crates/aiken-lang/src/tipo/expr.rs @@ -17,7 +17,7 @@ use crate::{ use super::{ environment::{assert_no_labeled_arguments, collapse_links, EntityKind, Environment}, - error::{Error, UnifyErrorSituation, Warning}, + error::{Error, Warning}, hydrator::Hydrator, pattern::PatternTyper, pipe::PipeTyper, @@ -155,7 +155,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let mut arguments = Vec::new(); for (i, arg) in args.into_iter().enumerate() { - let arg = self.infer_arg(arg, expected_args.get(i).cloned())?; + let arg = self.infer_param(arg, expected_args.get(i).cloned())?; arguments.push(arg); } @@ -394,7 +394,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let right = self.infer(right)?; - self.unify(left.tipo(), right.tipo(), right.location())?; + self.unify(left.tipo(), right.tipo(), right.location(), false)?; return Ok(TypedExpr::BinOp { location, @@ -423,13 +423,19 @@ impl<'a, 'b> ExprTyper<'a, 'b> { input_type.clone(), left.tipo(), left.type_defining_location(), + false, ) .map_err(|e| e.operator_situation(name))?; let right = self.infer(right)?; - self.unify(input_type, right.tipo(), right.type_defining_location()) - .map_err(|e| e.operator_situation(name))?; + self.unify( + input_type, + right.tipo(), + right.type_defining_location(), + false, + ) + .map_err(|e| e.operator_situation(name))?; Ok(TypedExpr::BinOp { location, @@ -505,7 +511,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let return_type = self.instantiate(ret.clone(), &mut HashMap::new()); // Check that the spread variable unifies with the return type of the constructor - self.unify(return_type, spread.tipo(), spread.location())?; + self.unify(return_type, spread.tipo(), spread.location(), false)?; let mut arguments = Vec::new(); @@ -523,7 +529,12 @@ impl<'a, 'b> ExprTyper<'a, 'b> { // field in the record contained within the spread variable. We // need to check the spread, and not the constructor, in order // to handle polymorphic types. - self.unify(spread_field.tipo(), value.tipo(), value.location())?; + self.unify( + spread_field.tipo(), + value.tipo(), + value.location(), + spread_field.tipo().is_data(), + )?; match field_map.fields.get(&label) { None => { @@ -571,7 +582,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { UnOp::Negate => int(), }; - self.unify(tipo.clone(), value.tipo(), value.location())?; + self.unify(tipo.clone(), value.tipo(), value.location(), false)?; Ok(TypedExpr::UnOp { location, @@ -754,7 +765,12 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let tipo = self.instantiate(tipo, &mut type_vars); - self.unify(accessor_record_type, record.tipo(), record.location())?; + self.unify( + accessor_record_type, + record.tipo(), + record.location(), + false, + )?; Ok(TypedExpr::RecordAccess { record, @@ -765,7 +781,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { }) } - fn infer_arg( + fn infer_param( &mut self, arg: UntypedArg, expected: Option>, @@ -788,7 +804,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { // function being type checked, resulting in better type errors and the // record field access syntax working. if let Some(expected) = expected { - self.unify(expected, tipo.clone(), location)?; + self.unify(expected, tipo.clone(), location, false)?; } Ok(Arg { @@ -820,18 +836,9 @@ impl<'a, 'b> ExprTyper<'a, 'b> { ann_typ.clone(), value_typ.clone(), typed_value.type_defining_location(), + (kind.is_let() && ann_typ.is_data()) || (kind.is_assert() && value_typ.is_data()), )?; - if value_typ.is_data() && kind.is_let() && !ann_typ.is_data() { - return Err(Error::CouldNotUnify { - location, - expected: ann_typ, - given: value_typ, - situation: Some(UnifyErrorSituation::UnsafeCast), - rigid_type_names: HashMap::new(), - }); - } - value_typ = ann_typ.clone(); // Ensure the pattern matches the type of the value @@ -951,17 +958,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { (_, value) => self.infer(value), }?; - self.unify(tipo.clone(), value.tipo(), value.location())?; - - if value.tipo().is_data() && !tipo.is_data() { - return Err(Error::CouldNotUnify { - location: value.location(), - expected: tipo, - given: value.tipo(), - situation: Some(UnifyErrorSituation::UnsafeCast), - rigid_type_names: HashMap::new(), - }); - } + self.unify(tipo.clone(), value.tipo(), value.location(), tipo.is_data())?; Ok(value) } @@ -1035,7 +1032,9 @@ impl<'a, 'b> ExprTyper<'a, 'b> { location, value, .. } => { let value = self.infer_clause_guard(*value)?; - self.unify(bool(), value.tipo(), value.location())?; + + self.unify(bool(), value.tipo(), value.location(), false)?; + Ok(ClauseGuard::Not { location, value: Box::new(value), @@ -1050,11 +1049,11 @@ impl<'a, 'b> ExprTyper<'a, 'b> { } => { let left = self.infer_clause_guard(*left)?; - self.unify(bool(), left.tipo(), left.location())?; + self.unify(bool(), left.tipo(), left.location(), false)?; let right = self.infer_clause_guard(*right)?; - self.unify(bool(), right.tipo(), right.location())?; + self.unify(bool(), right.tipo(), right.location(), false)?; Ok(ClauseGuard::And { location, @@ -1071,11 +1070,11 @@ impl<'a, 'b> ExprTyper<'a, 'b> { } => { let left = self.infer_clause_guard(*left)?; - self.unify(bool(), left.tipo(), left.location())?; + self.unify(bool(), left.tipo(), left.location(), false)?; let right = self.infer_clause_guard(*right)?; - self.unify(bool(), right.tipo(), right.location())?; + self.unify(bool(), right.tipo(), right.location(), false)?; Ok(ClauseGuard::Or { location, @@ -1093,7 +1092,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let left = self.infer_clause_guard(*left)?; let right = self.infer_clause_guard(*right)?; - self.unify(left.tipo(), right.tipo(), location)?; + self.unify(left.tipo(), right.tipo(), location, false)?; Ok(ClauseGuard::Equals { location, @@ -1111,7 +1110,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let left = self.infer_clause_guard(*left)?; let right = self.infer_clause_guard(*right)?; - self.unify(left.tipo(), right.tipo(), location)?; + self.unify(left.tipo(), right.tipo(), location, false)?; Ok(ClauseGuard::NotEquals { location, @@ -1128,11 +1127,11 @@ impl<'a, 'b> ExprTyper<'a, 'b> { } => { let left = self.infer_clause_guard(*left)?; - self.unify(int(), left.tipo(), left.location())?; + self.unify(int(), left.tipo(), left.location(), false)?; let right = self.infer_clause_guard(*right)?; - self.unify(int(), right.tipo(), right.location())?; + self.unify(int(), right.tipo(), right.location(), false)?; Ok(ClauseGuard::GtInt { location, @@ -1149,11 +1148,11 @@ impl<'a, 'b> ExprTyper<'a, 'b> { } => { let left = self.infer_clause_guard(*left)?; - self.unify(int(), left.tipo(), left.location())?; + self.unify(int(), left.tipo(), left.location(), false)?; let right = self.infer_clause_guard(*right)?; - self.unify(int(), right.tipo(), right.location())?; + self.unify(int(), right.tipo(), right.location(), false)?; Ok(ClauseGuard::GtEqInt { location, @@ -1170,11 +1169,11 @@ impl<'a, 'b> ExprTyper<'a, 'b> { } => { let left = self.infer_clause_guard(*left)?; - self.unify(int(), left.tipo(), left.location())?; + self.unify(int(), left.tipo(), left.location(), false)?; let right = self.infer_clause_guard(*right)?; - self.unify(int(), right.tipo(), right.location())?; + self.unify(int(), right.tipo(), right.location(), false)?; Ok(ClauseGuard::LtInt { location, @@ -1191,11 +1190,11 @@ impl<'a, 'b> ExprTyper<'a, 'b> { } => { let left = self.infer_clause_guard(*left)?; - self.unify(int(), left.tipo(), left.location())?; + self.unify(int(), left.tipo(), left.location(), false)?; let right = self.infer_clause_guard(*right)?; - self.unify(int(), right.tipo(), right.location())?; + self.unify(int(), right.tipo(), right.location(), false)?; Ok(ClauseGuard::LtEqInt { location, @@ -1427,7 +1426,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let value = self.infer_const(&None, value)?; - self.unify(tipo.clone(), value.tipo(), value.location())?; + self.unify(tipo.clone(), value.tipo(), value.location(), tipo.is_data())?; typed_args.push(CallArg { label, @@ -1482,7 +1481,12 @@ impl<'a, 'b> ExprTyper<'a, 'b> { if let Some(ann) = annotation { let const_ann = self.type_from_annotation(ann)?; - self.unify(const_ann, inferred.tipo(), inferred.location())?; + self.unify( + const_ann.clone(), + inferred.tipo(), + inferred.location(), + const_ann.is_data(), + )?; }; Ok(inferred) @@ -1500,7 +1504,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { for element in untyped_elements { let element = self.infer_const(&None, element)?; - self.unify(tipo.clone(), element.tipo(), element.location())?; + self.unify(tipo.clone(), element.tipo(), element.location(), false)?; elements.push(element); } @@ -1522,7 +1526,12 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let condition = self.infer(first.condition.clone())?; - self.unify(bool(), condition.tipo(), condition.type_defining_location())?; + self.unify( + bool(), + condition.tipo(), + condition.type_defining_location(), + false, + )?; let body = self.infer(first.body.clone())?; @@ -1537,11 +1546,21 @@ impl<'a, 'b> ExprTyper<'a, 'b> { for branch in &branches[1..] { let condition = self.infer(branch.condition.clone())?; - self.unify(bool(), condition.tipo(), condition.type_defining_location())?; + self.unify( + bool(), + condition.tipo(), + condition.type_defining_location(), + false, + )?; let body = self.infer(branch.body.clone())?; - self.unify(tipo.clone(), body.tipo(), body.type_defining_location())?; + self.unify( + tipo.clone(), + body.tipo(), + body.type_defining_location(), + false, + )?; typed_branches.push(TypedIfBranch { body, @@ -1556,6 +1575,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { tipo.clone(), typed_final_else.tipo(), typed_final_else.type_defining_location(), + false, )?; Ok(TypedExpr::If { @@ -1628,11 +1648,16 @@ impl<'a, 'b> ExprTyper<'a, 'b> { // Check that any return type is accurate. if let Some(return_type) = return_type { - self.unify(return_type, body.tipo(), body.type_defining_location()) - .map_err(|e| { - e.return_annotation_mismatch() - .with_unify_error_rigid_names(&body_rigid_names) - })?; + self.unify( + return_type.clone(), + body.tipo(), + body.type_defining_location(), + return_type.is_data(), + ) + .map_err(|e| { + e.return_annotation_mismatch() + .with_unify_error_rigid_names(&body_rigid_names) + })?; } Ok((args, body)) @@ -1660,7 +1685,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let element = self.infer(elem)?; // Ensure they all have the same type - self.unify(tipo.clone(), element.tipo(), location)?; + self.unify(tipo.clone(), element.tipo(), location, false)?; elems.push(element) } @@ -1673,7 +1698,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { let tail = self.infer(*tail)?; // Ensure the tail has the same type as the preceeding elements - self.unify(tipo.clone(), tail.tipo(), location)?; + self.unify(tipo.clone(), tail.tipo(), location, false)?; Some(Box::new(tail)) } @@ -1700,7 +1725,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { Some(guard) => { let guard = self.infer_clause_guard(guard)?; - self.unify(bool(), guard.tipo(), guard.location())?; + self.unify(bool(), guard.tipo(), guard.location(), false)?; Ok(Some(guard)) } @@ -1986,6 +2011,7 @@ impl<'a, 'b> ExprTyper<'a, 'b> { return_type.clone(), typed_clause.then.tipo(), typed_clause.location(), + false, ) .map_err(|e| e.case_clause_mismatch())?; @@ -2034,7 +2060,13 @@ impl<'a, 'b> ExprTyper<'a, 'b> { .type_from_annotation(annotation, self.environment) } - fn unify(&mut self, t1: Arc, t2: Arc, location: Span) -> Result<(), Error> { - self.environment.unify(t1, t2, location) + fn unify( + &mut self, + t1: Arc, + t2: Arc, + location: Span, + allow_cast: bool, + ) -> Result<(), Error> { + self.environment.unify(t1, t2, location, allow_cast) } } diff --git a/crates/aiken-lang/src/tipo/hydrator.rs b/crates/aiken-lang/src/tipo/hydrator.rs index 6f066034..9a4088e8 100644 --- a/crates/aiken-lang/src/tipo/hydrator.rs +++ b/crates/aiken-lang/src/tipo/hydrator.rs @@ -160,7 +160,7 @@ impl Hydrator { for (parameter, (location, argument)) in parameter_types.into_iter().zip(argument_types) { - environment.unify(parameter, argument, location)?; + environment.unify(parameter, argument, location, false)?; } Ok(return_type) diff --git a/crates/aiken-lang/src/tipo/infer.rs b/crates/aiken-lang/src/tipo/infer.rs index fb2c50fe..54037cfd 100644 --- a/crates/aiken-lang/src/tipo/infer.rs +++ b/crates/aiken-lang/src/tipo/infer.rs @@ -207,7 +207,7 @@ fn infer_definition( })?; // Assert that the inferred type matches the type of any recursive call - environment.unify(preregistered_type, tipo.clone(), location)?; + environment.unify(preregistered_type, tipo.clone(), location, false)?; // Generalise the function if safe to do so let tipo = if safe_to_generalise { @@ -250,7 +250,7 @@ fn infer_definition( if let Definition::Fn(f) = infer_definition(Definition::Fn(f), module_name, hydrators, environment, kind)? { - environment.unify(f.return_type.clone(), builtins::bool(), f.location)?; + environment.unify(f.return_type.clone(), builtins::bool(), f.location, false)?; Ok(Definition::Test(f)) } else { unreachable!("test defintion inferred as something else than a function?") diff --git a/crates/aiken-lang/src/tipo/pattern.rs b/crates/aiken-lang/src/tipo/pattern.rs index e7e83dd0..f77bf88d 100644 --- a/crates/aiken-lang/src/tipo/pattern.rs +++ b/crates/aiken-lang/src/tipo/pattern.rs @@ -82,7 +82,8 @@ impl<'a, 'b> PatternTyper<'a, 'b> { Some(initial) if self.initial_pattern_vars.contains(name) => { assigned.push(name.to_string()); let initial_typ = initial.tipo.clone(); - self.environment.unify(initial_typ, typ, err_location) + self.environment + .unify(initial_typ, typ, err_location, false) } // This variable was not defined in the Initial multi-pattern @@ -280,13 +281,13 @@ impl<'a, 'b> PatternTyper<'a, 'b> { } Pattern::Int { location, value } => { - self.environment.unify(tipo, int(), location)?; + self.environment.unify(tipo, int(), location, false)?; Ok(Pattern::Int { location, value }) } Pattern::String { location, value } => { - self.environment.unify(tipo, string(), location)?; + self.environment.unify(tipo, string(), location, false)?; Ok(Pattern::String { location, value }) } @@ -358,7 +359,7 @@ impl<'a, 'b> PatternTyper<'a, 'b> { .collect(); self.environment - .unify(tuple(elems_types.clone()), tipo, location)?; + .unify(tuple(elems_types.clone()), tipo, location, false)?; let mut patterns = vec![]; @@ -513,7 +514,7 @@ impl<'a, 'b> PatternTyper<'a, 'b> { }) .try_collect()?; - self.environment.unify(tipo, ret.clone(), location)?; + self.environment.unify(tipo, ret.clone(), location, false)?; Ok(Pattern::Constructor { location, @@ -543,6 +544,7 @@ impl<'a, 'b> PatternTyper<'a, 'b> { tipo, instantiated_constructor_type.clone(), location, + false, )?; Ok(Pattern::Constructor { diff --git a/crates/aiken-lang/src/tipo/pipe.rs b/crates/aiken-lang/src/tipo/pipe.rs index 95715a7a..d0b0e481 100644 --- a/crates/aiken-lang/src/tipo/pipe.rs +++ b/crates/aiken-lang/src/tipo/pipe.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{ops::Deref, sync::Arc}; use vec1::Vec1; @@ -268,6 +268,15 @@ impl<'a, 'b, 'c> PipeTyper<'a, 'b, 'c> { func.tipo(), function(vec![self.argument_type.clone()], return_type.clone()), func.location(), + if let Type::Fn { args, .. } = func.tipo().deref() { + if let Some(typ) = args.get(0) { + typ.is_data() + } else { + false + } + } else { + false + }, ) .map_err(|e| { let is_pipe_mismatch = self.check_if_pipe_type_mismatch(&e, func.location()); @@ -301,7 +310,7 @@ impl<'a, 'b, 'c> PipeTyper<'a, 'b, 'c> { (Some(a), Some(b)) => self .expr_typer .environment - .unify(a.clone(), b.clone(), location) + .unify(a.clone(), b.clone(), location, a.is_data()) .is_err(), _ => false, }