diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 68aacd2cf..3ac0f7979 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -8,14 +8,12 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; -import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; @@ -109,29 +107,6 @@ public void inferenceSentences( }, listener::onFailure)); } - /** - * Abstraction to call predict function of api of MLClient with provided targetResponseFilters. It uses the - * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of - * inputText. We are not making this function generic enough to take any function or TaskType as currently we need - * to run only TextEmbedding tasks only. Please note this method is a blocking method, use this only when the processing - * needs block waiting for response, otherwise please use {@link #inferenceSentences(String, List, ActionListener)} - * instead. - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen. - * @return {@link List} of {@link List} of {@link String} represents the text embedding vector result. - * @throws ExecutionException If the underlying task failed, this exception will be thrown in the future.get(). - * @throws InterruptedException If the thread is interrupted, this will be thrown. - */ - public List> inferenceSentences(@NonNull final String modelId, @NonNull final List inputText) - throws ExecutionException, InterruptedException { - final MLInput mlInput = createMLInput(TARGET_RESPONSE_FILTERS, inputText); - final ActionFuture outputActionFuture = mlClient.predict(modelId, mlInput); - final List> vector = buildVectorFromResponse(outputActionFuture.get()); - log.debug("Inference Response for input sentence {} is : {} ", inputText, vector); - return vector; - } - private MLInput createMLInput(final List targetResponseFilters, List inputText) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 36271d83c..2d1841ea0 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -10,13 +10,14 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.concurrent.ExecutionException; +import java.util.function.BiConsumer; import java.util.function.Supplier; import java.util.stream.IntStream; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; @@ -80,17 +81,33 @@ private void validateEmbeddingConfiguration(Map fieldMap) { @Override public IngestDocument execute(IngestDocument ingestDocument) { - validateEmbeddingFieldsValue(ingestDocument); - Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); + return ingestDocument; + } + + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + // When received a bulk indexing request, the pipeline will be executed in this method, (see + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). + // Before the pipeline execution, the pipeline will be marked as resolved (means executed), + // and then this overriding method will be invoked when executing the text embedding processor. + // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. try { - List> vectors = mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap)); - appendVectorFieldsToDocument(ingestDocument, knnMap, vectors); - } catch (ExecutionException | InterruptedException e) { - log.error("Text embedding processor failed with exception: ", e); - throw new RuntimeException("Text embedding processor failed with exception", e); + validateEmbeddingFieldsValue(ingestDocument); + Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); + mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap), ActionListener.wrap(vectors -> { + appendVectorFieldsToDocument(ingestDocument, knnMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); + } catch (Exception e) { + handler.accept(null, e); } - log.debug("Text embedding completed, returning ingestDocument!"); - return ingestDocument; + } void appendVectorFieldsToDocument(IngestDocument ingestDocument, Map knnMap, List> vectors) { diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 91335d970..81523d8d7 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,25 +5,17 @@ package org.opensearch.neuralsearch.ml; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import lombok.SneakyThrows; - import org.junit.Before; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; -import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.input.MLInput; @@ -35,8 +27,6 @@ import org.opensearch.neuralsearch.constants.TestCommonConstants; import org.opensearch.test.OpenSearchTestCase; -import com.google.common.collect.ImmutableList; - public class MLCommonsClientAccessorTests extends OpenSearchTestCase { @Mock @@ -124,26 +114,6 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { Mockito.verifyNoMoreInteractions(resultListener); } - @SneakyThrows - public void test_blockingInferenceSentences() { - ActionFuture actionFuture = mock(ActionFuture.class); - when(client.predict(anyString(), any(MLInput.class))).thenReturn(actionFuture); - List tensorsList = new ArrayList<>(); - - List tensors = new ArrayList<>(); - ModelTensor tensor = mock(ModelTensor.class); - when(tensor.getData()).thenReturn(TestCommonConstants.PREDICT_VECTOR_ARRAY); - tensors.add(tensor); - - ModelTensors modelTensors = new ModelTensors(tensors); - tensorsList.add(modelTensors); - - ModelTensorOutput mlOutput = new ModelTensorOutput(tensorsList); - when(actionFuture.get()).thenReturn(mlOutput); - List> result = accessor.inferenceSentences("modelId", ImmutableList.of("mock")); - assertEquals(TestCommonConstants.PREDICT_VECTOR_ARRAY[0], result.get(0).get(0)); - } - private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 71a7ac8bf..1e8312401 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -7,7 +7,6 @@ import java.nio.file.Files; import java.nio.file.Path; -import java.util.Locale; import java.util.Map; import org.apache.http.HttpHeaders; @@ -36,13 +35,8 @@ public void testTextEmbeddingProcessor() throws Exception { } private String uploadTextEmbeddingModel() throws Exception { - String currentPath = System.getProperty("user.dir"); - Path testClusterPath = Path.of(currentPath).getParent().resolveSibling("testclusters/integTest-0/data"); - Path path = Path.of(testClusterPath + "/all-MiniLM-L6-v2.zip"); - Files.copy(Path.of(classLoader.getResource("model/all-MiniLM-L6-v2.zip").toURI()), path); String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); - String request = String.format(Locale.getDefault(), requestBody, path); - return uploadModel(request); + return uploadModel(requestBody); } private void createTextEmbeddingIndex() throws Exception { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 15af12157..ec28ea285 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -7,14 +7,14 @@ import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; import java.util.function.Supplier; import org.junit.Before; @@ -22,6 +22,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; +import org.opensearch.action.ActionListener; import org.opensearch.common.settings.Settings; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; @@ -59,7 +60,6 @@ private TextEmbeddingProcessor createInstance(List> vector) throws E config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); - when(mlCommonsClientAccessor.inferenceSentences(anyString(), anyList())).thenReturn(vector); return processor; } @@ -95,8 +95,17 @@ public void testExecute_successful() throws Exception { sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - IngestDocument document = processor.execute(ingestDocument); - assert document.getSourceAndMetadata().containsKey("key1"); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); } public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() throws Exception { @@ -112,12 +121,10 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); - when(accessor.inferenceSentences(anyString(), anyList())).thenThrow(new InterruptedException()); - try { - processor.execute(ingestDocument); - } catch (RuntimeException e) { - assertEquals("Text embedding processor failed with exception", e.getMessage()); - } + doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(RuntimeException.class)); } public void testExecute_withListTypeInput_successful() throws Exception { @@ -128,8 +135,17 @@ public void testExecute_withListTypeInput_successful() throws Exception { sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(6)); - IngestDocument document = processor.execute(ingestDocument); - assert document.getSourceAndMetadata().containsKey("key1"); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); } public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() throws Exception { @@ -137,11 +153,10 @@ public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentExcep sourceAndMetadata.put("key1", " "); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("field [key1] has empty string value, can not process it", e.getMessage()); - } + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() throws Exception { @@ -150,11 +165,10 @@ public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() sourceAndMetadata.put("key1", list1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("list type field [key1] has empty string, can not process it", e.getMessage()); - } + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_listHasNonStringValue_throwIllegalArgumentException() throws Exception { @@ -163,11 +177,9 @@ public void testExecute_listHasNonStringValue_throwIllegalArgumentException() th sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("list type field [key2] has non string value, can not process it", e.getMessage()); - } + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_listHasNull_throwIllegalArgumentException() throws Exception { @@ -179,11 +191,9 @@ public void testExecute_listHasNull_throwIllegalArgumentException() throws Excep sourceAndMetadata.put("key2", list); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("list type field [key2] has null, can not process it", e.getMessage()); - } + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_withMapTypeInput_successful() throws Exception { @@ -194,8 +204,18 @@ public void testExecute_withMapTypeInput_successful() throws Exception { sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - IngestDocument document = processor.execute(ingestDocument); - assert document.getSourceAndMetadata().containsKey("key1"); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() throws Exception { @@ -206,11 +226,9 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() thr sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("map type field [key2] has non-string type, can not process it", e.getMessage()); - } + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() throws Exception { @@ -221,11 +239,9 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() t sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("map type field [key2] has empty string, can not process it", e.getMessage()); - } + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throws Exception { @@ -235,13 +251,27 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throw sourceAndMetadata.put("key2", ret); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("map type field [key2] reached max depth limit, can not process it", e.getMessage()); - return; - } - fail("Shouldn't be here, expected exception!"); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_MLClientAccessorThrowFail_handlerFailure() throws Exception { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("illegal argument")); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } private Map createMaxDepthLimitExceedMap(Supplier maxDepthSupplier) { @@ -267,17 +297,21 @@ public void testExecute_hybridTypeInput_successful() throws Exception { assert document.getSourceAndMetadata().containsKey("key2"); } - public void testExecute_simpleTypeInputWithNonStringValue_throwIllegalArgumentException() throws Exception { + public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentException() throws Exception { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", 100); sourceAndMetadata.put("key2", 100.232D); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("illegal argument")); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("field [key1] is neither string nor nested type, can not process it", e.getMessage()); - } + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testGetType_successful() throws Exception { diff --git a/src/test/resources/model/all-MiniLM-L6-v2.zip b/src/test/resources/model/all-MiniLM-L6-v2.zip index 4afc4ceda..90d46c2c1 100644 Binary files a/src/test/resources/model/all-MiniLM-L6-v2.zip and b/src/test/resources/model/all-MiniLM-L6-v2.zip differ diff --git a/src/test/resources/processor/UploadModelRequestBody.json b/src/test/resources/processor/UploadModelRequestBody.json index 95b26451c..d56a61f3e 100644 --- a/src/test/resources/processor/UploadModelRequestBody.json +++ b/src/test/resources/processor/UploadModelRequestBody.json @@ -9,5 +9,5 @@ "framework_type": "sentence_transformers", "all_config": "{\"architectures\":[\"BertModel\"],\"max_position_embeddings\":512,\"model_type\":\"bert\",\"num_attention_heads\":12,\"num_hidden_layers\":6}" }, - "url": "file://%s" + "url": "https://github.com/opensearch-project/ml-commons/blob/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_torchscript_sentence-transformer.zip?raw=true" }