diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 3aeac115ad5..9845166afae 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -1418,22 +1418,13 @@ impl<'context> Elaborator<'context> { trait_impl.resolved_generics = self.generics.clone(); // Fetch trait constraints here - let trait_generics = if let Some(trait_id) = trait_impl.trait_id { - let trait_def = self.interner.get_trait(trait_id); - let resolved_generics = trait_def.generics.clone(); - assert_eq!(resolved_generics.len(), trait_impl.trait_generics.len()); - trait_impl - .trait_generics - .iter() - .enumerate() - .map(|(i, generic)| { - self.resolve_type_inner(generic.clone(), &resolved_generics[i].kind) - }) - .collect() - } else { - // We still resolve as to continue type checking - vecmap(&trait_impl.trait_generics, |generic| self.resolve_type(generic.clone())) - }; + let trait_generics = trait_impl + .trait_id + .and_then(|trait_id| self.resolve_trait_impl_generics(trait_impl, trait_id)) + .unwrap_or_else(|| { + // We still resolve as to continue type checking + vecmap(&trait_impl.trait_generics, |generic| self.resolve_type(generic.clone())) + }); trait_impl.resolved_trait_generics = trait_generics; diff --git a/compiler/noirc_frontend/src/elaborator/traits.rs b/compiler/noirc_frontend/src/elaborator/traits.rs index 8520885bcdc..a00e770218e 100644 --- a/compiler/noirc_frontend/src/elaborator/traits.rs +++ b/compiler/noirc_frontend/src/elaborator/traits.rs @@ -8,7 +8,9 @@ use crate::{ FunctionKind, TraitItem, UnresolvedGeneric, UnresolvedGenerics, UnresolvedTraitConstraint, }, hir::{ - def_collector::dc_crate::{CollectedItems, UnresolvedTrait}, + def_collector::dc_crate::{ + CollectedItems, CompilationError, UnresolvedTrait, UnresolvedTraitImpl, + }, type_check::TypeCheckError, }, hir_def::{ @@ -215,6 +217,31 @@ impl<'context> Elaborator<'context> { // Don't check the scope tree for unused variables, they can't be used in a declaration anyway. self.generics.truncate(old_generic_count); } + + pub fn resolve_trait_impl_generics( + &mut self, + trait_impl: &UnresolvedTraitImpl, + trait_id: TraitId, + ) -> Option> { + let trait_def = self.interner.get_trait(trait_id); + let resolved_generics = trait_def.generics.clone(); + if resolved_generics.len() != trait_impl.trait_generics.len() { + self.push_err(CompilationError::TypeError(TypeCheckError::GenericCountMismatch { + item: trait_def.name.to_string(), + expected: resolved_generics.len(), + found: trait_impl.trait_generics.len(), + span: trait_impl.trait_path.span(), + })); + + return None; + } + + let generics = trait_impl.trait_generics.iter().zip(resolved_generics.iter()); + let mapped = generics.map(|(generic, resolved_generic)| { + self.resolve_type_inner(generic.clone(), &resolved_generic.kind) + }); + Some(mapped.collect()) + } } /// Checks that the type of a function in a trait impl matches the type diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index b4f17489ff7..3767dad103c 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -2459,3 +2459,29 @@ fn no_super() { assert_eq!(span.start(), 4); assert_eq!(span.end(), 9); } + +#[test] +fn trait_impl_generics_count_mismatch() { + let src = r#" + trait Foo {} + + impl Foo<()> for Field {} + + fn main() {}"#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::TypeError(TypeCheckError::GenericCountMismatch { + item, + expected, + found, + .. + }) = &errors[0].0 + else { + panic!("Expected a generic count mismatch error, got {:?}", errors[0].0); + }; + + assert_eq!(item, "Foo"); + assert_eq!(*expected, 0); + assert_eq!(*found, 1); +}