From 0711f17d76bb05a47ae72e95bb3cff91a0458f82 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Thu, 12 Oct 2023 17:13:14 -0700 Subject: [PATCH] fix multiple docs support (#1516) Signed-off-by: Yaliang Wu --- .../engine/algorithms/remote/RemoteConnectorExecutor.java | 6 +++++- .../algorithms/remote/HttpJsonConnectorExecutorTest.java | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index cf99934779..c26c79b452 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -37,7 +37,11 @@ default ModelTensorOutput executePredict(MLInput mlInput) { List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); List tempTensorOutputs = new ArrayList<>(); preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs); - processedDocs += Math.max(tempTensorOutputs.size(), 1); + int tensorCount = 0; + if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { + tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); + } + processedDocs += Math.max(tensorCount, 1); tensorOutputs.addAll(tempTensorOutputs); } } else { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index d8cdb4e9d5..6666628d19 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -213,7 +213,8 @@ public void executePredict_TextDocsInput() throws IOException { when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert .assertArrayEquals(