From d7319984fe7c55a685ae7b618927a73e38a40fef Mon Sep 17 00:00:00 2001 From: Andrea Di Maio Date: Mon, 20 Jan 2025 18:28:27 +0100 Subject: [PATCH] Handle API request limits in WatsonxEmbeddingModel --- .../watsonx/deployment/AiEmbeddingTest.java | 48 +++++++++++++++++++ .../watsonx/WatsonxEmbeddingModel.java | 41 +++++++++------- 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiEmbeddingTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiEmbeddingTest.java index 52a2229b5..974572b41 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiEmbeddingTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiEmbeddingTest.java @@ -5,6 +5,9 @@ import java.util.Date; import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import jakarta.inject.Inject; @@ -13,6 +16,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import com.github.tomakehurst.wiremock.stubbing.Scenario; + import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -134,6 +139,49 @@ void test_embed_list_of_three_textsegment() throws Exception { assertEquals(vector, response.content().get(2).vectorAsList()); } + @Test + public void test_high_embedding_text_segments() throws Exception { + mockServers.mockIAMBuilder(200) + .response(WireMockUtil.BEARER_TOKEN, new Date()) + .build(); + + var RESPONSE = """ + { + "model_id": "%s", + "results": [], + "created_at": "2024-02-21T17:32:28Z", + "input_token_count": 10 + } + """; + + Function, EmbeddingRequest> createRequest = (List elementsToEmbed) -> { + return new EmbeddingRequest(WireMockUtil.DEFAULT_EMBEDDING_MODEL, null, WireMockUtil.PROJECT_ID, elementsToEmbed, + null); + }; + + var list = IntStream.rangeClosed(1, 2001).mapToObj(String::valueOf).collect(Collectors.toList()); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200) + .scenario(Scenario.STARTED, "SECOND_CALL") + .body(mapper.writeValueAsString(createRequest.apply(list.subList(0, 1000)))) + .response(RESPONSE.formatted(WireMockUtil.DEFAULT_EMBEDDING_MODEL)) + .build(); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200) + .scenario("SECOND_CALL", "THIRD_CALL") + .body(mapper.writeValueAsString(createRequest.apply(list.subList(1000, 2000)))) + .response(RESPONSE.formatted(WireMockUtil.DEFAULT_EMBEDDING_MODEL)) + .build(); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_EMBEDDING_API, 200) + .scenario("THIRD_CALL", Scenario.STARTED) + .body(mapper.writeValueAsString(createRequest.apply(list.subList(2000, 2001)))) + .response(RESPONSE.formatted(WireMockUtil.DEFAULT_EMBEDDING_MODEL)) + .build(); + + embeddingModel.embedAll(list.stream().map(TextSegment::textSegment).toList()); + } + private List mockEmbeddingServer(String input) throws Exception { mockServers.mockIAMBuilder(200) .response(WireMockUtil.BEARER_TOKEN, new Date()) diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java index bf126cd47..b5a0b20cc 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxEmbeddingModel.java @@ -2,6 +2,7 @@ import static io.quarkiverse.langchain4j.watsonx.WatsonxUtils.retryOn; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.concurrent.Callable; @@ -21,6 +22,7 @@ public class WatsonxEmbeddingModel extends Watsonx implements EmbeddingModel, TokenCountEstimator { private final EmbeddingParameters parameters; + private static final int MAX_SIZE = 1000; public WatsonxEmbeddingModel(Builder builder) { super(builder); @@ -37,24 +39,29 @@ public Response> embedAll(List textSegments) { if (Objects.isNull(textSegments) || textSegments.isEmpty()) return Response.from(List.of()); - var inputs = textSegments.stream() - .map(TextSegment::text) - .collect(Collectors.toList()); - - EmbeddingRequest request = new EmbeddingRequest(modelId, spaceId, projectId, inputs, parameters); - EmbeddingResponse result = retryOn(new Callable() { - @Override - public EmbeddingResponse call() throws Exception { - return client.embeddings(request, version); - } - }); + List result = new ArrayList<>(); + + // Watsonx.ai embedding API allows a maximum of 1000 elements per request. + for (int fromIndex = 0; fromIndex < textSegments.size(); fromIndex += MAX_SIZE) { + int toIndex = Math.min(fromIndex + MAX_SIZE, textSegments.size()); + List subList = textSegments.subList(fromIndex, toIndex).stream() + .map(TextSegment::text) + .collect(Collectors.toList()); + + EmbeddingRequest request = new EmbeddingRequest(modelId, spaceId, projectId, subList, parameters); + EmbeddingResponse embeddingResponse = retryOn(new Callable() { + @Override + public EmbeddingResponse call() throws Exception { + return client.embeddings(request, version); + } + }); + result.addAll(embeddingResponse.results().stream() + .map(Result::embedding) + .map(Embedding::from) + .toList()); + } - return Response.from( - result.results() - .stream() - .map(Result::embedding) - .map(Embedding::from) - .collect(Collectors.toList())); + return Response.from(result); } @Override