diff --git a/compiler/noirc_frontend/src/elaborator/comptime.rs b/compiler/noirc_frontend/src/elaborator/comptime.rs index a27e2bf0163..962356d6dd9 100644 --- a/compiler/noirc_frontend/src/elaborator/comptime.rs +++ b/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -329,8 +329,6 @@ impl<'context> Elaborator<'context> { push_arg(Value::TraitDefinition(trait_id)); } else { let (expr_id, expr_type) = interpreter.elaborator.elaborate_expression(arg); - push_arg(interpreter.evaluate(expr_id)?); - if let Err(UnificationError) = expr_type.unify(param_type) { return Err(InterpreterError::TypeMismatch { expected: param_type.clone(), @@ -338,6 +336,7 @@ impl<'context> Elaborator<'context> { location: arg_location, }); } + push_arg(interpreter.evaluate(expr_id)?); }; } diff --git a/compiler/noirc_frontend/src/tests/metaprogramming.rs b/compiler/noirc_frontend/src/tests/metaprogramming.rs index 82c40203244..89a049ebc9d 100644 --- a/compiler/noirc_frontend/src/tests/metaprogramming.rs +++ b/compiler/noirc_frontend/src/tests/metaprogramming.rs @@ -141,3 +141,23 @@ fn errors_if_macros_inject_functions_with_name_collisions() { ) if contents == "foo" )); } + +#[test] +fn uses_correct_type_for_attribute_arguments() { + let src = r#" + #[foo(32)] + comptime fn foo(_f: FunctionDefinition, i: u32) { + let y: u32 = 1; + let _ = y == i; + } + + #[bar([0; 2])] + comptime fn bar(_f: FunctionDefinition, i: [u32; 2]) { + let y: u32 = 1; + let _ = y == i[0]; + } + + fn main() {} + "#; + assert_no_errors(src); +}