diff --git a/crates/lang/src/tipo/expr.rs b/crates/lang/src/tipo/expr.rs index 733c567f..767da8ef 100644 --- a/crates/lang/src/tipo/expr.rs +++ b/crates/lang/src/tipo/expr.rs @@ -771,20 +771,36 @@ impl<'a, 'b> ExprTyper<'a, 'b> { location: Span, ) -> Result { let value = self.in_new_scope(|value_typer| value_typer.infer(value))?; - let value_typ = value.tipo(); - - // Ensure the pattern matches the type of the value - let pattern = PatternTyper::new(self.environment, &self.hydrator) - .unify(pattern, value_typ.clone())?; + let mut value_typ = value.tipo(); // Check that any type annotation is accurate. - if let Some(ann) = annotation { + let pattern = if let Some(ann) = annotation { let ann_typ = self .type_from_annotation(ann) .map(|t| self.instantiate(t, &mut HashMap::new()))?; - self.unify(ann_typ, value_typ.clone(), value.type_defining_location())?; - } + self.unify( + ann_typ.clone(), + value_typ.clone(), + value.type_defining_location(), + )?; + + value_typ = ann_typ.clone(); + + // Ensure the pattern matches the type of the value + PatternTyper::new(self.environment, &self.hydrator).unify( + pattern, + value_typ.clone(), + Some(ann_typ), + )? + } else { + // Ensure the pattern matches the type of the value + PatternTyper::new(self.environment, &self.hydrator).unify( + pattern, + value_typ.clone(), + None, + )? + }; // We currently only do limited exhaustiveness checking of custom types // at the top level of patterns. @@ -1688,8 +1704,11 @@ impl<'a, 'b> ExprTyper<'a, 'b> { }; // Ensure the pattern matches the type of the value - let pattern = PatternTyper::new(self.environment, &self.hydrator) - .unify(pattern, value_type.clone())?; + let pattern = PatternTyper::new(self.environment, &self.hydrator).unify( + pattern, + value_type.clone(), + None, + )?; // Check the type of the following code let then = self.infer(then)?; diff --git a/crates/lang/src/tipo/pattern.rs b/crates/lang/src/tipo/pattern.rs index 148b3c2e..a13abbab 100644 --- a/crates/lang/src/tipo/pattern.rs +++ b/crates/lang/src/tipo/pattern.rs @@ -144,7 +144,7 @@ impl<'a, 'b> PatternTyper<'a, 'b> { // Unify each pattern in the multi-pattern with the corresponding subject let mut typed_multi = Vec::with_capacity(multi_pattern.len()); for (pattern, subject_type) in multi_pattern.into_iter().zip(subjects) { - let pattern = self.unify(pattern, subject_type.clone())?; + let pattern = self.unify(pattern, subject_type.clone(), None)?; typed_multi.push(pattern); } Ok(typed_multi) @@ -224,12 +224,13 @@ impl<'a, 'b> PatternTyper<'a, 'b> { &mut self, pattern: UntypedPattern, tipo: Arc, + ann_type: Option>, ) -> Result { match pattern { Pattern::Discard { name, location } => Ok(Pattern::Discard { name, location }), - Pattern::Var { name, location, .. } => { - self.insert_variable(&name, tipo, location, location)?; + Pattern::Var { name, location } => { + self.insert_variable(&name, ann_type.unwrap_or(tipo), location, location)?; Ok(Pattern::Var { name, location }) } @@ -288,9 +289,14 @@ impl<'a, 'b> PatternTyper<'a, 'b> { pattern, location, } => { - self.insert_variable(&name, tipo.clone(), location, pattern.location())?; + self.insert_variable( + &name, + ann_type.clone().unwrap_or_else(|| tipo.clone()), + location, + pattern.location(), + )?; - let pattern = self.unify(*pattern, tipo)?; + let pattern = self.unify(*pattern, tipo, ann_type)?; Ok(Pattern::Assign { name, @@ -324,11 +330,11 @@ impl<'a, 'b> PatternTyper<'a, 'b> { let elements = elements .into_iter() - .map(|element| self.unify(element, tipo.clone())) + .map(|element| self.unify(element, tipo.clone(), None)) .try_collect()?; let tail = match tail { - Some(tail) => Some(Box::new(self.unify(*tail, list(tipo))?)), + Some(tail) => Some(Box::new(self.unify(*tail, list(tipo), None)?)), None => None, }; @@ -397,12 +403,6 @@ impl<'a, 'b> PatternTyper<'a, 'b> { // }) // } // }, - // Pattern::BitString { location, segments } => { - // self.environment - // .unify(type_, bit_string()) - // .map_err(|e| convert_unify_error(e, location))?; - // self.infer_pattern_bit_string(segments, location) - // } Pattern::Constructor { location, module, @@ -506,7 +506,7 @@ impl<'a, 'b> PatternTyper<'a, 'b> { label, } = arg; - let value = self.unify(value, typ.clone())?; + let value = self.unify(value, typ.clone(), None)?; Ok(CallArg { value,