From 32a9ed9ad19cf81275c31ca77e4970bc1598c112 Mon Sep 17 00:00:00 2001 From: jfecher Date: Tue, 19 Nov 2024 11:15:32 -0600 Subject: [PATCH] fix: Do a shallow follow_bindings before unification (#6558) --- compiler/noirc_frontend/src/hir_def/types.rs | 22 +++++++++++++++++-- .../src/monomorphization/mod.rs | 6 +++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index ef8a697966c..659fafbbcbb 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1547,12 +1547,15 @@ impl Type { ) -> Result<(), UnificationError> { use Type::*; - let lhs = match self { + let lhs = self.follow_bindings_shallow(); + let rhs = other.follow_bindings_shallow(); + + let lhs = match lhs.as_ref() { Type::InfixExpr(..) => Cow::Owned(self.canonicalize()), other => Cow::Borrowed(other), }; - let rhs = match other { + let rhs = match rhs.as_ref() { Type::InfixExpr(..) => Cow::Owned(other.canonicalize()), other => Cow::Borrowed(other), }; @@ -2386,6 +2389,21 @@ impl Type { } } + /// Follow bindings if this is a type variable or generic to the first non-typevariable + /// type. Unlike `follow_bindings`, this won't recursively follow any bindings on any + /// fields or arguments of this type. + pub fn follow_bindings_shallow(&self) -> Cow { + match self { + Type::TypeVariable(var) | Type::NamedGeneric(var, _) => { + if let TypeBinding::Bound(typ) = &*var.borrow() { + return Cow::Owned(typ.follow_bindings_shallow().into_owned()); + } + Cow::Borrowed(self) + } + other => Cow::Borrowed(other), + } + } + pub fn from_generics(generics: &GenericTypeVars) -> Vec { vecmap(generics, |var| Type::TypeVariable(var.clone())) } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index ce2c58e71c1..dd72437ccd7 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -959,7 +959,8 @@ impl<'interner> Monomorphizer<'interner> { /// Convert a non-tuple/struct type to a monomorphized type fn convert_type(typ: &HirType, location: Location) -> Result { - Ok(match typ { + let typ = typ.follow_bindings_shallow(); + Ok(match typ.as_ref() { HirType::FieldElement => ast::Type::Field, HirType::Integer(sign, bits) => ast::Type::Integer(*sign, *bits), HirType::Bool => ast::Type::Bool, @@ -1125,7 +1126,8 @@ impl<'interner> Monomorphizer<'interner> { // Similar to `convert_type` but returns an error if any type variable can't be defaulted. fn check_type(typ: &HirType, location: Location) -> Result<(), MonomorphizationError> { - match typ { + let typ = typ.follow_bindings_shallow(); + match typ.as_ref() { HirType::FieldElement | HirType::Integer(..) | HirType::Bool