From 7328f0b2a7411e7c38dae0c2bd5fa2cd04c75461 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Wed, 8 Jan 2025 14:35:46 -0300 Subject: [PATCH] fix: allow multiple trait impls for the same trait as long as one is in scope (#6987) --- .../noirc_frontend/src/elaborator/types.rs | 42 ++++++++++++------- compiler/noirc_frontend/src/node_interner.rs | 8 ++++ compiler/noirc_frontend/src/tests/traits.rs | 32 ++++++++++++++ tooling/lsp/src/requests/hover.rs | 34 ++++++++++++++- 4 files changed, 101 insertions(+), 15 deletions(-) diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index e01cda3a2e4..320b85e70cf 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -1,5 +1,6 @@ use std::{borrow::Cow, rc::Rc}; +use im::HashSet; use iter_extended::vecmap; use noirc_errors::{Location, Span}; use rustc_hash::FxHashMap as HashMap; @@ -1422,9 +1423,14 @@ impl<'context> Elaborator<'context> { let module_id = self.module_id(); let module_data = self.get_module(module_id); - let trait_methods_in_scope: Vec<_> = trait_methods + // Only keep unique trait IDs: multiple trait methods might come from the same trait + // but implemented with different generics (like `Convert` and `Convert`). + let traits: HashSet = + trait_methods.into_iter().map(|(_, trait_id)| trait_id).collect(); + + let traits_in_scope: Vec<_> = traits .iter() - .filter_map(|(func_id, trait_id)| { + .filter_map(|trait_id| { let trait_ = self.interner.get_trait(*trait_id); let trait_name = &trait_.name; let Some(map) = module_data.scope().types().get(trait_name) else { @@ -1434,30 +1440,34 @@ impl<'context> Elaborator<'context> { return None; }; if imported_item.0 == ModuleDefId::TraitId(*trait_id) { - Some((*func_id, *trait_id, trait_name)) + Some((*trait_id, trait_name)) } else { None } }) .collect(); - for (_, _, trait_name) in &trait_methods_in_scope { + for (_, trait_name) in &traits_in_scope { self.usage_tracker.mark_as_used(module_id, trait_name); } - if trait_methods_in_scope.is_empty() { - if trait_methods.len() == 1 { - // This is the backwards-compatible case where there's a single trait method but it's not in scope - let (func_id, trait_id) = trait_methods[0]; + if traits_in_scope.is_empty() { + if traits.len() == 1 { + // This is the backwards-compatible case where there's a single trait but it's not in scope + let trait_id = *traits.iter().next().unwrap(); let trait_ = self.interner.get_trait(trait_id); let trait_name = self.fully_qualified_trait_path(trait_); + let generics = trait_.as_constraint(span).trait_bound.trait_generics; + let trait_method_id = trait_.find_method(method_name).unwrap(); + self.push_err(PathResolutionError::TraitMethodNotInScope { ident: Ident::new(method_name.into(), span), trait_name, }); - return Some(HirMethodReference::FuncId(func_id)); + + return Some(HirMethodReference::TraitMethodId(trait_method_id, generics, false)); } else { - let traits = vecmap(trait_methods, |(_, trait_id)| { + let traits = vecmap(traits, |trait_id| { let trait_ = self.interner.get_trait(trait_id); self.fully_qualified_trait_path(trait_) }); @@ -1469,8 +1479,8 @@ impl<'context> Elaborator<'context> { } } - if trait_methods_in_scope.len() > 1 { - let traits = vecmap(trait_methods, |(_, trait_id)| { + if traits_in_scope.len() > 1 { + let traits = vecmap(traits, |trait_id| { let trait_ = self.interner.get_trait(trait_id); self.fully_qualified_trait_path(trait_) }); @@ -1481,8 +1491,12 @@ impl<'context> Elaborator<'context> { return None; } - let func_id = trait_methods_in_scope[0].0; - Some(HirMethodReference::FuncId(func_id)) + // Return a TraitMethodId with unbound generics. These will later be bound by the type-checker. + let trait_id = traits_in_scope[0].0; + let trait_ = self.interner.get_trait(trait_id); + let generics = trait_.as_constraint(span).trait_bound.trait_generics; + let trait_method_id = trait_.find_method(method_name).unwrap(); + Some(HirMethodReference::TraitMethodId(trait_method_id, generics, false)) } fn lookup_method_in_trait_constraints( diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 25bc507da1e..599558bb91a 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -2313,6 +2313,14 @@ impl NodeInterner { pub fn doc_comments(&self, id: ReferenceId) -> Option<&Vec> { self.doc_comments.get(&id) } + + pub fn get_expr_id_from_index(&self, index: impl Into) -> Option { + let index = index.into(); + match self.nodes.get(index) { + Some(Node::Expression(_)) => Some(ExprId(index)), + _ => None, + } + } } impl Methods { diff --git a/compiler/noirc_frontend/src/tests/traits.rs b/compiler/noirc_frontend/src/tests/traits.rs index 82c8460d004..faacd9aeee3 100644 --- a/compiler/noirc_frontend/src/tests/traits.rs +++ b/compiler/noirc_frontend/src/tests/traits.rs @@ -1007,6 +1007,38 @@ fn errors_if_multiple_trait_methods_are_in_scope_for_method_call() { assert_eq!(traits, vec!["private_mod::Foo", "private_mod::Foo2"]); } +#[test] +fn calls_trait_method_if_it_is_in_scope_with_multiple_candidates_but_only_one_decided_by_generics() +{ + let src = r#" + struct Foo { + inner: Field, + } + + trait Converter { + fn convert(self) -> N; + } + + impl Converter for Foo { + fn convert(self) -> Field { + self.inner + } + } + + impl Converter for Foo { + fn convert(self) -> u32 { + self.inner as u32 + } + } + + fn main() { + let foo = Foo { inner: 42 }; + let _: u32 = foo.convert(); + } + "#; + assert_no_errors(src); +} + #[test] fn type_checks_trait_default_method_and_errors() { let src = r#" diff --git a/tooling/lsp/src/requests/hover.rs b/tooling/lsp/src/requests/hover.rs index 78be09653fc..e2e2d2881dc 100644 --- a/tooling/lsp/src/requests/hover.rs +++ b/tooling/lsp/src/requests/hover.rs @@ -8,12 +8,13 @@ use noirc_frontend::{ hir::def_map::ModuleId, hir_def::{ expr::{HirArrayLiteral, HirExpression, HirLiteral}, + function::FuncMeta, stmt::HirPattern, traits::Trait, }, node_interner::{ DefinitionId, DefinitionKind, ExprId, FuncId, GlobalId, NodeInterner, ReferenceId, - StructId, TraitId, TypeAliasId, + StructId, TraitId, TraitImplKind, TypeAliasId, }, Generics, Shared, StructType, Type, TypeAlias, TypeBinding, TypeVariable, }; @@ -296,6 +297,12 @@ fn get_exprs_global_value(interner: &NodeInterner, exprs: &[ExprId]) -> Option String { let func_meta = args.interner.function_meta(&id); + + // If this points to a trait method, see if we can figure out what's the concrete trait impl method + if let Some(func_id) = get_trait_impl_func_id(id, args, func_meta) { + return format_function(func_id, args); + } + let func_modifiers = args.interner.function_modifiers(&id); let func_name_definition_id = args.interner.definition(func_meta.name.id); @@ -440,6 +447,31 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { string } +fn get_trait_impl_func_id( + id: FuncId, + args: &ProcessRequestCallbackArgs, + func_meta: &FuncMeta, +) -> Option { + func_meta.trait_id?; + + let index = args.interner.find_location_index(args.location)?; + let expr_id = args.interner.get_expr_id_from_index(index)?; + let Some(TraitImplKind::Normal(trait_impl_id)) = + args.interner.get_selected_impl_for_expression(expr_id) + else { + return None; + }; + + let trait_impl = args.interner.get_trait_implementation(trait_impl_id); + let trait_impl = trait_impl.borrow(); + + let function_name = args.interner.function_name(&id); + let mut trait_impl_methods = trait_impl.methods.iter(); + let func_id = + trait_impl_methods.find(|func_id| args.interner.function_name(func_id) == function_name)?; + Some(*func_id) +} + fn format_alias(id: TypeAliasId, args: &ProcessRequestCallbackArgs) -> String { let type_alias = args.interner.get_type_alias(id); let type_alias = type_alias.borrow();