Skip to content

Commit

Permalink
Introduce AOT run-time support in the TestContext framework
Browse files Browse the repository at this point in the history
This commit introduces initial AOT run-time support in the Spring
TestContext Framework.

- DefaultCacheAwareContextLoaderDelegate: when running in AOT mode, now
  loads a test's ApplicationContext via the AotContextLoader SPI
  instead of via the standard SmartContextLoader and ContextLoader SPIs.

- DependencyInjectionTestExecutionListener: when running in AOT mode,
  now injects dependencies into a test instance using a local instance
  of AutowiredAnnotationBeanPostProcessor instead of relying on
  AutowireCapableBeanFactory support.

Closes gh-28205
  • Loading branch information
sbrannen committed Aug 23, 2022
1 parent ada0880 commit 8a6c1ba
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,21 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.aot.AotDetector;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.log.LogMessage;
import org.springframework.lang.Nullable;
import org.springframework.test.annotation.DirtiesContext.HierarchyMode;
import org.springframework.test.context.CacheAwareContextLoaderDelegate;
import org.springframework.test.context.ContextLoader;
import org.springframework.test.context.MergedContextConfiguration;
import org.springframework.test.context.SmartContextLoader;
import org.springframework.test.context.aot.AotContextLoader;
import org.springframework.test.context.aot.AotTestMappings;
import org.springframework.test.context.aot.TestContextAotException;
import org.springframework.util.Assert;

/**
Expand All @@ -48,6 +56,8 @@ public class DefaultCacheAwareContextLoaderDelegate implements CacheAwareContext
*/
static final ContextCache defaultContextCache = new DefaultContextCache();

private final AotTestMappings aotTestMappings = getAotTestMappings();

private final ContextCache contextCache;


