Skip to content

Commit

Permalink
Introduce initial support for processing test contexts ahead-of-time
Browse files Browse the repository at this point in the history
This commit introduces TestContextAotGenerator for processing Spring
integration test classes and generating AOT artifacts. Specifically,
this class performs the following.

- bootstraps the TCF for a given test class
- builds the MergedContextConfiguration for each test class and tracks
  all test classes that share the same MergedContextConfiguration
- loads each test ApplicationContext without refreshing it
- passes the test ApplicationContext to ApplicationContextAotGenerator
  to generate the AOT optimized ApplicationContextInitializer
  - The GenerationContext passed to ApplicationContextAotGenerator uses
    a feature name of the form "TestContext###_", where "###" is a
    3-digit sequence ID left padded with zeros.

This commit also includes tests using the TestCompiler to verify that
each generated ApplicationContextInitializer can be used to populate a
GenericApplicationContext as expected.

See spring-projectsgh-28204
  • Loading branch information
sbrannen committed Aug 23, 2022
1 parent 0973348 commit c1a6bfc
Show file tree
Hide file tree
Showing 12 changed files with 618 additions and 27 deletions.
1 change: 1 addition & 0 deletions spring-test/spring-test.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies {
optional("io.projectreactor:reactor-test")
optional("org.jetbrains.kotlinx:kotlinx-coroutines-core")
optional("org.jetbrains.kotlinx:kotlinx-coroutines-reactor")
testImplementation(project(":spring-core-test"))
testImplementation(project(":spring-context-support"))
testImplementation(project(":spring-oxm"))
testImplementation(testFixtures(project(":spring-beans")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.lang.annotation.Annotation;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
Expand Down Expand Up @@ -156,7 +157,8 @@ Stream<Class<?>> scan(String... packageNames) {
.map(this::getJavaClass)
.flatMap(Optional::stream)
.filter(this::isSpringTestClass)
.distinct();
.distinct()
.sorted(Comparator.comparing(Class::getName));
}

private Optional<Class<?>> getJavaClass(ClassSource classSource) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.context.aot;

/**
* Thrown if an error occurs during AOT build-time processing or AOT run-time
* execution in the <em>Spring TestContext Framework</em>.
*
* @author Sam Brannen
* @since 6.0
*/
@SuppressWarnings("serial")
public class TestContextAotException extends RuntimeException {

/**
* Create a new {@code TestContextAotException}.
* @param message the detail message
*/
public TestContextAotException(String message) {
super(message);
}

/**
* Create a new {@code TestContextAotException}.
* @param message the detail message
* @param cause the root cause
*/
public TestContextAotException(String message, Throwable cause) {
super(message, cause);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.context.aot;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.aot.generate.ClassNameGenerator;
import org.springframework.aot.generate.DefaultGenerationContext;
import org.springframework.aot.generate.GeneratedFiles;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.aot.ApplicationContextAotGenerator;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.javapoet.ClassName;
import org.springframework.test.context.BootstrapUtils;
import org.springframework.test.context.ContextLoader;
import org.springframework.test.context.MergedContextConfiguration;
import org.springframework.test.context.SmartContextLoader;
import org.springframework.test.context.TestContextBootstrapper;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

/**
* {@code TestContextAotGenerator} generates AOT artifacts for integration tests
* that depend on support from the <em>Spring TestContext Framework</em>.
*
* @author Sam Brannen
* @since 6.0
* @see ApplicationContextAotGenerator
*/
class TestContextAotGenerator {

private static final Log logger = LogFactory.getLog(TestClassScanner.class);

private final ApplicationContextAotGenerator aotGenerator = new ApplicationContextAotGenerator();

private final AtomicInteger sequence = new AtomicInteger();

private final GeneratedFiles generatedFiles;

private final RuntimeHints runtimeHints;


/**
* Create a new {@link TestContextAotGenerator} that uses the supplied
* {@link GeneratedFiles}.
* @param generatedFiles the {@code GeneratedFiles} to use
*/
public TestContextAotGenerator(GeneratedFiles generatedFiles) {
this(generatedFiles, new RuntimeHints());
}

/**
* Create a new {@link TestContextAotGenerator} that uses the supplied
* {@link GeneratedFiles} and {@link RuntimeHints}.
* @param generatedFiles the {@code GeneratedFiles} to use
* @param runtimeHints the {@code RuntimeHints} to use
*/
public TestContextAotGenerator(GeneratedFiles generatedFiles, RuntimeHints runtimeHints) {
this.generatedFiles = generatedFiles;
this.runtimeHints = runtimeHints;
}


/**
* Get the {@link RuntimeHints} gathered during {@linkplain #processAheadOfTime(Stream)
* AOT processing}.
*/
public final RuntimeHints getRuntimeHints() {
return this.runtimeHints;
}

/**
* Process each of the supplied Spring integration test classes and generate
* AOT artifacts.
* @throws TestContextAotException if an error occurs during AOT processing
*/
public void processAheadOfTime(Stream<Class<?>> testClasses) throws TestContextAotException {
MultiValueMap<MergedContextConfiguration, Class<?>> map = new LinkedMultiValueMap<>();
testClasses.forEach(testClass -> map.add(buildMergedContextConfiguration(testClass), testClass));

map.forEach((mergedConfig, classes) -> {
// System.err.println(mergedConfig + " -> " + classes);
if (logger.isDebugEnabled()) {
logger.debug("Generating AOT artifacts for test classes [%s]"
.formatted(classes.stream().map(Class::getCanonicalName).toList()));
}
try {
// Use first test class discovered for a given unique MergedContextConfiguration.
Class<?> testClass = classes.get(0);
DefaultGenerationContext generationContext = createGenerationContext(testClass);
ClassName className = processAheadOfTime(mergedConfig, generationContext);
// TODO Store ClassName in a map analogous to TestContextAotProcessor in Spring Native.
generationContext.writeGeneratedContent();
}
catch (Exception ex) {
if (logger.isWarnEnabled()) {
logger.warn("Failed to generate AOT artifacts for test classes [%s]"
.formatted(classes.stream().map(Class::getCanonicalName).toList()), ex);
}
}
});
}

/**
* Process the specified {@link MergedContextConfiguration} ahead-of-time
* using the specified {@link GenerationContext}.
* <p>Return the {@link ClassName} of the {@link ApplicationContextInitializer}
* to use to restore an optimized state of the test application context for
* the given {@code MergedContextConfiguration}.
* @param mergedConfig the {@code MergedContextConfiguration} to process
* @param generationContext the generation context to use
* @return the {@link ClassName} for the generated {@code ApplicationContextInitializer}
* @throws TestContextAotException if an error occurs during AOT processing
*/
ClassName processAheadOfTime(MergedContextConfiguration mergedConfig,
GenerationContext generationContext) throws TestContextAotException {

GenericApplicationContext gac = loadContextForAotProcessing(mergedConfig);
try {
return this.aotGenerator.processAheadOfTime(gac, generationContext);
}
catch (Throwable ex) {
throw new TestContextAotException("Failed to process test class [%s] for AOT"
.formatted(mergedConfig.getTestClass().getCanonicalName()), ex);
}
}

/**
* Load the {@code GenericApplicationContext} for the supplied merged context
* configuration for AOT processing.
* <p>Only supports {@link SmartContextLoader SmartContextLoaders} that
* create {@link GenericApplicationContext GenericApplicationContexts}.
* @throws TestContextAotException if an error occurs while loading the application
* context or if one of the prerequisites is not met
* @see SmartContextLoader#loadContextForAotProcessing(MergedContextConfiguration)
*/
private GenericApplicationContext loadContextForAotProcessing(
MergedContextConfiguration mergedConfig) throws TestContextAotException {

Class<?> testClass = mergedConfig.getTestClass();
ContextLoader contextLoader = mergedConfig.getContextLoader();
Assert.notNull(contextLoader, """
Cannot load an ApplicationContext with a NULL 'contextLoader'. \
Consider annotating test class [%s] with @ContextConfiguration or \
@ContextHierarchy.""".formatted(testClass.getCanonicalName()));

if (contextLoader instanceof SmartContextLoader smartContextLoader) {
try {
ApplicationContext context = smartContextLoader.loadContextForAotProcessing(mergedConfig);
if (context instanceof GenericApplicationContext gac) {
return gac;
}
}
catch (Exception ex) {
throw new TestContextAotException(
"Failed to load ApplicationContext for AOT processing for test class [%s]"
.formatted(testClass.getCanonicalName()), ex);
}
}
throw new TestContextAotException("""
Cannot generate AOT artifacts for test class [%s]. The configured \
ContextLoader [%s] must be a SmartContextLoader and must create a \
GenericApplicationContext.""".formatted(testClass.getCanonicalName(),
contextLoader.getClass().getName()));
}

MergedContextConfiguration buildMergedContextConfiguration(Class<?> testClass) {
TestContextBootstrapper testContextBootstrapper =
BootstrapUtils.resolveTestContextBootstrapper(testClass);
return testContextBootstrapper.buildMergedContextConfiguration();
}

DefaultGenerationContext createGenerationContext(Class<?> testClass) {
ClassNameGenerator classNameGenerator = new ClassNameGenerator(testClass);
DefaultGenerationContext generationContext =
new DefaultGenerationContext(classNameGenerator, this.generatedFiles, this.runtimeHints);
return generationContext.withName(nextTestContextId());
}

private String nextTestContextId() {
return "TestContext%03d_".formatted(this.sequence.incrementAndGet());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.context.aot;

import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Set;
import java.util.stream.Stream;

/**
* @author Sam Brannen
* @since 6.0
*/
abstract class AbstractAotTests {

static final String[] expectedSourceFilesForBasicSpringTests = {
// BasicSpringJupiterSharedConfigTests
"org/springframework/context/event/DefaultEventListenerFactory__TestContext001_BeanDefinitions.java",
"org/springframework/context/event/EventListenerMethodProcessor__TestContext001_BeanDefinitions.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringJupiterSharedConfigTests__TestContext001_ApplicationContextInitializer.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringJupiterSharedConfigTests__TestContext001_BeanFactoryRegistrations.java",
"org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext001_BeanDefinitions.java",
// BasicSpringJupiterTests -- not generated b/c already generated for BasicSpringJupiterSharedConfigTests.
// "org/springframework/context/event/DefaultEventListenerFactory__TestContext00?_BeanDefinitions.java",
// "org/springframework/context/event/EventListenerMethodProcessor__TestContext00?_BeanDefinitions.java",
// "org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests__TestContext00?_ApplicationContextInitializer.java",
// "org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests__TestContext00?_BeanFactoryRegistrations.java",
// "org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext00?_BeanDefinitions.java",
// BasicSpringJupiterTests.NestedTests
"org/springframework/context/event/DefaultEventListenerFactory__TestContext002_BeanDefinitions.java",
"org/springframework/context/event/EventListenerMethodProcessor__TestContext002_BeanDefinitions.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests_NestedTests__TestContext002_ApplicationContextInitializer.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringJupiterTests_NestedTests__TestContext002_BeanFactoryRegistrations.java",
"org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext002_BeanDefinitions.java",
// BasicSpringTestNGTests
"org/springframework/context/event/DefaultEventListenerFactory__TestContext003_BeanDefinitions.java",
"org/springframework/context/event/EventListenerMethodProcessor__TestContext003_BeanDefinitions.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringTestNGTests__TestContext003_ApplicationContextInitializer.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringTestNGTests__TestContext003_BeanFactoryRegistrations.java",
"org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext003_BeanDefinitions.java",
// BasicSpringVintageTests
"org/springframework/context/event/DefaultEventListenerFactory__TestContext004_BeanDefinitions.java",
"org/springframework/context/event/EventListenerMethodProcessor__TestContext004_BeanDefinitions.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringVintageTests__TestContext004_ApplicationContextInitializer.java",
"org/springframework/test/context/aot/samples/basic/BasicSpringVintageTests__TestContext004_BeanFactoryRegistrations.java",
"org/springframework/test/context/aot/samples/basic/BasicTestConfiguration__TestContext004_BeanDefinitions.java"
};

Stream<Class<?>> scan() {
return new TestClassScanner(classpathRoots()).scan();
}

Stream<Class<?>> scan(String... packageNames) {
return new TestClassScanner(classpathRoots()).scan(packageNames);
}

Set<Path> classpathRoots() {
try {
return Set.of(Paths.get(getClass().getProtectionDomain().getCodeSource().getLocation().toURI()));
}
catch (Exception ex) {
throw new RuntimeException(ex);
}
}

}
Loading

0 comments on commit c1a6bfc

Please sign in to comment.