Skip to content

Commit

Permalink
fix [BUG] #1606 AiServicesAutoConfig is unable to detect AiService in…
Browse files Browse the repository at this point in the history
… the package specified by @componentscan
  • Loading branch information
qing-wq committed Aug 19, 2024
1 parent 17e2470 commit 34e27fa
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package dev.langchain4j.service.spring;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.boot.autoconfigure.AutoConfigurationPackages;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.stereotype.Component;

import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

@Component
public class AiServiceScannerProcessor implements BeanDefinitionRegistryPostProcessor {

@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
ClassPathAiServiceScanner classPathAiServiceScanner = new ClassPathAiServiceScanner(registry, false);
classPathAiServiceScanner.addIncludeFilter(new AnnotationTypeFilter(AiService.class));
Set<String> basePackages = getBasePackages((ConfigurableListableBeanFactory) registry);
for (String basePackage : basePackages) {
classPathAiServiceScanner.scan(basePackage);
}
}

private Set<String> getBasePackages(ConfigurableListableBeanFactory beanFactory) {
Set<String> basePackages = new LinkedHashSet<>();

List<String> autoConfigPackages = AutoConfigurationPackages.get(beanFactory);
basePackages.addAll(autoConfigPackages);

String[] beanNames = beanFactory.getBeanNamesForAnnotation(ComponentScan.class);
for (String beanName : beanNames) {
Class<?> beanClass = beanFactory.getType(beanName);
if (beanClass != null) {
ComponentScan componentScan = beanClass.getAnnotation(ComponentScan.class);
if (componentScan != null) {
Collections.addAll(basePackages, componentScan.value());
Collections.addAll(basePackages, componentScan.basePackages());
for (Class<?> basePackageClass : componentScan.basePackageClasses()) {
basePackages.add(basePackageClass.getPackage().getName());
}
}
}
}

return basePackages;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,12 @@
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import org.reflections.Reflections;
import org.reflections.util.ConfigurationBuilder;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;

import java.lang.reflect.Method;
Expand Down Expand Up @@ -60,13 +55,9 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
}
}

findAiServices(beanFactory).forEach(aiServiceClass -> {

if (beanFactory.getBeanNamesForType(aiServiceClass).length > 0) {
// User probably wants to configure AI Service bean manually
// TODO or better fail because user should not annotate it with @AiService then?
return;
}
String[] aiServiceBean = beanFactory.getBeanNamesForAnnotation(AiService.class);
for (String aiService : aiServiceBean) {
Class<?> aiServiceClass = beanFactory.getType(aiService);

GenericBeanDefinition aiServiceBeanDefinition = new GenericBeanDefinition();
aiServiceBeanDefinition.setBeanClass(AiServiceFactory.class);
Expand Down Expand Up @@ -144,21 +135,12 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
}

BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
registry.registerBeanDefinition(lowercaseFirstLetter(aiServiceClass.getSimpleName()), aiServiceBeanDefinition);
});
registry.removeBeanDefinition(aiService);
registry.registerBeanDefinition(lowercaseFirstLetter(aiService), aiServiceBeanDefinition);
}
};
}

private static Set<Class<?>> findAiServices(ConfigurableListableBeanFactory beanFactory) {
String[] applicationBean = beanFactory.getBeanNamesForAnnotation(SpringBootApplication.class);
BeanDefinition applicationBeanDefinition = beanFactory.getBeanDefinition(applicationBean[0]);
String basePackage = applicationBeanDefinition.getResolvableType().resolve().getPackage().getName();
Reflections reflections = new Reflections((new ConfigurationBuilder()).forPackage(basePackage));
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(AiService.class);
classes.removeIf(clazz -> !clazz.getName().startsWith(basePackage));
return classes;
}

private static void addBeanReference(Class<?> beanType,
AiService aiServiceAnnotation,
String customBeanName,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package dev.langchain4j.service.spring;

import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.context.annotation.ClassPathBeanDefinitionScanner;

public class ClassPathAiServiceScanner extends ClassPathBeanDefinitionScanner {

public ClassPathAiServiceScanner(BeanDefinitionRegistry registry, boolean useDefaultFilters) {
super(registry, useDefaultFilters);
}

@Override
protected boolean isCandidateComponent(AnnotatedBeanDefinition beanDefinition) {
return beanDefinition.getMetadata().isInterface() && beanDefinition.getMetadata().isIndependent();
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package dev.langchain4j.spring;

import dev.langchain4j.rag.spring.RagAutoConfig;
import dev.langchain4j.service.spring.AiServiceScannerProcessor;
import dev.langchain4j.service.spring.AiServicesAutoConfig;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.context.annotation.Import;

@AutoConfiguration
@Import({
AiServicesAutoConfig.class,
RagAutoConfig.class
RagAutoConfig.class,
AiServiceScannerProcessor.class
})
public class LangChain4jAutoConfig {
}

0 comments on commit 34e27fa

Please sign in to comment.