Expand Down Expand Up @@ -87,7 +97,12 @@ public ApplicationContext loadContext(MergedContextConfiguration mergedContextCo
ApplicationContext context = this.contextCache.get(mergedContextConfiguration);
if (context == null) {
try {
context = loadContextInternal(mergedContextConfiguration);
if (runningInAotMode(mergedContextConfiguration.getTestClass())) {
context = loadContextInAotMode(mergedContextConfiguration);
}
else {
context = loadContextInternal(mergedContextConfiguration);
}
if (logger.isDebugEnabled()) {
logger.debug(String.format("Storing ApplicationContext [%s] in cache under key [%s]",
System.identityHashCode(context), mergedContextConfiguration));
Expand Down Expand Up @@ -149,4 +164,45 @@ protected ApplicationContext loadContextInternal(MergedContextConfiguration merg
}
}

protected ApplicationContext loadContextInAotMode(MergedContextConfiguration mergedConfig) throws Exception {
Class<?> testClass = mergedConfig.getTestClass();
ApplicationContextInitializer<ConfigurableApplicationContext> contextInitializer =
this.aotTestMappings.getContextInitializer(testClass);
Assert.state(contextInitializer != null,
() -> "Failed to load AOT ApplicationContextInitializer for test class [%s]"
.formatted(testClass.getName()));
logger.info(LogMessage.format("Loading ApplicationContext in AOT mode for %s", mergedConfig));
ContextLoader contextLoader = mergedConfig.getContextLoader();
if (!((contextLoader instanceof AotContextLoader aotContextLoader) &&
(aotContextLoader.loadContextForAotRuntime(mergedConfig, contextInitializer)
instanceof GenericApplicationContext gac))) {
throw new TestContextAotException("""
Cannot load ApplicationContext for AOT runtime for %s. The configured \
ContextLoader [%s] must be an AotContextLoader and must create a \
GenericApplicationContext."""
.formatted(mergedConfig, contextLoader.getClass().getName()));
}
gac.registerShutdownHook();
return gac;
}

/**
* Determine if we are running in AOT mode for the supplied test class.
*/
private boolean runningInAotMode(Class<?> testClass) {
return (this.aotTestMappings != null && this.aotTestMappings.isSupportedTestClass(testClass));
}

private static AotTestMappings getAotTestMappings() {
if (AotDetector.useGeneratedArtifacts()) {
try {
return new AotTestMappings();
}
catch (Exception ex) {
throw new IllegalStateException("Failed to instantiate AotTestMappings", ex);
}
}
return null;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,9 +19,15 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.aot.AotDetector;
import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.Conventions;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.aot.AotTestMappings;

/**
* {@code TestExecutionListener} which provides support for dependency
Expand Down Expand Up @@ -53,6 +59,8 @@ public class DependencyInjectionTestExecutionListener extends AbstractTestExecut

private static final Log logger = LogFactory.getLog(DependencyInjectionTestExecutionListener.class);

private final AotTestMappings aotTestMappings = getAotTestMappings();


/**
* Returns {@code 2000}.
Expand All @@ -78,9 +86,14 @@ public final int getOrder() {
@Override
public void prepareTestInstance(TestContext testContext) throws Exception {
if (logger.isDebugEnabled()) {
logger.debug("Performing dependency injection for test context [" + testContext + "].");
logger.debug("Performing dependency injection for test context " + testContext);
}
if (runningInAotMode(testContext.getTestClass())) {
injectDependenciesInAotMode(testContext);
}
else {
injectDependencies(testContext);
}
injectDependencies(testContext);
}

/**
Expand All @@ -96,7 +109,12 @@ public void beforeTestMethod(TestContext testContext) throws Exception {
if (logger.isDebugEnabled()) {
logger.debug("Reinjecting dependencies for test context [" + testContext + "].");
}
injectDependencies(testContext);
if (runningInAotMode(testContext.getTestClass())) {
injectDependenciesInAotMode(testContext);
}
else {
injectDependencies(testContext);
}
}
}

Expand All @@ -121,4 +139,40 @@ protected void injectDependencies(TestContext testContext) throws Exception {
testContext.removeAttribute(REINJECT_DEPENDENCIES_ATTRIBUTE);
}

private void injectDependenciesInAotMode(TestContext testContext) throws Exception {
ApplicationContext applicationContext = testContext.getApplicationContext();
if (!(applicationContext instanceof GenericApplicationContext gac)) {
throw new IllegalStateException("AOT ApplicationContext must be a GenericApplicationContext instead of " +
applicationContext.getClass().getName());
}

Object bean = testContext.getTestInstance();
Class<?> clazz = testContext.getTestClass();
ConfigurableListableBeanFactory beanFactory = gac.getBeanFactory();
AutowiredAnnotationBeanPostProcessor beanPostProcessor = new AutowiredAnnotationBeanPostProcessor();
beanPostProcessor.setBeanFactory(beanFactory);
beanPostProcessor.processInjection(bean);
beanFactory.initializeBean(bean, clazz.getName() + AutowireCapableBeanFactory.ORIGINAL_INSTANCE_SUFFIX);
testContext.removeAttribute(REINJECT_DEPENDENCIES_ATTRIBUTE);
}

/**
* Determine if we are running in AOT mode for the supplied test class.
*/
private boolean runningInAotMode(Class<?> testClass) {
return (this.aotTestMappings != null && this.aotTestMappings.isSupportedTestClass(testClass));
}

private static AotTestMappings getAotTestMappings() {
if (AotDetector.useGeneratedArtifacts()) {
try {
return new AotTestMappings();
}
catch (Exception ex) {
throw new IllegalStateException("Failed to instantiate AotTestMappings", ex);
}
}
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,139 @@

package org.springframework.test.context.aot;


import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.platform.launcher.LauncherDiscoveryRequest;
import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder;
import org.junit.platform.launcher.core.LauncherFactory;
import org.junit.platform.launcher.listeners.SummaryGeneratingListener;
import org.junit.platform.launcher.listeners.TestExecutionSummary;
import org.junit.platform.launcher.listeners.TestExecutionSummary.Failure;
import org.opentest4j.MultipleFailuresError;

import org.springframework.aot.AotDetector;
import org.springframework.aot.generate.GeneratedFiles.Kind;
import org.springframework.aot.generate.InMemoryGeneratedFiles;
import org.springframework.aot.test.generator.compile.CompileWithTargetClassAccess;
import org.springframework.aot.test.generator.compile.TestCompiler;
import org.springframework.test.context.aot.samples.basic.BasicSpringJupiterSharedConfigTests;
import org.springframework.test.context.aot.samples.basic.BasicSpringJupiterTests;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.platform.engine.discovery.DiscoverySelectors.selectClass;
import static org.junit.platform.launcher.EngineFilter.includeEngines;

/**
* Smoke tests for AOT support in the TestContext framework.
*
* @author Sam Brannen
* @since 6.0
*/
@CompileWithTargetClassAccess
class AotSmokeTests extends AbstractAotTests {

private static final String CLASSPATH_ROOT = "AotSmokeTests.classpath_root";

// We have to determine the classpath root and store it in a system property
// since @CompileWithTargetClassAccess uses a custom ClassLoader that does
// not support CodeSource.
//
// The system property will only be set when this class is loaded by the
// original ClassLoader used to launch the JUnit Platform. The attempt to
// access the CodeSource will fail when the tests are executed in the
// nested JUnit Platform launched by the CompileWithTargetClassAccessExtension.
static {
try {
Path classpathRoot = Paths.get(AotSmokeTests.class.getProtectionDomain().getCodeSource().getLocation().toURI());
System.setProperty(CLASSPATH_ROOT, classpathRoot.toFile().getCanonicalPath());
}
catch (Exception ex) {
// ignore
}
}


@Test
// Using @CompileWithTargetClassAccess results in the following exception in classpathRoots():
// java.lang.NullPointerException: Cannot invoke "java.net.URL.toURI()" because the return
// value of "java.security.CodeSource.getLocation()" is null
void scanClassPathThenGenerateSourceFilesAndCompileThem() {
Stream<Class<?>> testClasses = scan("org.springframework.test.context.aot.samples.basic");
void endToEndTests() {
// AOT BUILD-TIME: CLASSPATH SCANNING
Stream<Class<?>> testClasses = createTestClassScanner()
.scan("org.springframework.test.context.aot.samples.basic")
// This test focuses solely on JUnit Jupiter tests
.filter(sourceFile -> sourceFile.getName().contains("Jupiter"));

// AOT BUILD-TIME: PROCESSING
InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles();
TestContextAotGenerator generator = new TestContextAotGenerator(generatedFiles);

generator.processAheadOfTime(testClasses);

List<String> sourceFiles = generatedFiles.getGeneratedFiles(Kind.SOURCE).keySet().stream().toList();
assertThat(sourceFiles).containsExactlyInAnyOrder(expectedSourceFilesForBasicSpringTests);
assertThat(sourceFiles).containsExactlyInAnyOrder(expectedSourceFilesForBasicSpringJupiterTests);

// AOT BUILD-TIME: COMPILATION
TestCompiler.forSystem().withFiles(generatedFiles)
// .printFiles(System.out)
.compile(compiled -> {
// just make sure compilation completes without errors
});
.compile(compiled ->
// AOT RUN-TIME: EXECUTION
runTestsInAotMode(BasicSpringJupiterTests.class, BasicSpringJupiterSharedConfigTests.class));
}


private static void runTestsInAotMode(Class<?>... testClasses) {
try {
System.setProperty(AotDetector.AOT_ENABLED, "true");

LauncherDiscoveryRequestBuilder builder = LauncherDiscoveryRequestBuilder.request()
.filters(includeEngines("junit-jupiter"));
Arrays.stream(testClasses).forEach(testClass -> builder.selectors(selectClass(testClass)));
LauncherDiscoveryRequest request = builder.build();
SummaryGeneratingListener listener = new SummaryGeneratingListener();
LauncherFactory.create().execute(request, listener);
TestExecutionSummary summary = listener.getSummary();
if (summary.getTotalFailureCount() > 0) {
List<Throwable> exceptions = summary.getFailures().stream().map(Failure::getException).toList();
throw new MultipleFailuresError("Test execution failures", exceptions);
}
}
finally {
System.clearProperty(AotDetector.AOT_ENABLED);
}
}

private static TestClassScanner createTestClassScanner() {
String classpathRoot = System.getProperty(CLASSPATH_ROOT);
assertThat(classpathRoot).as(CLASSPATH_ROOT).isNotNull();
Set<Path> classpathRoots = Set.of(Paths.get(classpathRoot));
return new TestClassScanner(classpathRoots);
}

private static final String[] expectedSourceFilesForBasicSpringJupiterTests = {
// Global
"org/springframework/test/context/aot/AotTestMappings__Generated.java",
// BasicSpringJupiterSharedConfigTests
"org/springframework/context/event/DefaultEventListenerFactory__TestContext001_BeanDefinitions.java",
"org/springframework/context/event/EventListenerMethodProcessor__TestContext001_BeanDefinitions.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringJupiterSharedConfigTests__TestContext001_ApplicationContextInitializer.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringJupiterSharedConfigTests__TestContext001_BeanFactoryRegistrations.java",
"org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext001_BeanDefinitions.java",
// BasicSpringJupiterTests -- not generated b/c already generated for BasicSpringJupiterSharedConfigTests.
// "org/springframework/context/event/DefaultEventListenerFactory__TestContext00?_BeanDefinitions.java",
// "org/springframework/context/event/EventListenerMethodProcessor__TestContext00?_BeanDefinitions.java",
// "org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests__TestContext00?_ApplicationContextInitializer.java",
// "org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests__TestContext00?_BeanFactoryRegistrations.java",
// "org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext00?_BeanDefinitions.java",
// BasicSpringJupiterTests.NestedTests
"org/springframework/context/event/DefaultEventListenerFactory__TestContext002_BeanDefinitions.java",
"org/springframework/context/event/EventListenerMethodProcessor__TestContext002_BeanDefinitions.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests_NestedTests__TestContext002_ApplicationContextInitializer.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests_NestedTests__TestContext002_BeanFactoryRegistrations.java",
"org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext002_BeanDefinitions.java",
};

}
1 change: 1 addition & 0 deletions spring-test/src/test/resources/log4j2-test.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<Logger name="org.springframework.test.context.ContextLoaderUtils" level="warn" />
<Logger name="org.springframework.test.context.aot" level="debug" />
<Logger name="org.springframework.test.context.cache" level="warn" />
<Logger name="org.springframework.test.context.cache.DefaultCacheAwareContextLoaderDelegate" level="info" />
<Logger name="org.springframework.test.context.junit4.rules" level="warn" />
<Logger name="org.springframework.test.context.transaction.TransactionalTestExecutionListener" level="warn" />
<Logger name="org.springframework.test.context.web" level="warn" />
Expand Down

0 comments on commit 8a6c1ba

Please sign in to comment.