Skip to content

Commit

Permalink
Generate direct mappings from AOT initializers to test classes
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrannen committed Oct 10, 2022
1 parent 3e33912 commit bca35dc
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
*
* <p>If we are not running in {@linkplain AotDetector#useGeneratedArtifacts()
* AOT mode} or if a test class is not {@linkplain #isSupportedTestClass(Class)
* supported} in AOT mode, {@link #getContextInitializer(Class)} will return
* {@code null}.
* supported} in AOT mode, {@link #getContextInitializer(Class)} and
* {@link #getContextInitializerClass(Class)} will return {@code null}.
*
* @author Sam Brannen
* @since 6.0
Expand All @@ -42,13 +42,20 @@ public class AotTestContextInitializers {

private final Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> contextInitializers;

private final Map<String, Class<ApplicationContextInitializer<?>>> contextInitializerClasses;


public AotTestContextInitializers() {
this(AotTestContextInitializersFactory.getContextInitializers());
this(AotTestContextInitializersFactory.getContextInitializers(),
AotTestContextInitializersFactory.getContextInitializerClasses());
}

AotTestContextInitializers(Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> contextInitializers) {
AotTestContextInitializers(
Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> contextInitializers,
Map<String, Class<ApplicationContextInitializer<?>>> contextInitializerClasses) {

this.contextInitializers = contextInitializers;
this.contextInitializerClasses = contextInitializerClasses;
}


Expand All @@ -67,6 +74,7 @@ public boolean isSupportedTestClass(Class<?> testClass) {
* @return the AOT context initializer, or {@code null} if there is no AOT context
* initializer for the specified test class
* @see #isSupportedTestClass(Class)
* @see #getContextInitializerClass(Class)
*/
@Nullable
public ApplicationContextInitializer<ConfigurableApplicationContext> getContextInitializer(Class<?> testClass) {
Expand All @@ -75,4 +83,17 @@ public ApplicationContextInitializer<ConfigurableApplicationContext> getContextI
return (supplier != null ? supplier.get() : null);
}

/**
* Get the AOT {@link ApplicationContextInitializer} {@link Class} for the
* specified test class.
* @return the AOT context initializer class, or {@code null} if there is no
* AOT context initializer for the specified test class
* @see #isSupportedTestClass(Class)
* @see #getContextInitializer(Class)
*/
@Nullable
public Class<ApplicationContextInitializer<?>> getContextInitializerClass(Class<?> testClass) {
return this.contextInitializerClasses.get(testClass.getName());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,38 @@ class AotTestContextInitializersCodeGenerator {

private static final Log logger = LogFactory.getLog(AotTestContextInitializersCodeGenerator.class);

// ApplicationContextInitializer<? extends ConfigurableApplicationContext>
private static final ParameterizedTypeName CONTEXT_INITIALIZER = ParameterizedTypeName.get(
ClassName.get(ApplicationContextInitializer.class),
WildcardTypeName.subtypeOf(ConfigurableApplicationContext.class));

// Supplier<ApplicationContextInitializer<? extends ConfigurableApplicationContext>>
private static final ParameterizedTypeName CONTEXT_INITIALIZER_SUPPLIER = ParameterizedTypeName
.get(ClassName.get(Supplier.class), CONTEXT_INITIALIZER);

// Map<String, Supplier<ApplicationContextInitializer<? extends ConfigurableApplicationContext>>>
private static final TypeName CONTEXT_SUPPLIER_MAP = ParameterizedTypeName
private static final TypeName CONTEXT_INITIALIZER_SUPPLIER_MAP = ParameterizedTypeName
.get(ClassName.get(Map.class), ClassName.get(String.class), CONTEXT_INITIALIZER_SUPPLIER);

// Class<ApplicationContextInitializer<?>
private static final ParameterizedTypeName CONTEXT_INITIALIZER_CLASS = ParameterizedTypeName
.get(ClassName.get(Class.class), WildcardTypeName.subtypeOf(
ParameterizedTypeName.get(ClassName.get(ApplicationContextInitializer.class),
WildcardTypeName.subtypeOf(Object.class))));

// Map<String, Class<ApplicationContextInitializer<?>>>
private static final TypeName CONTEXT_INITIALIZER_CLASS_MAP = ParameterizedTypeName
.get(ClassName.get(Map.class), ClassName.get(String.class), CONTEXT_INITIALIZER_CLASS);

private static final String GENERATED_SUFFIX = "Generated";

// TODO Consider an alternative means for specifying the name of the generated class.
// Ideally we would generate a class named: org.springframework.test.context.aot.GeneratedAotTestContextInitializers
static final String GENERATED_MAPPINGS_CLASS_NAME = AotTestContextInitializers.class.getName() + "__" + GENERATED_SUFFIX;

static final String GENERATED_MAPPINGS_METHOD_NAME = "getContextInitializers";
static final String GET_CONTEXT_INITIALIZERS_METHOD_NAME = "getContextInitializers";

static final String GET_CONTEXT_INITIALIZER_CLASSES_METHOD_NAME = "getContextInitializerClasses";


private final MultiValueMap<ClassName, Class<?>> initializerClassMappings;
Expand All @@ -92,24 +106,25 @@ private void generateType(TypeSpec.Builder type) {
this.generatedClass.getName().reflectionName()));
type.addJavadoc("Generated mappings for {@link $T}.", AotTestContextInitializers.class);
type.addModifiers(Modifier.PUBLIC);
type.addMethod(generateMappingMethod());
type.addMethod(contextInitializersMappingMethod());
type.addMethod(contextInitializerClassesMappingMethod());
}

private MethodSpec generateMappingMethod() {
MethodSpec.Builder method = MethodSpec.methodBuilder(GENERATED_MAPPINGS_METHOD_NAME);
private MethodSpec contextInitializersMappingMethod() {
MethodSpec.Builder method = MethodSpec.methodBuilder(GET_CONTEXT_INITIALIZERS_METHOD_NAME);
method.addModifiers(Modifier.PUBLIC, Modifier.STATIC);
method.returns(CONTEXT_SUPPLIER_MAP);
method.addCode(generateMappingCode());
method.returns(CONTEXT_INITIALIZER_SUPPLIER_MAP);
method.addCode(generateContextInitializersMappingCode());
return method.build();
}

private CodeBlock generateMappingCode() {
private CodeBlock generateContextInitializersMappingCode() {
CodeBlock.Builder code = CodeBlock.builder();
code.addStatement("$T map = new $T<>()", CONTEXT_SUPPLIER_MAP, HashMap.class);
code.addStatement("$T map = new $T<>()", CONTEXT_INITIALIZER_SUPPLIER_MAP, HashMap.class);
this.initializerClassMappings.forEach((className, testClasses) -> {
List<String> testClassNames = testClasses.stream().map(Class::getName).toList();
logger.debug(LogMessage.format(
"Generating mapping from AOT context initializer [%s] to test classes %s",
"Generating mapping from AOT context initializer supplier [%s] to test classes %s",
className.reflectionName(), testClassNames));
testClassNames.forEach(testClassName ->
code.addStatement("map.put($S, () -> new $T())", testClassName, className));
Expand All @@ -118,4 +133,27 @@ private CodeBlock generateMappingCode() {
return code.build();
}

private MethodSpec contextInitializerClassesMappingMethod() {
MethodSpec.Builder method = MethodSpec.methodBuilder(GET_CONTEXT_INITIALIZER_CLASSES_METHOD_NAME);
method.addModifiers(Modifier.PUBLIC, Modifier.STATIC);
method.returns(CONTEXT_INITIALIZER_CLASS_MAP);
method.addCode(generateContextInitializerClassesMappingCode());
return method.build();
}

private CodeBlock generateContextInitializerClassesMappingCode() {
CodeBlock.Builder code = CodeBlock.builder();
code.addStatement("$T map = new $T<>()", CONTEXT_INITIALIZER_CLASS_MAP, HashMap.class);
this.initializerClassMappings.forEach((className, testClasses) -> {
List<String> testClassNames = testClasses.stream().map(Class::getName).toList();
logger.debug(LogMessage.format(
"Generating mapping from AOT context initializer class [%s] to test classes %s",
className.reflectionName(), testClassNames));
testClassNames.forEach(testClassName ->
code.addStatement("map.put($S, $T.class)", testClassName, className));
});
code.addStatement("return map");
return code.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ final class AotTestContextInitializersFactory {
@Nullable
private static volatile Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> contextInitializers;

@Nullable
private static volatile Map<String, Class<ApplicationContextInitializer<?>>> contextInitializerClasses;


private AotTestContextInitializersFactory() {
}
Expand All @@ -59,20 +62,42 @@ static Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicatio
return initializers;
}

static Map<String, Class<ApplicationContextInitializer<?>>> getContextInitializerClasses() {
Map<String, Class<ApplicationContextInitializer<?>>> initializerClasses = contextInitializerClasses;
if (initializerClasses == null) {
synchronized (AotTestContextInitializersFactory.class) {
initializerClasses = contextInitializerClasses;
if (initializerClasses == null) {
initializerClasses = (AotDetector.useGeneratedArtifacts() ? loadContextInitializerClassesMap() : Map.of());
contextInitializerClasses = initializerClasses;
}
}
}
return initializerClasses;
}

/**
* Reset the factory.
* <p>Only for internal use.
*/
static void reset() {
synchronized (AotTestContextInitializersFactory.class) {
contextInitializers = null;
contextInitializerClasses = null;
}
}

@SuppressWarnings("unchecked")
private static Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> loadContextInitializersMap() {
String className = AotTestContextInitializersCodeGenerator.GENERATED_MAPPINGS_CLASS_NAME;
String methodName = AotTestContextInitializersCodeGenerator.GENERATED_MAPPINGS_METHOD_NAME;
String methodName = AotTestContextInitializersCodeGenerator.GET_CONTEXT_INITIALIZERS_METHOD_NAME;
return GeneratedMapUtils.loadMap(className, methodName);
}

@SuppressWarnings("unchecked")
private static Map<String, Class<ApplicationContextInitializer<?>>> loadContextInitializerClassesMap() {
String className = AotTestContextInitializersCodeGenerator.GENERATED_MAPPINGS_CLASS_NAME;
String methodName = AotTestContextInitializersCodeGenerator.GET_CONTEXT_INITIALIZER_CLASSES_METHOD_NAME;
return GeneratedMapUtils.loadMap(className, methodName);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ private MergedContextConfiguration replaceIfNecessary(MergedContextConfiguration
Class<?> testClass = mergedConfig.getTestClass();
if (this.aotTestContextInitializers.isSupportedTestClass(testClass)) {
Class<? extends ApplicationContextInitializer<?>> contextInitializerClass =
(Class<? extends ApplicationContextInitializer<?>>)
this.aotTestContextInitializers.getContextInitializer(testClass).getClass();
this.aotTestContextInitializers.getContextInitializerClass(testClass);
return new AotMergedContextConfiguration(testClass, contextInitializerClass, mergedConfig, this);
}
return mergedConfig;
Expand Down

0 comments on commit bca35dc

Please sign in to comment.