Skip to content

Commit

Permalink
Merge pull request quarkusio#32949 from mkouba/issue-32944
Browse files Browse the repository at this point in the history
InjectMock should not create a new contextual instance
  • Loading branch information
geoand authored Apr 28, 2023
2 parents 0b600c3 + c5a8c83 commit 69a794f
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ Collection<Resource> generateSyntheticBean(BeanInfo bean) {
implementGetStereotypes(bean, beanCreator, stereotypes.getFieldDescriptor());
}
implementGetBeanClass(bean, beanCreator);
implementGetImplementationClass(bean, beanCreator);
implementGetName(bean, beanCreator);
if (bean.isDefaultBean()) {
implementIsDefaultBean(bean, beanCreator);
Expand Down Expand Up @@ -487,6 +488,7 @@ Collection<Resource> generateProducerMethodBean(BeanInfo bean, MethodInfo produc
implementGetStereotypes(bean, beanCreator, stereotypes.getFieldDescriptor());
}
implementGetBeanClass(bean, beanCreator);
implementGetImplementationClass(bean, beanCreator);
implementGetName(bean, beanCreator);
if (bean.isDefaultBean()) {
implementIsDefaultBean(bean, beanCreator);
Expand Down Expand Up @@ -567,6 +569,7 @@ Collection<Resource> generateProducerFieldBean(BeanInfo bean, FieldInfo producer
implementGetStereotypes(bean, beanCreator, stereotypes.getFieldDescriptor());
}
implementGetBeanClass(bean, beanCreator);
implementGetImplementationClass(bean, beanCreator);
implementGetName(bean, beanCreator);
if (bean.isDefaultBean()) {
implementIsDefaultBean(bean, beanCreator);
Expand Down Expand Up @@ -2068,6 +2071,13 @@ protected void implementGetBeanClass(BeanInfo bean, ClassCreator beanCreator) {
getBeanClass.returnValue(getBeanClass.loadClass(bean.getBeanClass().toString()));
}

protected void implementGetImplementationClass(BeanInfo bean, ClassCreator beanCreator) {
MethodCreator getImplementationClass = beanCreator.getMethodCreator("getImplementationClass", Class.class)
.setModifiers(ACC_PUBLIC);
getImplementationClass.returnValue(bean.getImplClazz() != null ? getImplementationClass.loadClass(bean.getImplClazz())
: getImplementationClass.loadNull());
}

protected void implementGetName(BeanInfo bean, ClassCreator beanCreator) {
if (bean.getName() != null) {
MethodCreator getName = beanCreator.getMethodCreator("getName", String.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,23 @@ default int getPriority() {
return 0;
}

/**
* The return value depends on the {@link #getKind()}.
*
* <ul>
* <li>For managed beans, interceptors, decorators and built-in beans, the bean class is returned.</li>
* <li>For a producer method, the class of the return type is returned.</li>
* <li>For a producer field, the class of the field is returned.</li>
* <li>For a synthetic bean, the implementation class defined by the registrar is returned.
* </ul>
*
* @return the implementation class, or null in case of a producer of a primitive type or an array
* @see Kind
*/
default Class<?> getImplementationClass() {
return getBeanClass();
}

enum Kind {

CLASS,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.quarkus.it.mockbean;

import java.util.concurrent.atomic.AtomicBoolean;

import jakarta.annotation.PostConstruct;
import jakarta.enterprise.context.RequestScoped;

import io.quarkus.arc.Unremovable;

@Unremovable
@RequestScoped
public class RequestScopedFoo {

static final AtomicBoolean CONSTRUCTED = new AtomicBoolean();

public String ping() {
return "bar";
}

@PostConstruct
void init() {
CONSTRUCTED.set(true);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.quarkus.it.mockbean;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.mockito.Mockito.when;

import org.junit.jupiter.api.Test;

import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.mockito.InjectMock;

@QuarkusTest
class RequestScopedFooMockTest {

@InjectMock
RequestScopedFoo foo;

@Test
void testMock() {
when(foo.ping()).thenReturn("pong");
assertEquals("pong", foo.ping());
assertFalse(RequestScopedFoo.CONSTRUCTED.get());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.ClientProxy;
import io.quarkus.arc.InstanceHandle;
import io.quarkus.arc.Subclass;
import io.quarkus.test.junit.callback.QuarkusTestAfterConstructCallback;
import io.quarkus.test.junit.mockito.InjectMock;

Expand All @@ -29,24 +27,25 @@ public void afterConstruct(Object testInstance) {
InjectMock injectMockAnnotation = field.getAnnotation(InjectMock.class);
if (injectMockAnnotation != null) {
boolean returnsDeepMocks = injectMockAnnotation.returnsDeepMocks();
Object contextualReference = getContextualReference(testInstance, field, InjectMock.class);
Optional<Object> result = createMockAndSetTestField(testInstance, field, contextualReference,
InstanceHandle<?> beanHandle = getBeanHandle(testInstance, field, InjectMock.class);
Optional<Object> result = createMockAndSetTestField(testInstance, field, beanHandle,
new MockConfiguration(returnsDeepMocks));
if (result.isPresent()) {
MockitoMocksTracker.track(testInstance, result.get(), contextualReference);
MockitoMocksTracker.track(testInstance, result.get(), beanHandle.get());
}
}
}
current = current.getSuperclass();
}
}

private Optional<Object> createMockAndSetTestField(Object testInstance, Field field, Object contextualReference,
private Optional<Object> createMockAndSetTestField(Object testInstance, Field field, InstanceHandle<?> beanHandle,
MockConfiguration mockConfiguration) {
Class<?> implementationClass = getImplementationClass(contextualReference);
Class<?> implementationClass = beanHandle.getBean().getImplementationClass();
Object mock;
boolean isNew;
Optional<Object> currentMock = MockitoMocksTracker.currentMock(testInstance, contextualReference);
// Note that beanHandle.get() returns a client proxy for normal scoped beans; i.e. the contextual instance is not created
Optional<Object> currentMock = MockitoMocksTracker.currentMock(testInstance, beanHandle.get());
if (currentMock.isPresent()) {
mock = currentMock.get();
isNew = false;
Expand All @@ -71,15 +70,7 @@ private Optional<Object> createMockAndSetTestField(Object testInstance, Field fi
}
}

/**
* Contextual reference of a normal scoped bean is a client proxy.
*
* @param testInstance
* @param field
* @param annotationType
* @return a contextual reference of a bean
*/
static Object getContextualReference(Object testInstance, Field field, Class<? extends Annotation> annotationType) {
static InstanceHandle<?> getBeanHandle(Object testInstance, Field field, Class<? extends Annotation> annotationType) {
Type fieldType = field.getGenericType();
ArcContainer container = Arc.container();
BeanManager beanManager = container.beanManager();
Expand All @@ -100,15 +91,7 @@ static Object getContextualReference(Object testInstance, Field field, Class<? e
+ ". Offending field is " + field.getName() + " of test class "
+ testInstance.getClass());
}
return handle.get();
}

static Class<?> getImplementationClass(Object contextualReference) {
// Unwrap the client proxy if needed
Object contextualInstance = ClientProxy.unwrap(contextualReference);
// If the contextual instance is an intercepted subclass then mock the extended implementation class
return contextualInstance instanceof Subclass ? contextualInstance.getClass().getSuperclass()
: contextualInstance.getClass();
return handle;
}

static Annotation[] getQualifiers(Field fieldToMock, BeanManager beanManager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.mockito.Mockito;

import io.quarkus.arc.ClientProxy;
import io.quarkus.arc.InstanceHandle;
import io.quarkus.test.junit.callback.QuarkusTestAfterConstructCallback;
import io.quarkus.test.junit.mockito.InjectSpy;

Expand All @@ -18,22 +19,22 @@ public void afterConstruct(Object testInstance) {
for (Field field : current.getDeclaredFields()) {
InjectSpy injectSpyAnnotation = field.getAnnotation(InjectSpy.class);
if (injectSpyAnnotation != null) {
Object contextualReference = CreateMockitoMocksCallback.getContextualReference(testInstance, field,
InstanceHandle<?> beanHandle = CreateMockitoMocksCallback.getBeanHandle(testInstance, field,
InjectSpy.class);
Object spy = createSpyAndSetTestField(testInstance, field, contextualReference,
Object spy = createSpyAndSetTestField(testInstance, field, beanHandle,
injectSpyAnnotation.delegate());
MockitoMocksTracker.track(testInstance, spy, contextualReference);
MockitoMocksTracker.track(testInstance, spy, beanHandle.get());
}
}
current = current.getSuperclass();
}
}

private Object createSpyAndSetTestField(Object testInstance, Field field, Object contextualReference, boolean delegate) {
private Object createSpyAndSetTestField(Object testInstance, Field field, InstanceHandle<?> beanHandle, boolean delegate) {
Object spy;
Object contextualInstance = ClientProxy.unwrap(contextualReference);
Object contextualInstance = ClientProxy.unwrap(beanHandle.get());
if (delegate) {
spy = Mockito.mock(CreateMockitoMocksCallback.getImplementationClass(contextualReference),
spy = Mockito.mock(beanHandle.getBean().getImplementationClass(),
AdditionalAnswers.delegatesTo(contextualInstance));
} else {
// Unwrap the client proxy if needed
Expand Down

0 comments on commit 69a794f

Please sign in to comment.