From fd8ed4b7fd410b85b90f914a0f2a2dcb9b2befd6 Mon Sep 17 00:00:00 2001 From: LangChain4j Date: Sun, 15 Sep 2024 10:11:07 +0200 Subject: [PATCH] Closes https://github.com/langchain4j/langchain4j/issues/886 --- .../reactor/TokenStreamToFluxAdapter.java | 33 ++++++++++++++++ ...angchain4j.spi.services.TokenStreamAdapter | 1 + .../reactor/AiServiceWithFluxTest.java | 39 +++++++++++++++++++ .../reactor/TokenStreamToFluxAdapterTest.java | 39 +++++++++++++++++++ .../langchain4j/service/spring/AiService.java | 2 - pom.xml | 8 ++++ 6 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 langchain4j-reactor/src/main/java/dev/langchain4j/reactor/TokenStreamToFluxAdapter.java create mode 100644 langchain4j-reactor/src/main/resources/META-INF/services/dev.langchain4j.spi.services.TokenStreamAdapter create mode 100644 langchain4j-reactor/src/test/java/dev/langchain4j/reactor/AiServiceWithFluxTest.java create mode 100644 langchain4j-reactor/src/test/java/dev/langchain4j/reactor/TokenStreamToFluxAdapterTest.java diff --git a/langchain4j-reactor/src/main/java/dev/langchain4j/reactor/TokenStreamToFluxAdapter.java b/langchain4j-reactor/src/main/java/dev/langchain4j/reactor/TokenStreamToFluxAdapter.java new file mode 100644 index 00000000..ad4d301d --- /dev/null +++ b/langchain4j-reactor/src/main/java/dev/langchain4j/reactor/TokenStreamToFluxAdapter.java @@ -0,0 +1,33 @@ +package dev.langchain4j.reactor; + +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.spi.services.TokenStreamAdapter; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +public class TokenStreamToFluxAdapter implements TokenStreamAdapter { + + @Override + public boolean canAdaptTokenStreamTo(Type type) { + if (type instanceof ParameterizedType parameterizedType) { + if (parameterizedType.getRawType() == Flux.class) { + Type[] typeArguments = parameterizedType.getActualTypeArguments(); + return typeArguments.length == 1 && typeArguments[0] == String.class; + } + } + return false; + } + + @Override + public Object adapt(TokenStream tokenStream) { + Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer(); + tokenStream.onNext(sink::tryEmitNext) + .onComplete(aiMessageResponse -> sink.tryEmitComplete()) + .onError(sink::tryEmitError) + .start(); + return sink.asFlux(); + } +} diff --git a/langchain4j-reactor/src/main/resources/META-INF/services/dev.langchain4j.spi.services.TokenStreamAdapter b/langchain4j-reactor/src/main/resources/META-INF/services/dev.langchain4j.spi.services.TokenStreamAdapter new file mode 100644 index 00000000..c4898c47 --- /dev/null +++ b/langchain4j-reactor/src/main/resources/META-INF/services/dev.langchain4j.spi.services.TokenStreamAdapter @@ -0,0 +1 @@ +dev.langchain4j.reactor.TokenStreamToFluxAdapter \ No newline at end of file diff --git a/langchain4j-reactor/src/test/java/dev/langchain4j/reactor/AiServiceWithFluxTest.java b/langchain4j-reactor/src/test/java/dev/langchain4j/reactor/AiServiceWithFluxTest.java new file mode 100644 index 00000000..01f2c35c --- /dev/null +++ b/langchain4j-reactor/src/test/java/dev/langchain4j/reactor/AiServiceWithFluxTest.java @@ -0,0 +1,39 @@ +package dev.langchain4j.reactor; + +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.mock.StreamingChatModelMock; +import dev.langchain4j.service.AiServices; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import java.util.List; + +public class AiServiceWithFluxTest { + + interface Assistant { + + Flux stream(String userMessage); + } + + @Test + void should_stream() { + + // given + List tokens = List.of("The", " capital", " of", " Germany", " is", " Berlin", "."); + + StreamingChatLanguageModel model = StreamingChatModelMock.thatAlwaysStreams(tokens); + + Assistant assistant = AiServices.builder(Assistant.class) + .streamingChatLanguageModel(model) + .build(); + + // when + Flux flux = assistant.stream("What is the capital of Germany?"); + + // then + StepVerifier.create(flux) + .expectNextSequence(tokens) + .verifyComplete(); + } +} diff --git a/langchain4j-reactor/src/test/java/dev/langchain4j/reactor/TokenStreamToFluxAdapterTest.java b/langchain4j-reactor/src/test/java/dev/langchain4j/reactor/TokenStreamToFluxAdapterTest.java new file mode 100644 index 00000000..aea70aef --- /dev/null +++ b/langchain4j-reactor/src/test/java/dev/langchain4j/reactor/TokenStreamToFluxAdapterTest.java @@ -0,0 +1,39 @@ +package dev.langchain4j.reactor; + +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; + +import java.lang.reflect.Type; + +import static org.assertj.core.api.Assertions.assertThat; + +class TokenStreamToFluxAdapterTest { + + interface Assistant { + + Flux fluxOfString(); + + Flux flux(); + + Flux fluxOfObject(); + } + + @Test + void test_canAdapt() { + + TokenStreamToFluxAdapter adapter = new TokenStreamToFluxAdapter(); + + assertThat(adapter.canAdaptTokenStreamTo(getReturnTypeOfMethod("fluxOfString"))).isTrue(); + + assertThat(adapter.canAdaptTokenStreamTo(getReturnTypeOfMethod("flux"))).isFalse(); + assertThat(adapter.canAdaptTokenStreamTo(getReturnTypeOfMethod("fluxOfObject"))).isFalse(); + } + + private static Type getReturnTypeOfMethod(String methodName) { + try { + return Assistant.class.getDeclaredMethod(methodName).getGenericReturnType(); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } +} \ No newline at end of file diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java index c2186988..169c3f7e 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java @@ -96,6 +96,4 @@ * this attribute specifies the names of beans containing methods annotated with {@link Tool} that should be used by this AI Service. */ String[] tools() default {}; - - // TODO support Flux return type for AI Service method(s) (for streaming) } \ No newline at end of file diff --git a/pom.xml b/pom.xml index ce9d5df5..22aec9c9 100644 --- a/pom.xml +++ b/pom.xml @@ -27,6 +27,8 @@ langchain4j-redis-spring-boot-starter langchain4j-qianfan-spring-boot-starter langchain4j-milvus-spring-boot-starter + + langchain4j-reactor @@ -56,6 +58,12 @@ ${spring.boot.version} + + org.springframework.boot + spring-boot-starter-webflux + ${spring.boot.version} + + org.springframework.boot spring-boot-autoconfigure-processor