Skip to content

Commit

Permalink
fix: allow multiple trait impls for the same trait as long as one is …
Browse files Browse the repository at this point in the history
…in scope (#6987)
  • Loading branch information
asterite authored Jan 8, 2025
1 parent bb8dd5c commit 7328f0b
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 15 deletions.
42 changes: 28 additions & 14 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{borrow::Cow, rc::Rc};

use im::HashSet;
use iter_extended::vecmap;
use noirc_errors::{Location, Span};
use rustc_hash::FxHashMap as HashMap;
Expand Down Expand Up @@ -1422,9 +1423,14 @@ impl<'context> Elaborator<'context> {
let module_id = self.module_id();
let module_data = self.get_module(module_id);

let trait_methods_in_scope: Vec<_> = trait_methods
// Only keep unique trait IDs: multiple trait methods might come from the same trait
// but implemented with different generics (like `Convert<Field>` and `Convert<i32>`).
let traits: HashSet<TraitId> =
trait_methods.into_iter().map(|(_, trait_id)| trait_id).collect();

let traits_in_scope: Vec<_> = traits
.iter()
.filter_map(|(func_id, trait_id)| {
.filter_map(|trait_id| {
let trait_ = self.interner.get_trait(*trait_id);
let trait_name = &trait_.name;
let Some(map) = module_data.scope().types().get(trait_name) else {
Expand All @@ -1434,30 +1440,34 @@ impl<'context> Elaborator<'context> {
return None;
};
if imported_item.0 == ModuleDefId::TraitId(*trait_id) {
Some((*func_id, *trait_id, trait_name))
Some((*trait_id, trait_name))
} else {
None
}
})
.collect();

for (_, _, trait_name) in &trait_methods_in_scope {
for (_, trait_name) in &traits_in_scope {
self.usage_tracker.mark_as_used(module_id, trait_name);
}

if trait_methods_in_scope.is_empty() {
if trait_methods.len() == 1 {
// This is the backwards-compatible case where there's a single trait method but it's not in scope
let (func_id, trait_id) = trait_methods[0];
if traits_in_scope.is_empty() {
if traits.len() == 1 {
// This is the backwards-compatible case where there's a single trait but it's not in scope
let trait_id = *traits.iter().next().unwrap();
let trait_ = self.interner.get_trait(trait_id);
let trait_name = self.fully_qualified_trait_path(trait_);
let generics = trait_.as_constraint(span).trait_bound.trait_generics;
let trait_method_id = trait_.find_method(method_name).unwrap();

self.push_err(PathResolutionError::TraitMethodNotInScope {
ident: Ident::new(method_name.into(), span),
trait_name,
});
return Some(HirMethodReference::FuncId(func_id));

return Some(HirMethodReference::TraitMethodId(trait_method_id, generics, false));
} else {
let traits = vecmap(trait_methods, |(_, trait_id)| {
let traits = vecmap(traits, |trait_id| {
let trait_ = self.interner.get_trait(trait_id);
self.fully_qualified_trait_path(trait_)
});
Expand All @@ -1469,8 +1479,8 @@ impl<'context> Elaborator<'context> {
}
}

if trait_methods_in_scope.len() > 1 {
let traits = vecmap(trait_methods, |(_, trait_id)| {
if traits_in_scope.len() > 1 {
let traits = vecmap(traits, |trait_id| {
let trait_ = self.interner.get_trait(trait_id);
self.fully_qualified_trait_path(trait_)
});
Expand All @@ -1481,8 +1491,12 @@ impl<'context> Elaborator<'context> {
return None;
}

let func_id = trait_methods_in_scope[0].0;
Some(HirMethodReference::FuncId(func_id))
// Return a TraitMethodId with unbound generics. These will later be bound by the type-checker.
let trait_id = traits_in_scope[0].0;
let trait_ = self.interner.get_trait(trait_id);
let generics = trait_.as_constraint(span).trait_bound.trait_generics;
let trait_method_id = trait_.find_method(method_name).unwrap();
Some(HirMethodReference::TraitMethodId(trait_method_id, generics, false))
}

fn lookup_method_in_trait_constraints(
Expand Down
8 changes: 8 additions & 0 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,14 @@ impl NodeInterner {
pub fn doc_comments(&self, id: ReferenceId) -> Option<&Vec<String>> {
self.doc_comments.get(&id)
}

pub fn get_expr_id_from_index(&self, index: impl Into<Index>) -> Option<ExprId> {
let index = index.into();
match self.nodes.get(index) {
Some(Node::Expression(_)) => Some(ExprId(index)),
_ => None,
}
}
}

impl Methods {
Expand Down
32 changes: 32 additions & 0 deletions compiler/noirc_frontend/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,38 @@ fn errors_if_multiple_trait_methods_are_in_scope_for_method_call() {
assert_eq!(traits, vec!["private_mod::Foo", "private_mod::Foo2"]);
}

#[test]
fn calls_trait_method_if_it_is_in_scope_with_multiple_candidates_but_only_one_decided_by_generics()
{
let src = r#"
struct Foo {
inner: Field,
}
trait Converter<N> {
fn convert(self) -> N;
}
impl Converter<Field> for Foo {
fn convert(self) -> Field {
self.inner
}
}
impl Converter<u32> for Foo {
fn convert(self) -> u32 {
self.inner as u32
}
}
fn main() {
let foo = Foo { inner: 42 };
let _: u32 = foo.convert();
}
"#;
assert_no_errors(src);
}

#[test]
fn type_checks_trait_default_method_and_errors() {
let src = r#"
Expand Down
34 changes: 33 additions & 1 deletion tooling/lsp/src/requests/hover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ use noirc_frontend::{
hir::def_map::ModuleId,
hir_def::{
expr::{HirArrayLiteral, HirExpression, HirLiteral},
function::FuncMeta,
stmt::HirPattern,
traits::Trait,
},
node_interner::{
DefinitionId, DefinitionKind, ExprId, FuncId, GlobalId, NodeInterner, ReferenceId,
StructId, TraitId, TypeAliasId,
StructId, TraitId, TraitImplKind, TypeAliasId,
},
Generics, Shared, StructType, Type, TypeAlias, TypeBinding, TypeVariable,
};
Expand Down Expand Up @@ -296,6 +297,12 @@ fn get_exprs_global_value(interner: &NodeInterner, exprs: &[ExprId]) -> Option<S

fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String {
let func_meta = args.interner.function_meta(&id);

// If this points to a trait method, see if we can figure out what's the concrete trait impl method
if let Some(func_id) = get_trait_impl_func_id(id, args, func_meta) {
return format_function(func_id, args);
}

let func_modifiers = args.interner.function_modifiers(&id);

let func_name_definition_id = args.interner.definition(func_meta.name.id);
Expand Down Expand Up @@ -440,6 +447,31 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String {
string
}

fn get_trait_impl_func_id(
id: FuncId,
args: &ProcessRequestCallbackArgs,
func_meta: &FuncMeta,
) -> Option<FuncId> {
func_meta.trait_id?;

let index = args.interner.find_location_index(args.location)?;
let expr_id = args.interner.get_expr_id_from_index(index)?;
let Some(TraitImplKind::Normal(trait_impl_id)) =
args.interner.get_selected_impl_for_expression(expr_id)
else {
return None;
};

let trait_impl = args.interner.get_trait_implementation(trait_impl_id);
let trait_impl = trait_impl.borrow();

let function_name = args.interner.function_name(&id);
let mut trait_impl_methods = trait_impl.methods.iter();
let func_id =
trait_impl_methods.find(|func_id| args.interner.function_name(func_id) == function_name)?;
Some(*func_id)
}

fn format_alias(id: TypeAliasId, args: &ProcessRequestCallbackArgs) -> String {
let type_alias = args.interner.get_type_alias(id);
let type_alias = type_alias.borrow();
Expand Down

0 comments on commit 7328f0b

Please sign in to comment.