From 8a6c1ba1987e4fe1bee0ee9bcc8ef4b3b2764313 Mon Sep 17 00:00:00 2001 From: Sam Brannen Date: Sat, 20 Aug 2022 15:01:53 +0200 Subject: [PATCH] Introduce AOT run-time support in the TestContext framework This commit introduces initial AOT run-time support in the Spring TestContext Framework. - DefaultCacheAwareContextLoaderDelegate: when running in AOT mode, now loads a test's ApplicationContext via the AotContextLoader SPI instead of via the standard SmartContextLoader and ContextLoader SPIs. - DependencyInjectionTestExecutionListener: when running in AOT mode, now injects dependencies into a test instance using a local instance of AutowiredAnnotationBeanPostProcessor instead of relying on AutowireCapableBeanFactory support. Closes gh-28205 --- ...efaultCacheAwareContextLoaderDelegate.java | 58 ++++++++- ...endencyInjectionTestExecutionListener.java | 62 +++++++++- .../test/context/aot/AotSmokeTests.java | 115 ++++++++++++++++-- .../src/test/resources/log4j2-test.xml | 1 + 4 files changed, 221 insertions(+), 15 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/test/context/cache/DefaultCacheAwareContextLoaderDelegate.java b/spring-test/src/main/java/org/springframework/test/context/cache/DefaultCacheAwareContextLoaderDelegate.java index 84b53186c2d4..2bf58b244ccb 100644 --- a/spring-test/src/main/java/org/springframework/test/context/cache/DefaultCacheAwareContextLoaderDelegate.java +++ b/spring-test/src/main/java/org/springframework/test/context/cache/DefaultCacheAwareContextLoaderDelegate.java @@ -19,13 +19,21 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.aot.AotDetector; import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.log.LogMessage; import org.springframework.lang.Nullable; import org.springframework.test.annotation.DirtiesContext.HierarchyMode; import org.springframework.test.context.CacheAwareContextLoaderDelegate; import org.springframework.test.context.ContextLoader; import org.springframework.test.context.MergedContextConfiguration; import org.springframework.test.context.SmartContextLoader; +import org.springframework.test.context.aot.AotContextLoader; +import org.springframework.test.context.aot.AotTestMappings; +import org.springframework.test.context.aot.TestContextAotException; import org.springframework.util.Assert; /** @@ -48,6 +56,8 @@ public class DefaultCacheAwareContextLoaderDelegate implements CacheAwareContext */ static final ContextCache defaultContextCache = new DefaultContextCache(); + private final AotTestMappings aotTestMappings = getAotTestMappings(); + private final ContextCache contextCache; @@ -87,7 +97,12 @@ public ApplicationContext loadContext(MergedContextConfiguration mergedContextCo ApplicationContext context = this.contextCache.get(mergedContextConfiguration); if (context == null) { try { - context = loadContextInternal(mergedContextConfiguration); + if (runningInAotMode(mergedContextConfiguration.getTestClass())) { + context = loadContextInAotMode(mergedContextConfiguration); + } + else { + context = loadContextInternal(mergedContextConfiguration); + } if (logger.isDebugEnabled()) { logger.debug(String.format("Storing ApplicationContext [%s] in cache under key [%s]", System.identityHashCode(context), mergedContextConfiguration)); @@ -149,4 +164,45 @@ protected ApplicationContext loadContextInternal(MergedContextConfiguration merg } } + protected ApplicationContext loadContextInAotMode(MergedContextConfiguration mergedConfig) throws Exception { + Class testClass = mergedConfig.getTestClass(); + ApplicationContextInitializer contextInitializer = + this.aotTestMappings.getContextInitializer(testClass); + Assert.state(contextInitializer != null, + () -> "Failed to load AOT ApplicationContextInitializer for test class [%s]" + .formatted(testClass.getName())); + logger.info(LogMessage.format("Loading ApplicationContext in AOT mode for %s", mergedConfig)); + ContextLoader contextLoader = mergedConfig.getContextLoader(); + if (!((contextLoader instanceof AotContextLoader aotContextLoader) && + (aotContextLoader.loadContextForAotRuntime(mergedConfig, contextInitializer) + instanceof GenericApplicationContext gac))) { + throw new TestContextAotException(""" + Cannot load ApplicationContext for AOT runtime for %s. The configured \ + ContextLoader [%s] must be an AotContextLoader and must create a \ + GenericApplicationContext.""" + .formatted(mergedConfig, contextLoader.getClass().getName())); + } + gac.registerShutdownHook(); + return gac; + } + + /** + * Determine if we are running in AOT mode for the supplied test class. + */ + private boolean runningInAotMode(Class testClass) { + return (this.aotTestMappings != null && this.aotTestMappings.isSupportedTestClass(testClass)); + } + + private static AotTestMappings getAotTestMappings() { + if (AotDetector.useGeneratedArtifacts()) { + try { + return new AotTestMappings(); + } + catch (Exception ex) { + throw new IllegalStateException("Failed to instantiate AotTestMappings", ex); + } + } + return null; + } + } diff --git a/spring-test/src/main/java/org/springframework/test/context/support/DependencyInjectionTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/support/DependencyInjectionTestExecutionListener.java index 43d6e7afa3b2..e440c4c399bc 100644 --- a/spring-test/src/main/java/org/springframework/test/context/support/DependencyInjectionTestExecutionListener.java +++ b/spring-test/src/main/java/org/springframework/test/context/support/DependencyInjectionTestExecutionListener.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,15 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.aot.AotDetector; +import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; import org.springframework.beans.factory.config.AutowireCapableBeanFactory; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.context.ApplicationContext; +import org.springframework.context.support.GenericApplicationContext; import org.springframework.core.Conventions; import org.springframework.test.context.TestContext; +import org.springframework.test.context.aot.AotTestMappings; /** * {@code TestExecutionListener} which provides support for dependency @@ -53,6 +59,8 @@ public class DependencyInjectionTestExecutionListener extends AbstractTestExecut private static final Log logger = LogFactory.getLog(DependencyInjectionTestExecutionListener.class); + private final AotTestMappings aotTestMappings = getAotTestMappings(); + /** * Returns {@code 2000}. @@ -78,9 +86,14 @@ public final int getOrder() { @Override public void prepareTestInstance(TestContext testContext) throws Exception { if (logger.isDebugEnabled()) { - logger.debug("Performing dependency injection for test context [" + testContext + "]."); + logger.debug("Performing dependency injection for test context " + testContext); + } + if (runningInAotMode(testContext.getTestClass())) { + injectDependenciesInAotMode(testContext); + } + else { + injectDependencies(testContext); } - injectDependencies(testContext); } /** @@ -96,7 +109,12 @@ public void beforeTestMethod(TestContext testContext) throws Exception { if (logger.isDebugEnabled()) { logger.debug("Reinjecting dependencies for test context [" + testContext + "]."); } - injectDependencies(testContext); + if (runningInAotMode(testContext.getTestClass())) { + injectDependenciesInAotMode(testContext); + } + else { + injectDependencies(testContext); + } } } @@ -121,4 +139,40 @@ protected void injectDependencies(TestContext testContext) throws Exception { testContext.removeAttribute(REINJECT_DEPENDENCIES_ATTRIBUTE); } + private void injectDependenciesInAotMode(TestContext testContext) throws Exception { + ApplicationContext applicationContext = testContext.getApplicationContext(); + if (!(applicationContext instanceof GenericApplicationContext gac)) { + throw new IllegalStateException("AOT ApplicationContext must be a GenericApplicationContext instead of " + + applicationContext.getClass().getName()); + } + + Object bean = testContext.getTestInstance(); + Class clazz = testContext.getTestClass(); + ConfigurableListableBeanFactory beanFactory = gac.getBeanFactory(); + AutowiredAnnotationBeanPostProcessor beanPostProcessor = new AutowiredAnnotationBeanPostProcessor(); + beanPostProcessor.setBeanFactory(beanFactory); + beanPostProcessor.processInjection(bean); + beanFactory.initializeBean(bean, clazz.getName() + AutowireCapableBeanFactory.ORIGINAL_INSTANCE_SUFFIX); + testContext.removeAttribute(REINJECT_DEPENDENCIES_ATTRIBUTE); + } + + /** + * Determine if we are running in AOT mode for the supplied test class. + */ + private boolean runningInAotMode(Class testClass) { + return (this.aotTestMappings != null && this.aotTestMappings.isSupportedTestClass(testClass)); + } + + private static AotTestMappings getAotTestMappings() { + if (AotDetector.useGeneratedArtifacts()) { + try { + return new AotTestMappings(); + } + catch (Exception ex) { + throw new IllegalStateException("Failed to instantiate AotTestMappings", ex); + } + } + return null; + } + } diff --git a/spring-test/src/test/java/org/springframework/test/context/aot/AotSmokeTests.java b/spring-test/src/test/java/org/springframework/test/context/aot/AotSmokeTests.java index f3e762afc4d6..dfddfbbe8980 100644 --- a/spring-test/src/test/java/org/springframework/test/context/aot/AotSmokeTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/aot/AotSmokeTests.java @@ -16,16 +16,34 @@ package org.springframework.test.context.aot; + +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; import java.util.List; +import java.util.Set; import java.util.stream.Stream; import org.junit.jupiter.api.Test; +import org.junit.platform.launcher.LauncherDiscoveryRequest; +import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder; +import org.junit.platform.launcher.core.LauncherFactory; +import org.junit.platform.launcher.listeners.SummaryGeneratingListener; +import org.junit.platform.launcher.listeners.TestExecutionSummary; +import org.junit.platform.launcher.listeners.TestExecutionSummary.Failure; +import org.opentest4j.MultipleFailuresError; +import org.springframework.aot.AotDetector; import org.springframework.aot.generate.GeneratedFiles.Kind; import org.springframework.aot.generate.InMemoryGeneratedFiles; +import org.springframework.aot.test.generator.compile.CompileWithTargetClassAccess; import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.test.context.aot.samples.basic.BasicSpringJupiterSharedConfigTests; +import org.springframework.test.context.aot.samples.basic.BasicSpringJupiterTests; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.platform.engine.discovery.DiscoverySelectors.selectClass; +import static org.junit.platform.launcher.EngineFilter.includeEngines; /** * Smoke tests for AOT support in the TestContext framework. @@ -33,27 +51,104 @@ * @author Sam Brannen * @since 6.0 */ +@CompileWithTargetClassAccess class AotSmokeTests extends AbstractAotTests { + private static final String CLASSPATH_ROOT = "AotSmokeTests.classpath_root"; + + // We have to determine the classpath root and store it in a system property + // since @CompileWithTargetClassAccess uses a custom ClassLoader that does + // not support CodeSource. + // + // The system property will only be set when this class is loaded by the + // original ClassLoader used to launch the JUnit Platform. The attempt to + // access the CodeSource will fail when the tests are executed in the + // nested JUnit Platform launched by the CompileWithTargetClassAccessExtension. + static { + try { + Path classpathRoot = Paths.get(AotSmokeTests.class.getProtectionDomain().getCodeSource().getLocation().toURI()); + System.setProperty(CLASSPATH_ROOT, classpathRoot.toFile().getCanonicalPath()); + } + catch (Exception ex) { + // ignore + } + } + + @Test - // Using @CompileWithTargetClassAccess results in the following exception in classpathRoots(): - // java.lang.NullPointerException: Cannot invoke "java.net.URL.toURI()" because the return - // value of "java.security.CodeSource.getLocation()" is null - void scanClassPathThenGenerateSourceFilesAndCompileThem() { - Stream> testClasses = scan("org.springframework.test.context.aot.samples.basic"); + void endToEndTests() { + // AOT BUILD-TIME: CLASSPATH SCANNING + Stream> testClasses = createTestClassScanner() + .scan("org.springframework.test.context.aot.samples.basic") + // This test focuses solely on JUnit Jupiter tests + .filter(sourceFile -> sourceFile.getName().contains("Jupiter")); + + // AOT BUILD-TIME: PROCESSING InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); TestContextAotGenerator generator = new TestContextAotGenerator(generatedFiles); - generator.processAheadOfTime(testClasses); List sourceFiles = generatedFiles.getGeneratedFiles(Kind.SOURCE).keySet().stream().toList(); - assertThat(sourceFiles).containsExactlyInAnyOrder(expectedSourceFilesForBasicSpringTests); + assertThat(sourceFiles).containsExactlyInAnyOrder(expectedSourceFilesForBasicSpringJupiterTests); + // AOT BUILD-TIME: COMPILATION TestCompiler.forSystem().withFiles(generatedFiles) // .printFiles(System.out) - .compile(compiled -> { - // just make sure compilation completes without errors - }); + .compile(compiled -> + // AOT RUN-TIME: EXECUTION + runTestsInAotMode(BasicSpringJupiterTests.class, BasicSpringJupiterSharedConfigTests.class)); } + + private static void runTestsInAotMode(Class... testClasses) { + try { + System.setProperty(AotDetector.AOT_ENABLED, "true"); + + LauncherDiscoveryRequestBuilder builder = LauncherDiscoveryRequestBuilder.request() + .filters(includeEngines("junit-jupiter")); + Arrays.stream(testClasses).forEach(testClass -> builder.selectors(selectClass(testClass))); + LauncherDiscoveryRequest request = builder.build(); + SummaryGeneratingListener listener = new SummaryGeneratingListener(); + LauncherFactory.create().execute(request, listener); + TestExecutionSummary summary = listener.getSummary(); + if (summary.getTotalFailureCount() > 0) { + List exceptions = summary.getFailures().stream().map(Failure::getException).toList(); + throw new MultipleFailuresError("Test execution failures", exceptions); + } + } + finally { + System.clearProperty(AotDetector.AOT_ENABLED); + } + } + + private static TestClassScanner createTestClassScanner() { + String classpathRoot = System.getProperty(CLASSPATH_ROOT); + assertThat(classpathRoot).as(CLASSPATH_ROOT).isNotNull(); + Set classpathRoots = Set.of(Paths.get(classpathRoot)); + return new TestClassScanner(classpathRoots); + } + + private static final String[] expectedSourceFilesForBasicSpringJupiterTests = { + // Global + "org/springframework/test/context/aot/AotTestMappings__Generated.java", + // BasicSpringJupiterSharedConfigTests + "org/springframework/context/event/DefaultEventListenerFactory__TestContext001_BeanDefinitions.java", + "org/springframework/context/event/EventListenerMethodProcessor__TestContext001_BeanDefinitions.java", + "org/springframework/test/context/aot/samples/basic/BasicSpringJupiterSharedConfigTests__TestContext001_ApplicationContextInitializer.java", + "org/springframework/test/context/aot/samples/basic/BasicSpringJupiterSharedConfigTests__TestContext001_BeanFactoryRegistrations.java", + "org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext001_BeanDefinitions.java", + // BasicSpringJupiterTests -- not generated b/c already generated for BasicSpringJupiterSharedConfigTests. + // "org/springframework/context/event/DefaultEventListenerFactory__TestContext00?_BeanDefinitions.java", + // "org/springframework/context/event/EventListenerMethodProcessor__TestContext00?_BeanDefinitions.java", + // "org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests__TestContext00?_ApplicationContextInitializer.java", + // "org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests__TestContext00?_BeanFactoryRegistrations.java", + // "org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext00?_BeanDefinitions.java", + // BasicSpringJupiterTests.NestedTests + "org/springframework/context/event/DefaultEventListenerFactory__TestContext002_BeanDefinitions.java", + "org/springframework/context/event/EventListenerMethodProcessor__TestContext002_BeanDefinitions.java", + "org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests_NestedTests__TestContext002_ApplicationContextInitializer.java", + "org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests_NestedTests__TestContext002_BeanFactoryRegistrations.java", + "org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext002_BeanDefinitions.java", + }; + } diff --git a/spring-test/src/test/resources/log4j2-test.xml b/spring-test/src/test/resources/log4j2-test.xml index d7f8b2b1776f..5a5ceae46949 100644 --- a/spring-test/src/test/resources/log4j2-test.xml +++ b/spring-test/src/test/resources/log4j2-test.xml @@ -16,6 +16,7 @@ +