diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs index a0fea3aa774..462c21f1e1d 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs @@ -1655,12 +1655,15 @@ impl Type { ) -> Result<(), UnificationError> { use Type::*; - let lhs = match self { + let lhs = self.follow_bindings_shallow(); + let rhs = other.follow_bindings_shallow(); + + let lhs = match lhs.as_ref() { Type::InfixExpr(..) => Cow::Owned(self.canonicalize()), other => Cow::Borrowed(other), }; - let rhs = match other { + let rhs = match rhs.as_ref() { Type::InfixExpr(..) => Cow::Owned(other.canonicalize()), other => Cow::Borrowed(other), }; @@ -2494,6 +2497,21 @@ impl Type { } } + /// Follow bindings if this is a type variable or generic to the first non-typevariable + /// type. Unlike `follow_bindings`, this won't recursively follow any bindings on any + /// fields or arguments of this type. + pub fn follow_bindings_shallow(&self) -> Cow { + match self { + Type::TypeVariable(var) | Type::NamedGeneric(var, _) => { + if let TypeBinding::Bound(typ) = &*var.borrow() { + return Cow::Owned(typ.follow_bindings_shallow().into_owned()); + } + Cow::Borrowed(self) + } + other => Cow::Borrowed(other), + } + } + pub fn from_generics(generics: &GenericTypeVars) -> Vec { vecmap(generics, |var| Type::TypeVariable(var.clone())) } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs index 0ec26a5ca83..050f844146a 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -953,7 +953,8 @@ impl<'interner> Monomorphizer<'interner> { /// Convert a non-tuple/struct type to a monomorphized type fn convert_type(typ: &HirType, location: Location) -> Result { - Ok(match typ { + let typ = typ.follow_bindings_shallow(); + Ok(match typ.as_ref() { HirType::FieldElement => ast::Type::Field, HirType::Integer(sign, bits) => ast::Type::Integer(*sign, *bits), HirType::Bool => ast::Type::Bool, @@ -1119,7 +1120,8 @@ impl<'interner> Monomorphizer<'interner> { // Similar to `convert_type` but returns an error if any type variable can't be defaulted. fn check_type(typ: &HirType, location: Location) -> Result<(), MonomorphizationError> { - match typ { + let typ = typ.follow_bindings_shallow(); + match typ.as_ref() { HirType::FieldElement | HirType::Integer(..) | HirType::Bool