From 35408ab303f1018c1e2c38e6ea55430a2c89dc4c Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 14 Nov 2024 08:27:20 -0300 Subject: [PATCH] fix: disallow `#[test]` on associated functions (#6449) Co-authored-by: Michael J Klein --- .../noirc_frontend/src/elaborator/comptime.rs | 9 +++- .../src/hir/def_collector/dc_mod.rs | 38 +++++++++++++-- .../src/hir/def_collector/errors.rs | 8 ++++ compiler/noirc_frontend/src/tests.rs | 48 +++++++++++++++++++ 4 files changed, 98 insertions(+), 5 deletions(-) diff --git a/compiler/noirc_frontend/src/elaborator/comptime.rs b/compiler/noirc_frontend/src/elaborator/comptime.rs index 279adc331ea..a27e2bf0163 100644 --- a/compiler/noirc_frontend/src/elaborator/comptime.rs +++ b/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -452,7 +452,14 @@ impl<'context> Elaborator<'context> { } ItemKind::Impl(r#impl) => { let module = self.module_id(); - dc_mod::collect_impl(self.interner, generated_items, r#impl, self.file, module); + dc_mod::collect_impl( + self.interner, + generated_items, + r#impl, + self.file, + module, + &mut self.errors, + ); } ItemKind::ModuleDecl(_) diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index a373441b4e0..bae57daae15 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -97,9 +97,9 @@ pub fn collect_defs( errors.extend(collector.collect_functions(context, ast.functions, crate_id)); - collector.collect_trait_impls(context, ast.trait_impls, crate_id); + errors.extend(collector.collect_trait_impls(context, ast.trait_impls, crate_id)); - collector.collect_impls(context, ast.impls, crate_id); + errors.extend(collector.collect_impls(context, ast.impls, crate_id)); collector.collect_attributes( ast.inner_attributes, @@ -163,7 +163,13 @@ impl<'a> ModCollector<'a> { errors } - fn collect_impls(&mut self, context: &mut Context, impls: Vec, krate: CrateId) { + fn collect_impls( + &mut self, + context: &mut Context, + impls: Vec, + krate: CrateId, + ) -> Vec<(CompilationError, FileId)> { + let mut errors = Vec::new(); let module_id = ModuleId { krate, local_id: self.module_id }; for r#impl in impls { @@ -173,8 +179,11 @@ impl<'a> ModCollector<'a> { r#impl, self.file_id, module_id, + &mut errors, ); } + + errors } fn collect_trait_impls( @@ -182,7 +191,9 @@ impl<'a> ModCollector<'a> { context: &mut Context, impls: Vec, krate: CrateId, - ) { + ) -> Vec<(CompilationError, FileId)> { + let mut errors = Vec::new(); + for mut trait_impl in impls { let trait_name = trait_impl.trait_name.clone(); @@ -198,6 +209,13 @@ impl<'a> ModCollector<'a> { let module = ModuleId { krate, local_id: self.module_id }; for (_, func_id, noir_function) in &mut unresolved_functions.functions { + if noir_function.def.attributes.is_test_function() { + let error = DefCollectorErrorKind::TestOnAssociatedFunction { + span: noir_function.name_ident().span(), + }; + errors.push((error.into(), self.file_id)); + } + let location = Location::new(noir_function.def.span, self.file_id); context.def_interner.push_function(*func_id, &noir_function.def, module, location); } @@ -224,6 +242,8 @@ impl<'a> ModCollector<'a> { self.def_collector.items.trait_impls.push(unresolved_trait_impl); } + + errors } fn collect_functions( @@ -1051,6 +1071,7 @@ pub fn collect_impl( r#impl: TypeImpl, file_id: FileId, module_id: ModuleId, + errors: &mut Vec<(CompilationError, FileId)>, ) { let mut unresolved_functions = UnresolvedFunctions { file_id, functions: Vec::new(), trait_id: None, self_type: None }; @@ -1058,6 +1079,15 @@ pub fn collect_impl( for (method, _) in r#impl.methods { let doc_comments = method.doc_comments; let mut method = method.item; + + if method.def.attributes.is_test_function() { + let error = DefCollectorErrorKind::TestOnAssociatedFunction { + span: method.name_ident().span(), + }; + errors.push((error.into(), file_id)); + continue; + } + let func_id = interner.push_empty_fn(); method.def.where_clause.extend(r#impl.where_clause.clone()); let location = Location::new(method.span(), file_id); diff --git a/compiler/noirc_frontend/src/hir/def_collector/errors.rs b/compiler/noirc_frontend/src/hir/def_collector/errors.rs index d72f493092d..c08b4ff2062 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/errors.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/errors.rs @@ -82,6 +82,8 @@ pub enum DefCollectorErrorKind { }, #[error("{0}")] UnsupportedNumericGenericType(#[from] UnsupportedNumericGenericType), + #[error("The `#[test]` attribute may only be used on a non-associated function")] + TestOnAssociatedFunction { span: Span }, } impl DefCollectorErrorKind { @@ -291,6 +293,12 @@ impl<'a> From<&'a DefCollectorErrorKind> for Diagnostic { diag } DefCollectorErrorKind::UnsupportedNumericGenericType(err) => err.into(), + DefCollectorErrorKind::TestOnAssociatedFunction { span } => Diagnostic::simple_error( + "The `#[test]` attribute is disallowed on `impl` methods".into(), + String::new(), + *span, + ), + } } } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index b709436cccc..23f134cc2a3 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3692,3 +3692,51 @@ fn allows_struct_with_generic_infix_type_as_main_input_3() { "#; assert_no_errors(src); } + +#[test] +fn disallows_test_attribute_on_impl_method() { + let src = r#" + pub struct Foo {} + impl Foo { + #[test] + fn foo() {} + } + + fn main() {} + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + assert!(matches!( + errors[0].0, + CompilationError::DefinitionError(DefCollectorErrorKind::TestOnAssociatedFunction { + span: _ + }) + )); +} + +#[test] +fn disallows_test_attribute_on_trait_impl_method() { + let src = r#" + pub trait Trait { + fn foo() {} + } + + pub struct Foo {} + impl Trait for Foo { + #[test] + fn foo() {} + } + + fn main() {} + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + assert!(matches!( + errors[0].0, + CompilationError::DefinitionError(DefCollectorErrorKind::TestOnAssociatedFunction { + span: _ + }) + )); +}