Skip to content

Commit

Permalink
Fix @InjectMock and @InjectSpy handling of @nested tests
Browse files Browse the repository at this point in the history
Fixes: #19391
  • Loading branch information
geoand committed Aug 24, 2021
1 parent 778fb13 commit 663e040
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.quarkus.it.mockbean;

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.is;

import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.mockito.InjectMock;

@QuarkusTest
public class NestedTest {

@InjectMock
MessageService messageService;

@Nested
public class ActualTest {

@InjectMock
SuffixService suffixService;

@Test
public void testGreet() {
Mockito.when(messageService.getMessage()).thenReturn("hi");
Mockito.when(suffixService.getSuffix()).thenReturn("!");

given()
.when().get("/greeting")
.then()
.statusCode(200)
.body(is("HI!"));
}

@Test
public void testGreetAgain() {
Mockito.when(messageService.getMessage()).thenReturn("yolo");
Mockito.when(suffixService.getSuffix()).thenReturn("!!!");

given()
.when().get("/greeting")
.then()
.statusCode(200)
.body(is("YOLO!!!"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.quarkus.test.junit.mockito.internal;

import io.quarkus.test.junit.callback.QuarkusTestAfterAllCallback;
import io.quarkus.test.junit.callback.QuarkusTestContext;

public class ResetOuterMockitoMocksCallback implements QuarkusTestAfterAllCallback {

@Override
public void afterAll(QuarkusTestContext context) {
if (context.getOuterInstance() != null) {
MockitoMocksTracker.reset(context.getOuterInstance());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public class SetMockitoMockAsBeanMockCallback implements QuarkusTestBeforeEachCa
@Override
public void beforeEach(QuarkusTestMethodContext context) {
MockitoMocksTracker.getMocks(context.getTestInstance()).forEach(this::installMock);
if (context.getOuterInstance() != null) {
MockitoMocksTracker.getMocks(context.getOuterInstance()).forEach(this::installMock);
}
}

private void installMock(MockitoMocksTracker.Mocked mocked) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
io.quarkus.test.junit.mockito.internal.ResetOuterMockitoMocksCallback
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@
import io.quarkus.test.common.http.TestHTTPEndpoint;
import io.quarkus.test.common.http.TestHTTPResourceManager;
import io.quarkus.test.junit.buildchain.TestBuildChainCustomizerProducer;
import io.quarkus.test.junit.callback.QuarkusTestAfterAllCallback;
import io.quarkus.test.junit.callback.QuarkusTestAfterConstructCallback;
import io.quarkus.test.junit.callback.QuarkusTestAfterEachCallback;
import io.quarkus.test.junit.callback.QuarkusTestBeforeClassCallback;
import io.quarkus.test.junit.callback.QuarkusTestBeforeEachCallback;
import io.quarkus.test.junit.callback.QuarkusTestContext;
import io.quarkus.test.junit.callback.QuarkusTestMethodContext;
import io.quarkus.test.junit.internal.DeepClone;
import io.quarkus.test.junit.internal.SerializationWithXStreamFallbackDeepClone;
Expand All @@ -131,6 +133,8 @@ public class QuarkusTestExtension

private static Class<?> actualTestClass;
private static Object actualTestInstance;
// needed for @Nested
private static Object outerInstance;
private static ClassLoader originalCl;
private static RunningQuarkusApplication runningQuarkusApplication;
private static Pattern clonePattern;
Expand All @@ -140,6 +144,7 @@ public class QuarkusTestExtension
private static List<Object> afterConstructCallbacks;
private static List<Object> beforeEachCallbacks;
private static List<Object> afterEachCallbacks;
private static List<Object> afterAllCallbacks;
private static Class<?> quarkusTestMethodContextClass;
private static Class<? extends QuarkusTestProfile> quarkusTestProfile;
private static boolean hasPerTestResources;
Expand Down Expand Up @@ -474,6 +479,7 @@ private void populateCallbacks(ClassLoader classLoader) throws ClassNotFoundExce
afterConstructCallbacks = new ArrayList<>();
beforeEachCallbacks = new ArrayList<>();
afterEachCallbacks = new ArrayList<>();
afterAllCallbacks = new ArrayList<>();

ServiceLoader<?> quarkusTestBeforeClassLoader = ServiceLoader
.load(Class.forName(QuarkusTestBeforeClassCallback.class.getName(), false, classLoader), classLoader);
Expand All @@ -495,6 +501,11 @@ private void populateCallbacks(ClassLoader classLoader) throws ClassNotFoundExce
for (Object quarkusTestAfterEach : quarkusTestAfterEachLoader) {
afterEachCallbacks.add(quarkusTestAfterEach);
}
ServiceLoader<?> quarkusTestAfterAllLoader = ServiceLoader
.load(Class.forName(QuarkusTestAfterAllCallback.class.getName(), false, classLoader), classLoader);
for (Object quarkusTestAfterAll : quarkusTestAfterAllLoader) {
afterAllCallbacks.add(quarkusTestAfterAll);
}
}

private void populateTestMethodInvokers(ClassLoader quarkusClassLoader) {
Expand Down Expand Up @@ -640,9 +651,9 @@ public void afterEach(ExtensionContext context) throws Exception {
throw new RuntimeException("Could not find method " + originalTestMethod + " on test class");
}

Constructor<?> constructor = quarkusTestMethodContextClass.getConstructor(Object.class, Method.class);
Constructor<?> constructor = quarkusTestMethodContextClass.getConstructor(Object.class, Object.class, Method.class);
return new AbstractMap.SimpleEntry<>(quarkusTestMethodContextClass,
constructor.newInstance(actualTestInstance, actualTestMethod));
constructor.newInstance(actualTestInstance, outerInstance, actualTestMethod));
}

private boolean isNativeOrIntegrationTest(Class<?> clazz) {
Expand Down Expand Up @@ -849,12 +860,13 @@ private void initTestState(ExtensionContext extensionContext, ExtensionState sta
Class<?> previousActualTestClass = actualTestClass;
actualTestClass = Class.forName(extensionContext.getRequiredTestClass().getName(), true,
Thread.currentThread().getContextClassLoader());
outerInstance = null;
if (extensionContext.getRequiredTestClass().isAnnotationPresent(Nested.class)) {
Class<?> parent = actualTestClass.getEnclosingClass();
Object parentInstance = runningQuarkusApplication.instance(parent);
Constructor<?> declaredConstructor = actualTestClass.getDeclaredConstructor(parent);
Class<?> outerClass = actualTestClass.getEnclosingClass();
outerInstance = runningQuarkusApplication.instance(outerClass);
Constructor<?> declaredConstructor = actualTestClass.getDeclaredConstructor(outerClass);
declaredConstructor.setAccessible(true);
actualTestInstance = declaredConstructor.newInstance(parentInstance);
actualTestInstance = declaredConstructor.newInstance(outerInstance);
} else {
actualTestInstance = runningQuarkusApplication.instance(actualTestClass);
}
Expand All @@ -868,6 +880,12 @@ private void initTestState(ExtensionContext extensionContext, ExtensionState sta
afterConstructCallback.getClass().getMethod("afterConstruct", Object.class).invoke(afterConstructCallback,
actualTestInstance);
}
if (outerInstance != null) {
for (Object afterConstructCallback : afterConstructCallbacks) {
afterConstructCallback.getClass().getMethod("afterConstruct", Object.class).invoke(afterConstructCallback,
outerInstance);
}
}
} catch (Exception e) {
throw new TestInstantiationException("Failed to create test instance", e);
}
Expand Down Expand Up @@ -1102,6 +1120,7 @@ private Method determineTCCLExtensionMethod(ReflectiveInvocationContext<Method>
@Override
public void afterAll(ExtensionContext context) throws Exception {
resetHangTimeout();
runAfterAllCallbacks(context);
try {
if (!isNativeOrIntegrationTest(context.getRequiredTestClass()) && (runningQuarkusApplication != null)) {
popMockContext();
Expand All @@ -1111,6 +1130,31 @@ public void afterAll(ExtensionContext context) throws Exception {
}
} finally {
currentTestClassStack.pop();
outerInstance = null;
}
}

private void runAfterAllCallbacks(ExtensionContext context) throws Exception {
if (isNativeOrIntegrationTest(context.getRequiredTestClass())) {
return;
}
if (afterAllCallbacks.isEmpty()) {
return;
}

Class<?> quarkusTestContextClass = Class.forName(QuarkusTestContext.class.getName(), true,
runningQuarkusApplication.getClassLoader());
Object quarkusTestContextInstance = quarkusTestContextClass.getConstructor(Object.class, Object.class)
.newInstance(actualTestInstance, outerInstance);

ClassLoader original = setCCL(runningQuarkusApplication.getClassLoader());
try {
for (Object afterAllCallback : afterAllCallbacks) {
afterAllCallback.getClass().getMethod("afterAll", quarkusTestContextClass)
.invoke(afterAllCallback, quarkusTestContextInstance);
}
} finally {
setCCL(original);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.quarkus.test.junit.callback;

/**
* Can be implemented by classes that shall be called after all test methods in a {@code @QuarkusTest} have been run.
* <p>
* The implementing class has to be {@linkplain java.util.ServiceLoader deployed as service provider on the class path}.
*/
public interface QuarkusTestAfterAllCallback {

void afterAll(QuarkusTestContext context);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.quarkus.test.junit.callback;

/**
* Context object passed to {@link QuarkusTestAfterAllCallback}
*/
public class QuarkusTestContext {

private final Object testInstance;
private final Object outerInstance;

public QuarkusTestContext(Object testInstance, Object outerInstance) {
this.testInstance = testInstance;
this.outerInstance = outerInstance;
}

public Object getTestInstance() {
return testInstance;
}

public Object getOuterInstance() {
return outerInstance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,15 @@
/**
* Context object passed to {@link QuarkusTestBeforeEachCallback} and {@link QuarkusTestAfterEachCallback}
*/
public final class QuarkusTestMethodContext {
public final class QuarkusTestMethodContext extends QuarkusTestContext {

private final Object testInstance;
private final Method testMethod;

public QuarkusTestMethodContext(Object testInstance, Method testMethod) {
this.testInstance = testInstance;
public QuarkusTestMethodContext(Object testInstance, Object outerInstance, Method testMethod) {
super(testInstance, outerInstance);
this.testMethod = testMethod;
}

public Object getTestInstance() {
return testInstance;
}

public Method getTestMethod() {
return testMethod;
}
Expand Down

0 comments on commit 663e040

Please sign in to comment.