Skip to content

Commit

Permalink
ImportSelector.getCandidateFilter() for transitive filtering of classes
Browse files Browse the repository at this point in the history
Closes gh-24175
  • Loading branch information
jhoeller committed Feb 5, 2020
1 parent c2367b3 commit d93303c
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,6 +34,7 @@
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;
import java.util.function.Predicate;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -110,6 +111,9 @@ class ConfigurationClassParser {

private static final PropertySourceFactory DEFAULT_PROPERTY_SOURCE_FACTORY = new DefaultPropertySourceFactory();

private static final Predicate<String> DEFAULT_CANDIDATE_FILTER = className ->
(className.startsWith("java.lang.annotation.") || className.startsWith("org.springframework.stereotype."));

private static final Comparator<DeferredImportSelectorHolder> DEFERRED_IMPORT_COMPARATOR =
(o1, o2) -> AnnotationAwareOrderComparator.INSTANCE.compare(o1.getImportSelector(), o2.getImportSelector());

Expand Down Expand Up @@ -191,15 +195,15 @@ else if (bd instanceof AbstractBeanDefinition && ((AbstractBeanDefinition) bd).h
protected final void parse(@Nullable String className, String beanName) throws IOException {
Assert.notNull(className, "No bean class name for configuration class bean definition");
MetadataReader reader = this.metadataReaderFactory.getMetadataReader(className);
processConfigurationClass(new ConfigurationClass(reader, beanName));
processConfigurationClass(new ConfigurationClass(reader, beanName), DEFAULT_CANDIDATE_FILTER);
}

protected final void parse(Class<?> clazz, String beanName) throws IOException {
processConfigurationClass(new ConfigurationClass(clazz, beanName));
processConfigurationClass(new ConfigurationClass(clazz, beanName), DEFAULT_CANDIDATE_FILTER);
}

protected final void parse(AnnotationMetadata metadata, String beanName) throws IOException {
processConfigurationClass(new ConfigurationClass(metadata, beanName));
processConfigurationClass(new ConfigurationClass(metadata, beanName), DEFAULT_CANDIDATE_FILTER);
}

/**
Expand All @@ -217,7 +221,7 @@ public Set<ConfigurationClass> getConfigurationClasses() {
}


protected void processConfigurationClass(ConfigurationClass configClass) throws IOException {
protected void processConfigurationClass(ConfigurationClass configClass, Predicate<String> filter) throws IOException {
if (this.conditionEvaluator.shouldSkip(configClass.getMetadata(), ConfigurationPhase.PARSE_CONFIGURATION)) {
return;
}
Expand All @@ -240,9 +244,9 @@ protected void processConfigurationClass(ConfigurationClass configClass) throws
}

// Recursively process the configuration class and its superclass hierarchy.
SourceClass sourceClass = asSourceClass(configClass);
SourceClass sourceClass = asSourceClass(configClass, filter);
do {
sourceClass = doProcessConfigurationClass(configClass, sourceClass);
sourceClass = doProcessConfigurationClass(configClass, sourceClass, filter);
}
while (sourceClass != null);

Expand All @@ -258,12 +262,13 @@ protected void processConfigurationClass(ConfigurationClass configClass) throws
* @return the superclass, or {@code null} if none found or previously processed
*/
@Nullable
protected final SourceClass doProcessConfigurationClass(ConfigurationClass configClass, SourceClass sourceClass)
protected final SourceClass doProcessConfigurationClass(
ConfigurationClass configClass, SourceClass sourceClass, Predicate<String> filter)
throws IOException {

if (configClass.getMetadata().isAnnotated(Component.class.getName())) {
// Recursively process any member (nested) classes first
processMemberClasses(configClass, sourceClass);
processMemberClasses(configClass, sourceClass, filter);
}

// Process any @PropertySource annotations
Expand Down Expand Up @@ -302,7 +307,7 @@ protected final SourceClass doProcessConfigurationClass(ConfigurationClass confi
}

// Process any @Import annotations
processImports(configClass, sourceClass, getImports(sourceClass), true);
processImports(configClass, sourceClass, getImports(sourceClass), filter, true);

// Process any @ImportResource annotations
AnnotationAttributes importResource =
Expand Down Expand Up @@ -343,7 +348,9 @@ protected final SourceClass doProcessConfigurationClass(ConfigurationClass confi
/**
* Register member (nested) classes that happen to be configuration classes themselves.
*/
private void processMemberClasses(ConfigurationClass configClass, SourceClass sourceClass) throws IOException {
private void processMemberClasses(ConfigurationClass configClass, SourceClass sourceClass,
Predicate<String> filter) throws IOException {

Collection<SourceClass> memberClasses = sourceClass.getMemberClasses();
if (!memberClasses.isEmpty()) {
List<SourceClass> candidates = new ArrayList<>(memberClasses.size());
Expand All @@ -361,7 +368,7 @@ private void processMemberClasses(ConfigurationClass configClass, SourceClass so
else {
this.importStack.push(configClass);
try {
processConfigurationClass(candidate.asConfigClass(configClass));
processConfigurationClass(candidate.asConfigClass(configClass), filter);
}
finally {
this.importStack.pop();
Expand Down Expand Up @@ -543,7 +550,8 @@ private void collectImports(SourceClass sourceClass, Set<SourceClass> imports, S
}

private void processImports(ConfigurationClass configClass, SourceClass currentSourceClass,
Collection<SourceClass> importCandidates, boolean checkForCircularImports) {
Collection<SourceClass> importCandidates, Predicate<String> candidateFilter,
boolean checkForCircularImports) {

if (importCandidates.isEmpty()) {
return;
Expand All @@ -561,13 +569,17 @@ private void processImports(ConfigurationClass configClass, SourceClass currentS
Class<?> candidateClass = candidate.loadClass();
ImportSelector selector = ParserStrategyUtils.instantiateClass(candidateClass, ImportSelector.class,
this.environment, this.resourceLoader, this.registry);
Predicate<String> selectorFilter = selector.getCandidateFilter();
if (selectorFilter != null) {
candidateFilter = candidateFilter.or(selectorFilter);
}
if (selector instanceof DeferredImportSelector) {
this.deferredImportSelectorHandler.handle(configClass, (DeferredImportSelector) selector);
}
else {
String[] importClassNames = selector.selectImports(currentSourceClass.getMetadata());
Collection<SourceClass> importSourceClasses = asSourceClasses(importClassNames);
processImports(configClass, currentSourceClass, importSourceClasses, false);
Collection<SourceClass> importSourceClasses = asSourceClasses(importClassNames, candidateFilter);
processImports(configClass, currentSourceClass, importSourceClasses, candidateFilter, false);
}
}
else if (candidate.isAssignable(ImportBeanDefinitionRegistrar.class)) {
Expand All @@ -584,7 +596,7 @@ else if (candidate.isAssignable(ImportBeanDefinitionRegistrar.class)) {
// process it as an @Configuration class
this.importStack.registerImport(
currentSourceClass.getMetadata(), candidate.getMetadata().getClassName());
processConfigurationClass(candidate.asConfigClass(configClass));
processConfigurationClass(candidate.asConfigClass(configClass), candidateFilter);
}
}
}
Expand Down Expand Up @@ -624,19 +636,19 @@ ImportRegistry getImportRegistry() {
/**
* Factory method to obtain a {@link SourceClass} from a {@link ConfigurationClass}.
*/
private SourceClass asSourceClass(ConfigurationClass configurationClass) throws IOException {
private SourceClass asSourceClass(ConfigurationClass configurationClass, Predicate<String> filter) throws IOException {
AnnotationMetadata metadata = configurationClass.getMetadata();
if (metadata instanceof StandardAnnotationMetadata) {
return asSourceClass(((StandardAnnotationMetadata) metadata).getIntrospectedClass());
return asSourceClass(((StandardAnnotationMetadata) metadata).getIntrospectedClass(), filter);
}
return asSourceClass(metadata.getClassName());
return asSourceClass(metadata.getClassName(), filter);
}

/**
* Factory method to obtain a {@link SourceClass} from a {@link Class}.
*/
SourceClass asSourceClass(@Nullable Class<?> classType) throws IOException {
if (classType == null || classType.getName().startsWith("java.lang.annotation.")) {
SourceClass asSourceClass(@Nullable Class<?> classType, Predicate<String> filter) throws IOException {
if (classType == null || filter.test(classType.getName())) {
return this.objectSourceClass;
}
try {
Expand All @@ -649,41 +661,38 @@ SourceClass asSourceClass(@Nullable Class<?> classType) throws IOException {
}
catch (Throwable ex) {
// Enforce ASM via class name resolution
return asSourceClass(classType.getName());
return asSourceClass(classType.getName(), filter);
}
}

/**
* Factory method to obtain {@link SourceClass SourceClasss} from class names.
*/
private Collection<SourceClass> asSourceClasses(String... classNames) throws IOException {
private Collection<SourceClass> asSourceClasses(String[] classNames, Predicate<String> filter) throws IOException {
List<SourceClass> annotatedClasses = new ArrayList<>(classNames.length);
for (String className : classNames) {
annotatedClasses.add(asSourceClass(className));
annotatedClasses.add(asSourceClass(className, filter));
}
return annotatedClasses;
}

/**
* Factory method to obtain a {@link SourceClass} from a class name.
*/
SourceClass asSourceClass(@Nullable String className) throws IOException {
if (className == null || className.startsWith("java.lang.annotation.")) {
SourceClass asSourceClass(@Nullable String className, Predicate<String> filter) throws IOException {
if (className == null || filter.test(className)) {
return this.objectSourceClass;
}
if (className.startsWith("java")) {
// Never use ASM for core java types
try {
return new SourceClass(ClassUtils.forName(className,
this.resourceLoader.getClassLoader()));
return new SourceClass(ClassUtils.forName(className, this.resourceLoader.getClassLoader()));
}
catch (ClassNotFoundException ex) {
throw new NestedIOException(
"Failed to load class [" + className + "]", ex);
throw new NestedIOException("Failed to load class [" + className + "]", ex);
}
}
return new SourceClass(
this.metadataReaderFactory.getMetadataReader(className));
return new SourceClass(this.metadataReaderFactory.getMetadataReader(className));
}


Expand Down Expand Up @@ -748,8 +757,7 @@ private class DeferredImportSelectorHandler {
* @param importSelector the selector to handle
*/
public void handle(ConfigurationClass configClass, DeferredImportSelector importSelector) {
DeferredImportSelectorHolder holder = new DeferredImportSelectorHolder(
configClass, importSelector);
DeferredImportSelectorHolder holder = new DeferredImportSelectorHolder(configClass, importSelector);
if (this.deferredImportSelectors == null) {
DeferredImportSelectorGroupingHandler handler = new DeferredImportSelectorGroupingHandler();
handler.register(holder);
Expand All @@ -775,7 +783,6 @@ public void process() {
this.deferredImportSelectors = new ArrayList<>();
}
}

}


Expand All @@ -786,8 +793,7 @@ private class DeferredImportSelectorGroupingHandler {
private final Map<AnnotationMetadata, ConfigurationClass> configurationClasses = new HashMap<>();

public void register(DeferredImportSelectorHolder deferredImport) {
Class<? extends Group> group = deferredImport.getImportSelector()
.getImportGroup();
Class<? extends Group> group = deferredImport.getImportSelector().getImportGroup();
DeferredImportSelectorGrouping grouping = this.groupings.computeIfAbsent(
(group != null ? group : deferredImport),
key -> new DeferredImportSelectorGrouping(createGroup(group)));
Expand All @@ -798,12 +804,13 @@ public void register(DeferredImportSelectorHolder deferredImport) {

public void processGroupImports() {
for (DeferredImportSelectorGrouping grouping : this.groupings.values()) {
Predicate<String> candidateFilter = grouping.getCandidateFilter();
grouping.getImports().forEach(entry -> {
ConfigurationClass configurationClass = this.configurationClasses.get(
entry.getMetadata());
ConfigurationClass configurationClass = this.configurationClasses.get(entry.getMetadata());
try {
processImports(configurationClass, asSourceClass(configurationClass),
asSourceClasses(entry.getImportClassName()), false);
processImports(configurationClass, asSourceClass(configurationClass, candidateFilter),
Collections.singleton(asSourceClass(entry.getImportClassName(), candidateFilter)),
candidateFilter, false);
}
catch (BeanDefinitionStoreException ex) {
throw ex;
Expand All @@ -818,15 +825,12 @@ public void processGroupImports() {
}

private Group createGroup(@Nullable Class<? extends Group> type) {
Class<? extends Group> effectiveType = (type != null ? type
: DefaultDeferredImportSelectorGroup.class);
Group group = ParserStrategyUtils.instantiateClass(effectiveType, Group.class,
Class<? extends Group> effectiveType = (type != null ? type : DefaultDeferredImportSelectorGroup.class);
return ParserStrategyUtils.instantiateClass(effectiveType, Group.class,
ConfigurationClassParser.this.environment,
ConfigurationClassParser.this.resourceLoader,
ConfigurationClassParser.this.registry);
return group;
}

}


Expand Down Expand Up @@ -861,6 +865,10 @@ private static class DeferredImportSelectorGrouping {
this.group = group;
}

public Group getGroup() {
return this.group;
}

public void add(DeferredImportSelectorHolder deferredImport) {
this.deferredImports.add(deferredImport);
}
Expand All @@ -876,6 +884,17 @@ public Iterable<Group.Entry> getImports() {
}
return this.group.selectImports();
}

public Predicate<String> getCandidateFilter() {
Predicate<String> mergedFilter = DEFAULT_CANDIDATE_FILTER;
for (DeferredImportSelectorHolder deferredImport : this.deferredImports) {
Predicate<String> selectorFilter = deferredImport.getImportSelector().getCandidateFilter();
if (selectorFilter != null) {
mergedFilter = mergedFilter.or(selectorFilter);
}
}
return mergedFilter;
}
}


Expand Down Expand Up @@ -957,7 +976,7 @@ public Collection<SourceClass> getMemberClasses() throws IOException {
Class<?>[] declaredClasses = sourceClass.getDeclaredClasses();
List<SourceClass> members = new ArrayList<>(declaredClasses.length);
for (Class<?> declaredClass : declaredClasses) {
members.add(asSourceClass(declaredClass));
members.add(asSourceClass(declaredClass, DEFAULT_CANDIDATE_FILTER));
}
return members;
}
Expand All @@ -974,7 +993,7 @@ public Collection<SourceClass> getMemberClasses() throws IOException {
List<SourceClass> members = new ArrayList<>(memberClassNames.length);
for (String memberClassName : memberClassNames) {
try {
members.add(asSourceClass(memberClassName));
members.add(asSourceClass(memberClassName, DEFAULT_CANDIDATE_FILTER));
}
catch (IOException ex) {
// Let's skip it if it's not resolvable - we're just looking for candidates
Expand All @@ -989,22 +1008,23 @@ public Collection<SourceClass> getMemberClasses() throws IOException {

public SourceClass getSuperClass() throws IOException {
if (this.source instanceof Class) {
return asSourceClass(((Class<?>) this.source).getSuperclass());
return asSourceClass(((Class<?>) this.source).getSuperclass(), DEFAULT_CANDIDATE_FILTER);
}
return asSourceClass(((MetadataReader) this.source).getClassMetadata().getSuperClassName());
return asSourceClass(
((MetadataReader) this.source).getClassMetadata().getSuperClassName(), DEFAULT_CANDIDATE_FILTER);
}

public Set<SourceClass> getInterfaces() throws IOException {
Set<SourceClass> result = new LinkedHashSet<>();
if (this.source instanceof Class) {
Class<?> sourceClass = (Class<?>) this.source;
for (Class<?> ifcClass : sourceClass.getInterfaces()) {
result.add(asSourceClass(ifcClass));
result.add(asSourceClass(ifcClass, DEFAULT_CANDIDATE_FILTER));
}
}
else {
for (String className : this.metadata.getInterfaceNames()) {
result.add(asSourceClass(className));
result.add(asSourceClass(className, DEFAULT_CANDIDATE_FILTER));
}
}
return result;
Expand All @@ -1018,7 +1038,7 @@ public Set<SourceClass> getAnnotations() {
Class<?> annType = ann.annotationType();
if (!annType.getName().startsWith("java")) {
try {
result.add(asSourceClass(annType));
result.add(asSourceClass(annType, DEFAULT_CANDIDATE_FILTER));
}
catch (Throwable ex) {
// An annotation not present on the classpath is being ignored
Expand Down Expand Up @@ -1060,7 +1080,7 @@ private SourceClass getRelated(String className) throws IOException {
if (this.source instanceof Class) {
try {
Class<?> clazz = ClassUtils.forName(className, ((Class<?>) this.source).getClassLoader());
return asSourceClass(clazz);
return asSourceClass(clazz, DEFAULT_CANDIDATE_FILTER);
}
catch (ClassNotFoundException ex) {
// Ignore -> fall back to ASM next, except for core java types.
Expand All @@ -1070,7 +1090,7 @@ private SourceClass getRelated(String className) throws IOException {
return new SourceClass(metadataReaderFactory.getMetadataReader(className));
}
}
return asSourceClass(className);
return asSourceClass(className, DEFAULT_CANDIDATE_FILTER);
}

@Override
Expand Down
Loading

0 comments on commit d93303c

Please sign in to comment.