Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow multiple trait impls for the same trait as long as one is in scope #6987

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{borrow::Cow, rc::Rc};

use im::HashSet;
use iter_extended::vecmap;
use noirc_errors::{Location, Span};
use rustc_hash::FxHashMap as HashMap;
Expand Down Expand Up @@ -1422,9 +1423,14 @@ impl<'context> Elaborator<'context> {
let module_id = self.module_id();
let module_data = self.get_module(module_id);

let trait_methods_in_scope: Vec<_> = trait_methods
// Only keep unique trait IDs: multiple trait methods might come from the same trait
// but implemented with different generics (like `Convert<Field>` and `Convert<i32>`).
let traits: HashSet<TraitId> =
trait_methods.into_iter().map(|(_, trait_id)| trait_id).collect();

let traits_in_scope: Vec<_> = traits
.iter()
.filter_map(|(func_id, trait_id)| {
.filter_map(|trait_id| {
let trait_ = self.interner.get_trait(*trait_id);
let trait_name = &trait_.name;
let Some(map) = module_data.scope().types().get(trait_name) else {
Expand All @@ -1434,30 +1440,34 @@ impl<'context> Elaborator<'context> {
return None;
};
if imported_item.0 == ModuleDefId::TraitId(*trait_id) {
Some((*func_id, *trait_id, trait_name))
Some((*trait_id, trait_name))
} else {
None
}
})
.collect();

for (_, _, trait_name) in &trait_methods_in_scope {
for (_, trait_name) in &traits_in_scope {
self.usage_tracker.mark_as_used(module_id, trait_name);
}

if trait_methods_in_scope.is_empty() {
if trait_methods.len() == 1 {
// This is the backwards-compatible case where there's a single trait method but it's not in scope
let (func_id, trait_id) = trait_methods[0];
if traits_in_scope.is_empty() {
if traits.len() == 1 {
// This is the backwards-compatible case where there's a single trait but it's not in scope
let trait_id = *traits.iter().next().unwrap();
let trait_ = self.interner.get_trait(trait_id);
let trait_name = self.fully_qualified_trait_path(trait_);
let generics = trait_.as_constraint(span).trait_bound.trait_generics;
let trait_method_id = trait_.find_method(method_name).unwrap();

self.push_err(PathResolutionError::TraitMethodNotInScope {
ident: Ident::new(method_name.into(), span),
trait_name,
});
return Some(HirMethodReference::FuncId(func_id));

return Some(HirMethodReference::TraitMethodId(trait_method_id, generics, false));
} else {
let traits = vecmap(trait_methods, |(_, trait_id)| {
let traits = vecmap(traits, |trait_id| {
let trait_ = self.interner.get_trait(trait_id);
self.fully_qualified_trait_path(trait_)
});
Expand All @@ -1469,8 +1479,8 @@ impl<'context> Elaborator<'context> {
}
}

if trait_methods_in_scope.len() > 1 {
let traits = vecmap(trait_methods, |(_, trait_id)| {
if traits_in_scope.len() > 1 {
let traits = vecmap(traits, |trait_id| {
let trait_ = self.interner.get_trait(trait_id);
self.fully_qualified_trait_path(trait_)
});
Expand All @@ -1481,8 +1491,12 @@ impl<'context> Elaborator<'context> {
return None;
}

let func_id = trait_methods_in_scope[0].0;
Some(HirMethodReference::FuncId(func_id))
// Return a TraitMethodId with unbound generics. These will later be bound by the type-checker.
let trait_id = traits_in_scope[0].0;
let trait_ = self.interner.get_trait(trait_id);
let generics = trait_.as_constraint(span).trait_bound.trait_generics;
let trait_method_id = trait_.find_method(method_name).unwrap();
Some(HirMethodReference::TraitMethodId(trait_method_id, generics, false))
}

fn lookup_method_in_trait_constraints(
Expand Down
32 changes: 32 additions & 0 deletions compiler/noirc_frontend/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,38 @@ fn errors_if_multiple_trait_methods_are_in_scope_for_method_call() {
assert_eq!(traits, vec!["private_mod::Foo", "private_mod::Foo2"]);
}

#[test]
fn calls_trait_method_if_it_is_in_scope_with_multiple_candidates_but_only_one_decided_by_generics()
{
let src = r#"
struct Foo {
inner: Field,
}

trait Converter<N> {
fn convert(self) -> N;
}

impl Converter<Field> for Foo {
fn convert(self) -> Field {
self.inner
}
}

impl Converter<u32> for Foo {
fn convert(self) -> u32 {
self.inner as u32
}
}

fn main() {
let foo = Foo { inner: 42 };
let _: u32 = foo.convert();
}
"#;
assert_no_errors(src);
}

#[test]
fn type_checks_trait_default_method_and_errors() {
let src = r#"
Expand Down
Loading