diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/logging/logback/LogbackLoggingSystem.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/logging/logback/LogbackLoggingSystem.java index 44da84cc5c46..3cd3c241f750 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/logging/logback/LogbackLoggingSystem.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/logging/logback/LogbackLoggingSystem.java @@ -43,6 +43,10 @@ import org.slf4j.Marker; import org.slf4j.bridge.SLF4JBridgeHandler; +import org.springframework.aot.AotDetector; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.boot.logging.AbstractLoggingSystem; import org.springframework.boot.logging.LogFile; import org.springframework.boot.logging.LogLevel; @@ -69,7 +73,7 @@ * @author Ben Hale * @since 1.0.0 */ -public class LogbackLoggingSystem extends AbstractLoggingSystem { +public class LogbackLoggingSystem extends AbstractLoggingSystem implements BeanFactoryInitializationAotProcessor { private static final String BRIDGE_HANDLER = "org.slf4j.bridge.SLF4JBridgeHandler"; @@ -178,7 +182,9 @@ public void initialize(LoggingInitializationContext initializationContext, Strin if (isAlreadyInitialized(loggerContext)) { return; } - super.initialize(initializationContext, configLocation, logFile); + if (!initializeFromAotGeneratedArtifactsIfPossible(initializationContext, logFile)) { + super.initialize(initializationContext, configLocation, logFile); + } loggerContext.getTurboFilterList().remove(FILTER); markAsInitialized(loggerContext); if (StringUtils.hasText(System.getProperty(CONFIGURATION_FILE_PROPERTY))) { @@ -187,6 +193,21 @@ public void initialize(LoggingInitializationContext initializationContext, Strin } } + private boolean initializeFromAotGeneratedArtifactsIfPossible(LoggingInitializationContext initializationContext, + LogFile logFile) { + if (!AotDetector.useGeneratedArtifacts()) { + return false; + } + if (initializationContext != null) { + applySystemProperties(initializationContext.getEnvironment(), logFile); + } + LoggerContext loggerContext = getLoggerContext(); + stopAndReset(loggerContext); + SpringBootJoranConfigurator configurator = new SpringBootJoranConfigurator(initializationContext); + configurator.setContext(loggerContext); + return configurator.configureUsingAotGeneratedArtifacts(); + } + @Override protected void loadDefaults(LoggingInitializationContext initializationContext, LogFile logFile) { LoggerContext context = getLoggerContext(); @@ -382,6 +403,16 @@ private void markAsUninitialized(LoggerContext loggerContext) { loggerContext.removeObject(LoggingSystem.class.getName()); } + @Override + public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) { + String key = BeanFactoryInitializationAotContribution.class.getName(); + LoggerContext context = getLoggerContext(); + BeanFactoryInitializationAotContribution contribution = (BeanFactoryInitializationAotContribution) context + .getObject(key); + context.removeObject(key); + return contribution; + } + /** * {@link LoggingSystemFactory} that returns {@link LogbackLoggingSystem} if possible. */ diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/logging/logback/SpringBootJoranConfigurator.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/logging/logback/SpringBootJoranConfigurator.java index 61e19042a351..cec2a5209d0c 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/logging/logback/SpringBootJoranConfigurator.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/logging/logback/SpringBootJoranConfigurator.java @@ -16,12 +16,51 @@ package org.springframework.boot.logging.logback; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.stream.Stream; + import ch.qos.logback.classic.joran.JoranConfigurator; +import ch.qos.logback.core.Context; +import ch.qos.logback.core.CoreConstants; +import ch.qos.logback.core.joran.spi.DefaultNestedComponentRegistry; import ch.qos.logback.core.joran.spi.ElementSelector; import ch.qos.logback.core.joran.spi.RuleStore; +import ch.qos.logback.core.joran.util.beans.BeanDescription; +import ch.qos.logback.core.joran.util.beans.BeanDescriptionCache; +import ch.qos.logback.core.model.ComponentModel; +import ch.qos.logback.core.model.Model; +import ch.qos.logback.core.model.ModelUtil; import ch.qos.logback.core.model.processor.DefaultProcessor; +import ch.qos.logback.core.spi.ContextAware; +import ch.qos.logback.core.spi.ContextAwareBase; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.SerializationHints; +import org.springframework.aot.hint.TypeReference; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; +import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; import org.springframework.boot.logging.LoggingInitializationContext; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.PropertiesLoaderUtils; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; /** * Extended version of the Logback {@link JoranConfigurator} that adds additional Spring @@ -57,4 +96,242 @@ public void addElementSelectorAndActionAssociations(RuleStore ruleStore) { ruleStore.addTransparentPathPart("springProfile"); } + boolean configureUsingAotGeneratedArtifacts() { + if (!new PatternRules(getContext()).load()) { + return false; + } + Model model = new ModelReader().read(); + processModel(model); + registerSafeConfiguration(model); + return true; + } + + @Override + public void processModel(Model model) { + super.processModel(model); + if (isAotProcessingInProgress()) { + getContext().putObject(BeanFactoryInitializationAotContribution.class.getName(), + new LogbackConfigurationAotContribution(model, + getModelInterpretationContext().getBeanDescriptionCache(), + getModelInterpretationContext().getDefaultNestedComponentRegistry(), getContext())); + } + } + + private boolean isAotProcessingInProgress() { + return Boolean.getBoolean("spring.aot.processing"); + } + + static final class LogbackConfigurationAotContribution implements BeanFactoryInitializationAotContribution { + + private final ModelWriter modelWriter; + + private final PatternRules patternRules; + + private LogbackConfigurationAotContribution(Model model, BeanDescriptionCache beanDescriptionCache, + DefaultNestedComponentRegistry nestedComponentRegistry, Context context) { + this.modelWriter = new ModelWriter(model, beanDescriptionCache, nestedComponentRegistry); + this.patternRules = new PatternRules(context); + } + + @Override + public void applyTo(GenerationContext generationContext, + BeanFactoryInitializationCode beanFactoryInitializationCode) { + this.modelWriter.writeTo(generationContext); + this.patternRules.save(generationContext); + } + + } + + private static final class ModelWriter { + + private static final String MODEL_RESOURCE_LOCATION = "META-INF/spring/logback-model"; + + private final Model model; + + private final BeanDescriptionCache beanDescriptionCache; + + private final DefaultNestedComponentRegistry nestedComponentRegistry; + + private ModelWriter(Model model, BeanDescriptionCache beanDescriptionCache, + DefaultNestedComponentRegistry nestedComponentRegistry) { + this.model = model; + this.beanDescriptionCache = beanDescriptionCache; + this.nestedComponentRegistry = nestedComponentRegistry; + } + + private void writeTo(GenerationContext generationContext) { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (ObjectOutputStream output = new ObjectOutputStream(bytes)) { + output.writeObject(this.model); + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + Resource modelResource = new ByteArrayResource(bytes.toByteArray()); + generationContext.getGeneratedFiles().addResourceFile(MODEL_RESOURCE_LOCATION, modelResource); + generationContext.getRuntimeHints().resources().registerPattern(MODEL_RESOURCE_LOCATION); + SerializationHints serializationHints = generationContext.getRuntimeHints().serialization(); + serializationTypes(this.model).forEach(serializationHints::registerType); + reflectionTypes(this.model).forEach((type) -> generationContext.getRuntimeHints().reflection().registerType( + TypeReference.of(type), MemberCategory.INTROSPECT_PUBLIC_METHODS, + MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS)); + } + + @SuppressWarnings("unchecked") + private Set> serializationTypes(Model model) { + Set> modelClasses = new HashSet<>(); + Class candidate = model.getClass(); + while (Model.class.isAssignableFrom(candidate)) { + if (modelClasses.add((Class) candidate)) { + ReflectionUtils.doWithFields(candidate, (field) -> { + if (Modifier.isStatic(field.getModifiers())) { + return; + } + ReflectionUtils.makeAccessible(field); + Object value = field.get(model); + if (value != null) { + Class fieldType = value.getClass(); + if (Serializable.class.isAssignableFrom(fieldType)) { + modelClasses.add((Class) fieldType); + } + } + }); + candidate = candidate.getSuperclass(); + } + } + for (Model submodel : model.getSubModels()) { + modelClasses.addAll(serializationTypes(submodel)); + } + return modelClasses; + } + + private Set reflectionTypes(Model model) { + Set reflectionTypes = new HashSet<>(); + if (model instanceof ComponentModel) { + String className = ((ComponentModel) model).getClassName(); + processComponent(className, reflectionTypes); + } + String tag = model.getTag(); + if (tag != null) { + String componentType = this.nestedComponentRegistry.findDefaultComponentTypeByTag(tag); + processComponent(componentType, reflectionTypes); + } + for (Model submodel : model.getSubModels()) { + reflectionTypes.addAll(reflectionTypes(submodel)); + } + return reflectionTypes; + } + + private void processComponent(String componentTypeName, Set reflectionTypes) { + if (componentTypeName != null) { + BeanDescription beanDescription = this.beanDescriptionCache + .getBeanDescription(loadComponentType(componentTypeName)); + reflectionTypes.addAll(parameterTypesNames(beanDescription.getPropertyNameToAdder().values())); + reflectionTypes.addAll(parameterTypesNames(beanDescription.getPropertyNameToSetter().values())); + reflectionTypes.add(componentTypeName); + } + } + + private Class loadComponentType(String componentType) { + try { + return ClassUtils.forName(componentType, getClass().getClassLoader()); + } + catch (Throwable ex) { + throw new RuntimeException("Failed to load component type '" + componentType + "'", ex); + } + } + + private Collection parameterTypesNames(Collection methods) { + return methods.stream() + .filter((method) -> !method.getDeclaringClass().equals(ContextAware.class) + && !method.getDeclaringClass().equals(ContextAwareBase.class)) + .map(Method::getParameterTypes).flatMap(Stream::of) + .filter((type) -> !type.isPrimitive() && !type.equals(String.class)).map(Class::getName).toList(); + } + + } + + private static final class ModelReader { + + private Model read() { + try (InputStream modelInput = getClass().getClassLoader() + .getResourceAsStream(ModelWriter.MODEL_RESOURCE_LOCATION)) { + try (ObjectInputStream input = new ObjectInputStream(modelInput)) { + Model model = (Model) input.readObject(); + ModelUtil.resetForReuse(model); + return model; + } + } + catch (Exception ex) { + throw new RuntimeException("Failed to load model from '" + ModelWriter.MODEL_RESOURCE_LOCATION + "'", + ex); + } + } + + } + + private static final class PatternRules { + + private static final String RESOURCE_LOCATION = "META-INF/spring/logback-pattern-rules"; + + private final Context context; + + private PatternRules(Context context) { + this.context = context; + } + + private boolean load() { + try { + ClassPathResource resource = new ClassPathResource(RESOURCE_LOCATION); + if (!resource.exists()) { + return false; + } + Properties properties = PropertiesLoaderUtils.loadProperties(resource); + Map patternRuleRegistry = getRegistryMap(); + for (String word : properties.stringPropertyNames()) { + patternRuleRegistry.put(word, properties.getProperty(word)); + } + return true; + } + catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + @SuppressWarnings("unchecked") + private Map getRegistryMap() { + Map patternRuleRegistry = (Map) this.context + .getObject(CoreConstants.PATTERN_RULE_REGISTRY); + if (patternRuleRegistry == null) { + patternRuleRegistry = new HashMap<>(); + this.context.putObject(CoreConstants.PATTERN_RULE_REGISTRY, patternRuleRegistry); + } + return patternRuleRegistry; + } + + private void save(GenerationContext generationContext) { + Map registryMap = getRegistryMap(); + generationContext.getGeneratedFiles().addResourceFile(RESOURCE_LOCATION, () -> asInputStream(registryMap)); + generationContext.getRuntimeHints().resources().registerPattern(RESOURCE_LOCATION); + for (String ruleClassName : registryMap.values()) { + generationContext.getRuntimeHints().reflection().registerType(TypeReference.of(ruleClassName), + MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS); + } + } + + private InputStream asInputStream(Map patternRuleRegistry) { + Properties properties = new Properties(); + patternRuleRegistry.forEach(properties::setProperty); + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try { + properties.store(bytes, ""); + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + return new ByteArrayInputStream(bytes.toByteArray()); + } + + } + } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/LogbackConfigurationAotContributionTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/LogbackConfigurationAotContributionTests.java new file mode 100644 index 000000000000..3d44a884d089 --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/LogbackConfigurationAotContributionTests.java @@ -0,0 +1,190 @@ +/* + * Copyright 2012-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.logging.logback; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Map; +import java.util.Properties; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import ch.qos.logback.classic.LoggerContext; +import ch.qos.logback.classic.encoder.PatternLayoutEncoder; +import ch.qos.logback.core.CoreConstants; +import ch.qos.logback.core.FileAppender; +import ch.qos.logback.core.Layout; +import ch.qos.logback.core.model.ComponentModel; +import ch.qos.logback.core.model.ImplicitModel; +import ch.qos.logback.core.model.Model; +import ch.qos.logback.core.rolling.SizeAndTimeBasedRollingPolicy; +import ch.qos.logback.core.rolling.TimeBasedFileNamingAndTriggeringPolicy; +import ch.qos.logback.core.util.FileSize; +import org.assertj.core.api.Condition; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; + +import org.springframework.aot.generate.GeneratedFiles.Kind; +import org.springframework.aot.generate.InMemoryGeneratedFiles; +import org.springframework.aot.hint.JavaSerializationHint; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.SerializationHints; +import org.springframework.aot.hint.TypeReference; +import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; +import org.springframework.aot.test.generate.TestGenerationContext; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; +import org.springframework.boot.logging.logback.SpringBootJoranConfigurator.LogbackConfigurationAotContribution; +import org.springframework.core.io.InputStreamSource; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link LogbackConfigurationAotContribution}. + * + * @author Andy Wilkinson + */ +class LogbackConfigurationAotContributionTests { + + @BeforeEach + @AfterEach + void prepare() { + LoggerContext context = (LoggerContext) LoggerFactory.getILoggerFactory(); + context.reset(); + } + + @Test + void contributionOfBasicModel() { + TestGenerationContext generationContext = applyContribution(new Model()); + InMemoryGeneratedFiles generatedFiles = generationContext.getGeneratedFiles(); + assertThat(generatedFiles).has(resource("META-INF/spring/logback-model")); + assertThat(generatedFiles).has(resource("META-INF/spring/logback-pattern-rules")); + SerializationHints serializationHints = generationContext.getRuntimeHints().serialization(); + assertThat(serializationHints.javaSerializationHints().map(JavaSerializationHint::getType) + .map(TypeReference::getName)) + .containsExactlyInAnyOrder(namesOf(Model.class, ArrayList.class, Boolean.class, Integer.class)); + assertThat(generationContext.getRuntimeHints().reflection().typeHints()).isEmpty(); + Properties patternRules = load( + generatedFiles.getGeneratedFile(Kind.RESOURCE, "META-INF/spring/logback-pattern-rules")); + assertThat(patternRules).isEmpty(); + } + + @Test + void patternRulesAreStoredAndRegisteredForReflection() { + LoggerContext context = (LoggerContext) LoggerFactory.getILoggerFactory(); + context.putObject(CoreConstants.PATTERN_RULE_REGISTRY, + Map.of("a", "com.example.Alpha", "b", "com.example.Bravo")); + TestGenerationContext generationContext = applyContribution(new Model()); + assertThat(invokePublicConstructorsOf("com.example.Alpha")).accepts(generationContext.getRuntimeHints()); + assertThat(invokePublicConstructorsOf("com.example.Bravo")).accepts(generationContext.getRuntimeHints()); + Properties patternRules = load(generationContext.getGeneratedFiles().getGeneratedFile(Kind.RESOURCE, + "META-INF/spring/logback-pattern-rules")); + assertThat(patternRules).hasSize(2); + assertThat(patternRules).containsEntry("a", "com.example.Alpha"); + assertThat(patternRules).containsEntry("b", "com.example.Bravo"); + } + + @Test + void componentModelClassAndSetterParametersAreRegisteredForReflection() { + ComponentModel component = new ComponentModel(); + component.setClassName(SizeAndTimeBasedRollingPolicy.class.getName()); + Model model = new Model(); + model.getSubModels().add(component); + TestGenerationContext generationContext = applyContribution(model); + assertThat(invokePublicConstructorsAndInspectAndInvokePublicMethodsOf(SizeAndTimeBasedRollingPolicy.class)) + .accepts(generationContext.getRuntimeHints()); + assertThat(invokePublicConstructorsAndInspectAndInvokePublicMethodsOf(FileAppender.class)) + .accepts(generationContext.getRuntimeHints()); + assertThat(invokePublicConstructorsAndInspectAndInvokePublicMethodsOf(FileSize.class)) + .accepts(generationContext.getRuntimeHints()); + assertThat(invokePublicConstructorsAndInspectAndInvokePublicMethodsOf( + TimeBasedFileNamingAndTriggeringPolicy.class)).accepts(generationContext.getRuntimeHints()); + } + + @Test + void implicitModelClassAndSetterParametersAreRegisteredForReflection() { + ImplicitModel implicit = new ImplicitModel(); + implicit.setTag("encoder"); + Model model = new Model(); + model.getSubModels().add(implicit); + TestGenerationContext generationContext = applyContribution(model); + assertThat(invokePublicConstructorsAndInspectAndInvokePublicMethodsOf(PatternLayoutEncoder.class)) + .accepts(generationContext.getRuntimeHints()); + assertThat(invokePublicConstructorsAndInspectAndInvokePublicMethodsOf(Layout.class)) + .accepts(generationContext.getRuntimeHints()); + assertThat(invokePublicConstructorsAndInspectAndInvokePublicMethodsOf(Charset.class)) + .accepts(generationContext.getRuntimeHints()); + } + + private Predicate invokePublicConstructorsOf(String name) { + return RuntimeHintsPredicates.reflection().onType(TypeReference.of(name)) + .withMemberCategory(MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS); + } + + private Predicate invokePublicConstructorsAndInspectAndInvokePublicMethodsOf(Class type) { + return RuntimeHintsPredicates.reflection().onType(TypeReference.of(type)).withMemberCategories( + MemberCategory.INVOKE_PUBLIC_CONSTRUCTORS, MemberCategory.INTROSPECT_PUBLIC_METHODS, + MemberCategory.INVOKE_PUBLIC_METHODS); + } + + private Properties load(InputStreamSource source) { + try (InputStream inputStream = source.getInputStream()) { + Properties properties = new Properties(); + properties.load(inputStream); + return properties; + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + } + + private Condition resource(String name) { + return new Condition<>((files) -> files.getGeneratedFile(Kind.RESOURCE, name) != null, + "has a resource named '%s'", name); + } + + private TestGenerationContext applyContribution(Model model) { + LoggerContext context = (LoggerContext) LoggerFactory.getILoggerFactory(); + SpringBootJoranConfigurator configurator = new SpringBootJoranConfigurator(null); + configurator.setContext(context); + withSystemProperty("spring.aot.processing", "true", () -> configurator.processModel(model)); + LogbackConfigurationAotContribution contribution = (LogbackConfigurationAotContribution) context + .getObject(BeanFactoryInitializationAotContribution.class.getName()); + TestGenerationContext generationContext = new TestGenerationContext(); + contribution.applyTo(generationContext, null); + return generationContext; + } + + private String[] namesOf(Class... types) { + return Stream.of(types).map(Class::getName).toArray(String[]::new); + } + + private void withSystemProperty(String name, String value, Runnable action) { + System.setProperty(name, value); + try { + action.run(); + } + finally { + System.clearProperty(name); + } + } + +} diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/LogbackLoggingSystemTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/LogbackLoggingSystemTests.java index cf2b02d5db0d..5871e5885302 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/LogbackLoggingSystemTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/LogbackLoggingSystemTests.java @@ -46,6 +46,7 @@ import org.slf4j.LoggerFactory; import org.slf4j.bridge.SLF4JBridgeHandler; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; import org.springframework.boot.convert.ApplicationConversionService; import org.springframework.boot.logging.AbstractLoggingSystemTests; import org.springframework.boot.logging.LogFile; @@ -629,6 +630,22 @@ void customCharset() { assertThat(encoder.getCharset()).isEqualTo(StandardCharsets.UTF_16); } + @Test + void whenContextHasNoAotContributionThenProcessAheadOfTimeReturnsNull() { + BeanFactoryInitializationAotContribution contribution = this.loggingSystem.processAheadOfTime(null); + assertThat(contribution).isNull(); + } + + @Test + void whenContextHasAotContributionThenProcessAheadOfTimeClearsAndReturnsIt() { + LoggerContext context = ((LoggerContext) LoggerFactory.getILoggerFactory()); + context.putObject(BeanFactoryInitializationAotContribution.class.getName(), + mock(BeanFactoryInitializationAotContribution.class)); + BeanFactoryInitializationAotContribution contribution = this.loggingSystem.processAheadOfTime(null); + assertThat(context.getObject(BeanFactoryInitializationAotContribution.class.getName())).isNull(); + assertThat(contribution).isNotNull(); + } + private void initialize(LoggingInitializationContext context, String configLocation, LogFile logFile) { this.loggingSystem.getSystemProperties((ConfigurableEnvironment) context.getEnvironment()).apply(logFile); this.loggingSystem.initialize(context, configLocation, logFile); diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/SpringBootJoranConfiguratorTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/SpringBootJoranConfiguratorTests.java index 64e7bf85ed41..20b06e6662b0 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/SpringBootJoranConfiguratorTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/logging/logback/SpringBootJoranConfiguratorTests.java @@ -27,6 +27,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; import org.springframework.boot.context.properties.source.ConfigurationPropertySources; import org.springframework.boot.logging.LoggingInitializationContext; import org.springframework.boot.testsupport.system.CapturedOutput; @@ -188,6 +189,25 @@ void relaxedSpringPropertyWithDefaultValue() throws Exception { assertThat(this.context.getProperty("MINE")).isEqualTo("bar"); } + @Test + void addsAotContributionToContextDuringAotProcessing() throws Exception { + withSystemProperty("spring.aot.processing", "true", () -> { + initialize("property.xml"); + Object contribution = this.context.getObject(BeanFactoryInitializationAotContribution.class.getName()); + assertThat(contribution).isNotNull(); + }); + } + + private void withSystemProperty(String name, String value, Action action) throws Exception { + System.setProperty(name, value); + try { + action.perform(); + } + finally { + System.clearProperty(name); + } + } + private void doTestNestedProfile(boolean expected, String... profiles) throws JoranException { this.environment.setActiveProfiles(profiles); initialize("nested.xml"); @@ -206,4 +226,10 @@ private void initialize(String config) throws JoranException { this.configurator.doConfigure(getClass().getResourceAsStream(config)); } + private interface Action { + + void perform() throws Exception; + + } + } diff --git a/src/checkstyle/checkstyle-suppressions.xml b/src/checkstyle/checkstyle-suppressions.xml index d4dfb63f3bc4..683f96d85630 100644 --- a/src/checkstyle/checkstyle-suppressions.xml +++ b/src/checkstyle/checkstyle-suppressions.xml @@ -10,6 +10,7 @@ +