Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow multiple trait impls for the same trait as long as one is in scope #6987

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{borrow::Cow, rc::Rc};

use im::HashSet;
use iter_extended::vecmap;
use noirc_errors::{Location, Span};
use rustc_hash::FxHashMap as HashMap;
Expand Down Expand Up @@ -1422,9 +1423,14 @@ impl<'context> Elaborator<'context> {
let module_id = self.module_id();
let module_data = self.get_module(module_id);

let trait_methods_in_scope: Vec<_> = trait_methods
// Only keep unique trait IDs: multiple trait methods might come from the same trait
// but implemented with different generics (like `Convert<Field>` and `Convert<i32>`).
let traits: HashSet<TraitId> =
trait_methods.into_iter().map(|(_, trait_id)| trait_id).collect();

let traits_in_scope: Vec<_> = traits
.iter()
.filter_map(|(func_id, trait_id)| {
.filter_map(|trait_id| {
let trait_ = self.interner.get_trait(*trait_id);
let trait_name = &trait_.name;
let Some(map) = module_data.scope().types().get(trait_name) else {
Expand All @@ -1434,30 +1440,34 @@ impl<'context> Elaborator<'context> {
return None;
};
if imported_item.0 == ModuleDefId::TraitId(*trait_id) {
Some((*func_id, *trait_id, trait_name))
Some((*trait_id, trait_name))
} else {
None
}
})
.collect();

for (_, _, trait_name) in &trait_methods_in_scope {
for (_, trait_name) in &traits_in_scope {
self.usage_tracker.mark_as_used(module_id, trait_name);
}

if trait_methods_in_scope.is_empty() {
if trait_methods.len() == 1 {
// This is the backwards-compatible case where there's a single trait method but it's not in scope
let (func_id, trait_id) = trait_methods[0];
if traits_in_scope.is_empty() {
if traits.len() == 1 {
// This is the backwards-compatible case where there's a single trait but it's not in scope
let trait_id = *traits.iter().next().unwrap();
let trait_ = self.interner.get_trait(trait_id);
let trait_name = self.fully_qualified_trait_path(trait_);
let generics = trait_.as_constraint(span).trait_bound.trait_generics;
let trait_method_id = trait_.find_method(method_name).unwrap();

self.push_err(PathResolutionError::TraitMethodNotInScope {
ident: Ident::new(method_name.into(), span),
trait_name,
});
return Some(HirMethodReference::FuncId(func_id));

return Some(HirMethodReference::TraitMethodId(trait_method_id, generics, false));
} else {
let traits = vecmap(trait_methods, |(_, trait_id)| {
let traits = vecmap(traits, |trait_id| {
let trait_ = self.interner.get_trait(trait_id);
self.fully_qualified_trait_path(trait_)
});
Expand All @@ -1469,8 +1479,8 @@ impl<'context> Elaborator<'context> {
}
}

if trait_methods_in_scope.len() > 1 {
let traits = vecmap(trait_methods, |(_, trait_id)| {
if traits_in_scope.len() > 1 {
let traits = vecmap(traits, |trait_id| {
let trait_ = self.interner.get_trait(trait_id);
self.fully_qualified_trait_path(trait_)
});
Expand All @@ -1481,8 +1491,12 @@ impl<'context> Elaborator<'context> {
return None;
}

let func_id = trait_methods_in_scope[0].0;
Some(HirMethodReference::FuncId(func_id))
// Return a TraitMethodId with unbound generics. These will later be bound by the type-checker.
let trait_id = traits_in_scope[0].0;
let trait_ = self.interner.get_trait(trait_id);
let generics = trait_.as_constraint(span).trait_bound.trait_generics;
let trait_method_id = trait_.find_method(method_name).unwrap();
Some(HirMethodReference::TraitMethodId(trait_method_id, generics, false))
}

fn lookup_method_in_trait_constraints(
Expand Down
8 changes: 8 additions & 0 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@
interned_statement_kinds: noirc_arena::Arena<StatementKind>,

// Interned `UnresolvedTypeData`s during comptime code.
interned_unresolved_type_datas: noirc_arena::Arena<UnresolvedTypeData>,

Check warning on line 220 in compiler/noirc_frontend/src/node_interner.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (datas)

// Interned `Pattern`s during comptime code.
interned_patterns: noirc_arena::Arena<Pattern>,

/// Determins whether to run in LSP mode. In LSP mode references are tracked.

Check warning on line 225 in compiler/noirc_frontend/src/node_interner.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Determins)
pub(crate) lsp_mode: bool,

/// Store the location of the references in the graph.
Expand Down Expand Up @@ -673,7 +673,7 @@
quoted_types: Default::default(),
interned_expression_kinds: Default::default(),
interned_statement_kinds: Default::default(),
interned_unresolved_type_datas: Default::default(),

Check warning on line 676 in compiler/noirc_frontend/src/node_interner.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (datas)
interned_patterns: Default::default(),
lsp_mode: false,
location_indices: LocationIndices::default(),
Expand Down Expand Up @@ -2199,11 +2199,11 @@
&mut self,
typ: UnresolvedTypeData,
) -> InternedUnresolvedTypeData {
InternedUnresolvedTypeData(self.interned_unresolved_type_datas.insert(typ))

Check warning on line 2202 in compiler/noirc_frontend/src/node_interner.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (datas)
}

pub fn get_unresolved_type_data(&self, id: InternedUnresolvedTypeData) -> &UnresolvedTypeData {
&self.interned_unresolved_type_datas[id.0]

Check warning on line 2206 in compiler/noirc_frontend/src/node_interner.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (datas)
}

/// Returns the type of an operator (which is always a function), along with its return type.
Expand Down Expand Up @@ -2313,6 +2313,14 @@
pub fn doc_comments(&self, id: ReferenceId) -> Option<&Vec<String>> {
self.doc_comments.get(&id)
}

pub fn get_expr_id_from_index(&self, index: impl Into<Index>) -> Option<ExprId> {
let index = index.into();
match self.nodes.get(index) {
Some(Node::Expression(_)) => Some(ExprId(index)),
_ => None,
}
}
}

impl Methods {
Expand Down
32 changes: 32 additions & 0 deletions compiler/noirc_frontend/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,38 @@ fn errors_if_multiple_trait_methods_are_in_scope_for_method_call() {
assert_eq!(traits, vec!["private_mod::Foo", "private_mod::Foo2"]);
}

#[test]
fn calls_trait_method_if_it_is_in_scope_with_multiple_candidates_but_only_one_decided_by_generics()
{
let src = r#"
struct Foo {
inner: Field,
}

trait Converter<N> {
fn convert(self) -> N;
}

impl Converter<Field> for Foo {
fn convert(self) -> Field {
self.inner
}
}

impl Converter<u32> for Foo {
fn convert(self) -> u32 {
self.inner as u32
}
}

fn main() {
let foo = Foo { inner: 42 };
let _: u32 = foo.convert();
}
"#;
assert_no_errors(src);
}

#[test]
fn type_checks_trait_default_method_and_errors() {
let src = r#"
Expand Down
34 changes: 33 additions & 1 deletion tooling/lsp/src/requests/hover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
hir::def_map::ModuleId,
hir_def::{
expr::{HirArrayLiteral, HirExpression, HirLiteral},
function::FuncMeta,
stmt::HirPattern,
traits::Trait,
},
node_interner::{
DefinitionId, DefinitionKind, ExprId, FuncId, GlobalId, NodeInterner, ReferenceId,
StructId, TraitId, TypeAliasId,
StructId, TraitId, TraitImplKind, TypeAliasId,
},
Generics, Shared, StructType, Type, TypeAlias, TypeBinding, TypeVariable,
};
Expand Down Expand Up @@ -296,6 +297,12 @@

fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String {
let func_meta = args.interner.function_meta(&id);

// If this points to a trait method, see if we can figure out what's the concrete trait impl method
if let Some(func_id) = get_trait_impl_func_id(id, args, func_meta) {
return format_function(func_id, args);
}

let func_modifiers = args.interner.function_modifiers(&id);

let func_name_definition_id = args.interner.definition(func_meta.name.id);
Expand Down Expand Up @@ -440,6 +447,31 @@
string
}

fn get_trait_impl_func_id(
id: FuncId,
args: &ProcessRequestCallbackArgs,
func_meta: &FuncMeta,
) -> Option<FuncId> {
func_meta.trait_id?;

let index = args.interner.find_location_index(args.location)?;
let expr_id = args.interner.get_expr_id_from_index(index)?;
let Some(TraitImplKind::Normal(trait_impl_id)) =
args.interner.get_selected_impl_for_expression(expr_id)
else {
return None;
};

let trait_impl = args.interner.get_trait_implementation(trait_impl_id);
let trait_impl = trait_impl.borrow();

let function_name = args.interner.function_name(&id);
let mut trait_impl_methods = trait_impl.methods.iter();
let func_id =
trait_impl_methods.find(|func_id| args.interner.function_name(func_id) == function_name)?;
Some(*func_id)
}

fn format_alias(id: TypeAliasId, args: &ProcessRequestCallbackArgs) -> String {
let type_alias = args.interner.get_type_alias(id);
let type_alias = type_alias.borrow();
Expand Down Expand Up @@ -806,7 +838,7 @@
"two/src/lib.nr",
Position { line: 6, character: 9 },
r#" one
mod subone"#,

Check warning on line 841 in tooling/lsp/src/requests/hover.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (subone)
)
.await;
}
Expand All @@ -817,7 +849,7 @@
"workspace",
"two/src/lib.nr",
Position { line: 9, character: 20 },
r#" one::subone

Check warning on line 852 in tooling/lsp/src/requests/hover.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (subone)
struct SubOneStruct {
some_field: i32,
some_other_field: Field,
Expand All @@ -832,7 +864,7 @@
"workspace",
"two/src/lib.nr",
Position { line: 46, character: 17 },
r#" one::subone

Check warning on line 867 in tooling/lsp/src/requests/hover.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (subone)
struct GenericStruct<A, B> {
}"#,
)
Expand All @@ -845,7 +877,7 @@
"workspace",
"two/src/lib.nr",
Position { line: 9, character: 35 },
r#" one::subone::SubOneStruct

Check warning on line 880 in tooling/lsp/src/requests/hover.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (subone)
some_field: i32"#,
)
.await;
Expand All @@ -857,7 +889,7 @@
"workspace",
"two/src/lib.nr",
Position { line: 12, character: 17 },
r#" one::subone

Check warning on line 892 in tooling/lsp/src/requests/hover.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (subone)
trait SomeTrait"#,
)
.await;
Expand Down
Loading