Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ArC: improve validation of interceptor method signatures #38796

Merged
merged 1 commit into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jakarta.enterprise.inject.UnsatisfiedResolutionException;
import jakarta.enterprise.inject.spi.DefinitionException;
import jakarta.enterprise.inject.spi.DeploymentException;
import jakarta.enterprise.inject.spi.InterceptionType;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
Expand Down Expand Up @@ -710,8 +711,17 @@ static void addImplicitQualifiers(Set<AnnotationInstance> qualifiers) {
}

static List<MethodInfo> getCallbacks(ClassInfo beanClass, DotName annotation, IndexView index) {
InterceptionType interceptionType = null;
if (DotNames.POST_CONSTRUCT.equals(annotation)) {
interceptionType = InterceptionType.POST_CONSTRUCT;
} else if (DotNames.PRE_DESTROY.equals(annotation)) {
interceptionType = InterceptionType.PRE_DESTROY;
} else {
throw new IllegalArgumentException("Unexpected callback annotation: " + annotation);
}

List<MethodInfo> callbacks = new ArrayList<>();
collectCallbacks(beanClass, callbacks, annotation, index, new HashSet<>());
collectCallbacks(beanClass, callbacks, annotation, index, new HashSet<>(), interceptionType);
Collections.reverse(callbacks);
return callbacks;
}
Expand All @@ -729,7 +739,8 @@ static List<MethodInfo> getAroundInvokes(ClassInfo beanClass, BeanDeployment dep
continue;
}
if (store.hasAnnotation(method, DotNames.AROUND_INVOKE)) {
InterceptorInfo.addInterceptorMethod(allMethods, methods, method);
InterceptorInfo.addInterceptorMethod(allMethods, methods, method, InterceptionType.AROUND_INVOKE,
InterceptorPlacement.TARGET_CLASS);
if (++aroundInvokesFound > 1) {
throw new DefinitionException(
"Multiple @AroundInvoke interceptor methods declared on class: " + aClass);
Expand Down Expand Up @@ -1042,24 +1053,18 @@ private static void fetchType(Type type, BeanDeployment beanDeployment) {
}

private static void collectCallbacks(ClassInfo clazz, List<MethodInfo> callbacks, DotName annotation, IndexView index,
Set<String> knownMethods) {
Set<String> knownMethods, InterceptionType interceptionType) {
for (MethodInfo method : clazz.methods()) {
if (method.hasAnnotation(annotation) && !knownMethods.contains(method.name())) {
if (method.returnType().kind() == Kind.VOID && method.parameterTypes().isEmpty()) {
callbacks.add(method);
} else {
// invalid signature - build a meaningful message.
throw new DefinitionException("Invalid signature for the method `" + method + "` from class `"
+ method.declaringClass() + "`. Methods annotated with `" + annotation + "` must return" +
" `void` and cannot have parameters.");
}
InterceptorInfo.validateSignature(method, interceptionType, InterceptorPlacement.TARGET_CLASS);
callbacks.add(method);
}
knownMethods.add(method.name());
}
if (clazz.superName() != null) {
ClassInfo superClass = getClassByName(index, clazz.superName());
if (superClass != null) {
collectCallbacks(superClass, callbacks, annotation, index, knownMethods);
collectCallbacks(superClass, callbacks, annotation, index, knownMethods, interceptionType);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -106,28 +107,32 @@ public class InterceptorInfo extends BeanInfo implements Comparable<InterceptorI
+ aClass);
}
if (store.hasAnnotation(method, DotNames.AROUND_INVOKE)) {
addInterceptorMethod(allMethods, aroundInvokes, method);
addInterceptorMethod(allMethods, aroundInvokes, method, InterceptionType.AROUND_INVOKE,
InterceptorPlacement.INTERCEPTOR_CLASS);
if (++aroundInvokesFound > 1) {
throw new DefinitionException(
"Multiple @AroundInvoke interceptor methods declared on class: " + aClass);
}
}
if (store.hasAnnotation(method, DotNames.AROUND_CONSTRUCT)) {
addInterceptorMethod(allMethods, aroundConstructs, method);
addInterceptorMethod(allMethods, aroundConstructs, method, InterceptionType.AROUND_CONSTRUCT,
InterceptorPlacement.INTERCEPTOR_CLASS);
if (++aroundConstructsFound > 1) {
throw new DefinitionException(
"Multiple @AroundConstruct interceptor methods declared on class: " + aClass);
}
}
if (store.hasAnnotation(method, DotNames.POST_CONSTRUCT)) {
addInterceptorMethod(allMethods, postConstructs, method);
addInterceptorMethod(allMethods, postConstructs, method, InterceptionType.POST_CONSTRUCT,
InterceptorPlacement.INTERCEPTOR_CLASS);
if (++postConstructsFound > 1) {
throw new DefinitionException(
"Multiple @PostConstruct interceptor methods declared on class: " + aClass);
}
}
if (store.hasAnnotation(method, DotNames.PRE_DESTROY)) {
addInterceptorMethod(allMethods, preDestroys, method);
addInterceptorMethod(allMethods, preDestroys, method, InterceptionType.PRE_DESTROY,
InterceptorPlacement.INTERCEPTOR_CLASS);
if (++preDestroysFound > 1) {
throw new DefinitionException(
"Multiple @PreDestroy interceptor methods declared on class: " + aClass);
Expand Down Expand Up @@ -297,8 +302,9 @@ public int compareTo(InterceptorInfo other) {
return getTarget().toString().compareTo(other.getTarget().toString());
}

static void addInterceptorMethod(List<MethodInfo> allMethods, List<MethodInfo> interceptorMethods, MethodInfo method) {
validateSignature(method);
static void addInterceptorMethod(List<MethodInfo> allMethods, List<MethodInfo> interceptorMethods, MethodInfo method,
InterceptionType interceptionType, InterceptorPlacement interceptorPlacement) {
validateSignature(method, interceptionType, interceptorPlacement);
if (!isInterceptorMethodOverriden(allMethods, method)) {
interceptorMethods.add(method);
}
Expand All @@ -319,19 +325,105 @@ static boolean hasInterceptorMethodParameter(MethodInfo method) {
|| method.parameterType(0).name().equals(DotNames.ARC_INVOCATION_CONTEXT));
}

private static MethodInfo validateSignature(MethodInfo method) {
if (!hasInterceptorMethodParameter(method)) {
throw new IllegalStateException(
"An interceptor method must accept exactly one parameter of type jakarta.interceptor.InvocationContext: "
+ method + " declared on " + method.declaringClass());
private enum InterceptorMethodError {
MUST_HAVE_PARAMETER,
MUST_NOT_HAVE_PARAMETER,
WRONG_RETURN_TYPE,
}

static void validateSignature(MethodInfo method, InterceptionType interceptionType,
InterceptorPlacement interceptorPlacement) {
boolean isLifecycleCallback = interceptionType == InterceptionType.AROUND_CONSTRUCT
|| interceptionType == InterceptionType.POST_CONSTRUCT
|| interceptionType == InterceptionType.PRE_DESTROY;

boolean mustHaveParameter = !isLifecycleCallback || interceptorPlacement == InterceptorPlacement.INTERCEPTOR_CLASS;
boolean mustNotHaveParameter = isLifecycleCallback && interceptorPlacement == InterceptorPlacement.TARGET_CLASS;
boolean mayReturnVoid = isLifecycleCallback;
boolean mayReturnObject = !isLifecycleCallback || interceptorPlacement == InterceptorPlacement.INTERCEPTOR_CLASS;

Set<InterceptorMethodError> errors = EnumSet.noneOf(InterceptorMethodError.class);
if (mustHaveParameter && !hasInterceptorMethodParameter(method)) {
errors.add(InterceptorMethodError.MUST_HAVE_PARAMETER);
}
if (mustNotHaveParameter && method.parametersCount() > 0) {
errors.add(InterceptorMethodError.MUST_NOT_HAVE_PARAMETER);
}

boolean wrongReturnType = true;
if (mayReturnVoid && method.returnType().kind().equals(Kind.VOID)) {
wrongReturnType = false;
}
if (mayReturnObject && method.returnType().name().equals(DotNames.OBJECT)) {
wrongReturnType = false;
}
if (wrongReturnType) {
errors.add(InterceptorMethodError.WRONG_RETURN_TYPE);
}
if (!method.returnType().kind().equals(Type.Kind.VOID) &&
!method.returnType().name().equals(DotNames.OBJECT)) {
throw new IllegalStateException(
"The return type of an interceptor method must be java.lang.Object or void: "
+ method + " declared on " + method.declaringClass());

if (!errors.isEmpty()) {
StringBuilder msg = new StringBuilder();
switch (interceptionType) {
case AROUND_CONSTRUCT:
msg.append("@AroundConstruct");
break;
case AROUND_INVOKE:
msg.append("@AroundInvoke");
break;
case POST_CONSTRUCT:
msg.append("@PostConstruct");
break;
case PRE_DESTROY:
msg.append("@PreDestroy");
break;
default:
throw new IllegalArgumentException("Unknown interception type: " + interceptionType);
}
if (isLifecycleCallback) {
msg.append(" lifecycle callback method");
} else {
msg.append(" interceptor method");
}
msg.append(" declared in ");
switch (interceptorPlacement) {
case INTERCEPTOR_CLASS:
msg.append("an interceptor class");
break;
case TARGET_CLASS:
msg.append("a target class");
break;
default:
throw new IllegalArgumentException("Unknown interceptor placement: " + interceptorPlacement);
}
msg.append(" must ");

if (errors.contains(InterceptorMethodError.MUST_HAVE_PARAMETER)) {
msg.append("have exactly one parameter of type jakarta.interceptor.InvocationContext");
} else if (errors.contains(InterceptorMethodError.MUST_NOT_HAVE_PARAMETER)) {
msg.append("have zero parameters");
}

if (errors.contains(InterceptorMethodError.WRONG_RETURN_TYPE)) {
if (errors.contains(InterceptorMethodError.MUST_HAVE_PARAMETER)
|| errors.contains(InterceptorMethodError.MUST_NOT_HAVE_PARAMETER)) {
msg.append(" and must ");
}
msg.append("have a return type of ");
if (mayReturnVoid) {
msg.append("void");
}
if (mayReturnVoid && mayReturnObject) {
msg.append(" or ");
}
if (mayReturnObject) {
msg.append("java.lang.Object");
}
}

msg.append(": ").append(method).append(" declared in ").append(method.declaringClass().name());

throw new DefinitionException(msg.toString());
}
return method;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package io.quarkus.arc.processor;

enum InterceptorPlacement {
INTERCEPTOR_CLASS,
TARGET_CLASS,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.quarkus.arc.test.interceptors.illegal;

import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.PARAMETER;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.lang.annotation.Retention;
import java.lang.annotation.Target;

import jakarta.annotation.Priority;
import jakarta.enterprise.inject.spi.DefinitionException;
import jakarta.interceptor.AroundInvoke;
import jakarta.interceptor.Interceptor;
import jakarta.interceptor.InterceptorBinding;
import jakarta.interceptor.InvocationContext;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.test.ArcTestContainer;

public class InterceptorReturningVoidTest {
@RegisterExtension
public ArcTestContainer container = ArcTestContainer.builder()
.beanClasses(MyInterceptor.class, MyInterceptorBinding.class)
.shouldFail()
.build();

@Test
public void trigger() {
Throwable error = container.getFailure();
assertNotNull(error);
assertInstanceOf(DefinitionException.class, error);
assertTrue(error.getMessage().contains(
"@AroundInvoke interceptor method declared in an interceptor class must have a return type of java.lang.Object"));
assertTrue(error.getMessage().contains("intercept(jakarta.interceptor.InvocationContext ctx)"));
assertTrue(error.getMessage().contains("InterceptorReturningVoidTest$MyInterceptor"));
}

@Target({ TYPE, METHOD, FIELD, PARAMETER })
@Retention(RUNTIME)
@InterceptorBinding
@interface MyInterceptorBinding {
}

@MyInterceptorBinding
@Interceptor
@Priority(1)
static class MyInterceptor {
@AroundInvoke
void intercept(InvocationContext ctx) throws Exception {
ctx.proceed();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package io.quarkus.arc.test.interceptors.illegal;

import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.PARAMETER;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.lang.annotation.Retention;
import java.lang.annotation.Target;

import jakarta.annotation.Priority;
import jakarta.enterprise.inject.spi.DefinitionException;
import jakarta.interceptor.AroundInvoke;
import jakarta.interceptor.Interceptor;
import jakarta.interceptor.InterceptorBinding;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.test.ArcTestContainer;

public class InterceptorWithoutParameterTest {
@RegisterExtension
public ArcTestContainer container = ArcTestContainer.builder()
.beanClasses(MyInterceptor.class, MyInterceptorBinding.class)
.shouldFail()
.build();

@Test
public void trigger() {
Throwable error = container.getFailure();
assertNotNull(error);
assertInstanceOf(DefinitionException.class, error);
assertTrue(error.getMessage().contains(
"@AroundInvoke interceptor method declared in an interceptor class must have exactly one parameter"));
assertTrue(error.getMessage().contains("intercept()"));
assertTrue(error.getMessage().contains("InterceptorWithoutParameterTest$MyInterceptor"));
}

@Target({ TYPE, METHOD, FIELD, PARAMETER })
@Retention(RUNTIME)
@InterceptorBinding
@interface MyInterceptorBinding {
}

@MyInterceptorBinding
@Interceptor
@Priority(1)
static class MyInterceptor {
@AroundInvoke
Object intercept() throws Exception {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.quarkus.arc.test.validation;

import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand All @@ -23,7 +24,11 @@ public class InvalidPostConstructTest {
public void testFailure() {
Throwable error = container.getFailure();
assertNotNull(error);
assertTrue(error instanceof DefinitionException);
assertInstanceOf(DefinitionException.class, error);
assertTrue(error.getMessage().contains(
"@PostConstruct lifecycle callback method declared in a target class must have a return type of void"));
assertTrue(error.getMessage().contains("invalid()"));
assertTrue(error.getMessage().contains("InvalidPostConstructTest$InvalidBean"));
}

@ApplicationScoped
Expand Down
Loading
Loading