From 79892f1b1383e27ddc873ec1b8850e8957b7e77b Mon Sep 17 00:00:00 2001 From: Necrosis <60231561+N3cr0s1s@users.noreply.github.com> Date: Mon, 11 Nov 2024 10:03:39 +0100 Subject: [PATCH] fix #1193 ModerationModel is not auto configured (#52) Fix issue [langchain4j/langchain4j#1193](https://github.com/langchain4j/langchain4j/issues/1193) Added ModerationModel support to `AiService`,`AiServiceFactory` and `AiServiceAutoConfig` --- .../langchain4j/service/spring/AiService.java | 9 ++++- .../service/spring/AiServiceFactory.java | 10 +++++ .../service/spring/AiServicesAutoConfig.java | 12 ++++++ .../AiServiceWithModerationModel.java | 12 ++++++ ...ServiceWithModerationModelApplication.java | 34 +++++++++++++++++ .../AiServiceWithModerationModelIT.java | 38 +++++++++++++++++++ 6 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModel.java create mode 100644 langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModelApplication.java create mode 100644 langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModelIT.java diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java index 169c3f7e..43794649 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java @@ -5,6 +5,7 @@ import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.retriever.ContentRetriever; import dev.langchain4j.service.AiServices; @@ -91,9 +92,15 @@ */ String retrievalAugmentor() default ""; + /** + * When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT}, + * this attribute specifies the name of a {@link ModerationModel} bean that should be used by this AI Service. + */ + String moderationModel() default ""; + /** * When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT}, * this attribute specifies the names of beans containing methods annotated with {@link Tool} that should be used by this AI Service. */ String[] tools() default {}; -} \ No newline at end of file +} diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java index 7e215c8e..e2a2bdfc 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java @@ -4,6 +4,7 @@ import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.retriever.ContentRetriever; import dev.langchain4j.service.AiServices; @@ -22,6 +23,7 @@ class AiServiceFactory implements FactoryBean { private ChatMemoryProvider chatMemoryProvider; private ContentRetriever contentRetriever; private RetrievalAugmentor retrievalAugmentor; + private ModerationModel moderationModel; private List tools; public AiServiceFactory(Class aiServiceClass) { @@ -52,6 +54,10 @@ public void setRetrievalAugmentor(RetrievalAugmentor retrievalAugmentor) { this.retrievalAugmentor = retrievalAugmentor; } + public void setModerationModel(ModerationModel moderationModel) { + this.moderationModel = moderationModel; + } + public void setTools(List tools) { this.tools = tools; } @@ -83,6 +89,10 @@ public Object getObject() { builder = builder.contentRetriever(contentRetriever); } + if (moderationModel != null) { + builder = builder.moderationModel(moderationModel); + } + if (!isNullOrEmpty(tools)) { builder = builder.tools(tools); } diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java index 575769e8..90567db3 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java @@ -6,6 +6,7 @@ import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.retriever.ContentRetriever; import org.springframework.beans.MutablePropertyValues; @@ -43,6 +44,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { String[] chatMemoryProviders = beanFactory.getBeanNamesForType(ChatMemoryProvider.class); String[] contentRetrievers = beanFactory.getBeanNamesForType(ContentRetriever.class); String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class); + String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class); Set tools = new HashSet<>(); for (String beanName : beanFactory.getBeanDefinitionNames()) { @@ -129,6 +131,16 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { propertyValues ); + addBeanReference( + ModerationModel.class, + aiServiceAnnotation, + aiServiceAnnotation.moderationModel(), + moderationModels, + "moderationModel", + "moderationModel", + propertyValues + ); + if (aiServiceAnnotation.wiringMode() == EXPLICIT) { propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools()))); } else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) { diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModel.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModel.java new file mode 100644 index 00000000..c68abcc9 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModel.java @@ -0,0 +1,12 @@ +package dev.langchain4j.service.spring.mode.automatic.withModerationModel; + +import dev.langchain4j.service.Moderate; +import dev.langchain4j.service.spring.AiService; + +@AiService +interface AiServiceWithModerationModel { + + @Moderate + String chat(String userMessage); + +} diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModelApplication.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModelApplication.java new file mode 100644 index 00000000..3343c073 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModelApplication.java @@ -0,0 +1,34 @@ +package dev.langchain4j.service.spring.mode.automatic.withModerationModel; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.moderation.Moderation; +import dev.langchain4j.model.moderation.ModerationModel; +import dev.langchain4j.model.output.Response; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; + +import java.util.List; + +@SpringBootApplication +class AiServiceWithModerationModelApplication { + + @Bean + ModerationModel moderationModel() { + return new ModerationModel() { + @Override + public Response moderate(String s) { + return Response.from(Moderation.flagged("Flagged")); + } + + @Override + public Response moderate(List list) { + return Response.from(Moderation.flagged("Flagged")); + } + }; + } + + public static void main(String[] args) { + SpringApplication.run(AiServiceWithModerationModelApplication.class, args); + } +} diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModelIT.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModelIT.java new file mode 100644 index 00000000..90d6a0d8 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withModerationModel/AiServiceWithModerationModelIT.java @@ -0,0 +1,38 @@ +package dev.langchain4j.service.spring.mode.automatic.withModerationModel; + +import dev.langchain4j.service.ModerationException; +import dev.langchain4j.service.spring.AiServicesAutoConfig; +import org.junit.jupiter.api.Test; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class AiServiceWithModerationModelIT { + + ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class)); + + @Test + void should_create_AI_service_with_moderation_model() { + contextRunner + .withPropertyValues( + "langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY, + "langchain4j.open-ai.chat-model.max-tokens=20", + "langchain4j.open-ai.chat-model.temperature=0.0" + ) + .withUserConfiguration(AiServiceWithModerationModelApplication.class) + .run(context -> { + + // given + AiServiceWithModerationModel aiService = context.getBean(AiServiceWithModerationModel.class); + + // when & then + assertThatThrownBy(() -> aiService.chat("I'm violating content policy")) + .isInstanceOf(ModerationException.class) + .hasMessageContaining("Flagged"); + + }); + } +}