Skip to content

Commit

Permalink
Change text embedding processor to async mode for better isolation (#27)
Browse files Browse the repository at this point in the history
* Change text embedding processor to async mode

Signed-off-by: Zan Niu <[email protected]>

* Address review comments

Signed-off-by: Zan Niu <[email protected]>

Signed-off-by: Zan Niu <[email protected]>
  • Loading branch information
zane-neo authored Oct 27, 2022
1 parent 38de48d commit d538ad1
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -109,29 +107,6 @@ public void inferenceSentences(
}, listener::onFailure));
}

/**
* Abstraction to call predict function of api of MLClient with provided targetResponseFilters. It uses the
* custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent
* using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of
* inputText. We are not making this function generic enough to take any function or TaskType as currently we need
* to run only TextEmbedding tasks only. Please note this method is a blocking method, use this only when the processing
* needs block waiting for response, otherwise please use {@link #inferenceSentences(String, List, ActionListener)}
* instead.
* @param modelId {@link String}
* @param inputText {@link List} of {@link String} on which inference needs to happen.
* @return {@link List} of {@link List} of {@link String} represents the text embedding vector result.
* @throws ExecutionException If the underlying task failed, this exception will be thrown in the future.get().
* @throws InterruptedException If the thread is interrupted, this will be thrown.
*/
public List<List<Float>> inferenceSentences(@NonNull final String modelId, @NonNull final List<String> inputText)
throws ExecutionException, InterruptedException {
final MLInput mlInput = createMLInput(TARGET_RESPONSE_FILTERS, inputText);
final ActionFuture<MLOutput> outputActionFuture = mlClient.predict(modelId, mlInput);
final List<List<Float>> vector = buildVectorFromResponse(outputActionFuture.get());
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
return vector;
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
Expand Down Expand Up @@ -80,17 +81,33 @@ private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {

@Override
public IngestDocument execute(IngestDocument ingestDocument) {
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument);
return ingestDocument;
}

/**
* This method will be invoked by PipelineService to make async inference and then delegate the handler to
* process the inference response or failure.
* @param ingestDocument {@link IngestDocument} which is the document passed to processor.
* @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done.
*/
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
// When received a bulk indexing request, the pipeline will be executed in this method, (see
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226).
// Before the pipeline execution, the pipeline will be marked as resolved (means executed),
// and then this overriding method will be invoked when executing the text embedding processor.
// After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation.
try {
List<List<Float>> vectors = mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap));
appendVectorFieldsToDocument(ingestDocument, knnMap, vectors);
} catch (ExecutionException | InterruptedException e) {
log.error("Text embedding processor failed with exception: ", e);
throw new RuntimeException("Text embedding processor failed with exception", e);
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument);
mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap), ActionListener.wrap(vectors -> {
appendVectorFieldsToDocument(ingestDocument, knnMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
} catch (Exception e) {
handler.accept(null, e);
}
log.debug("Text embedding completed, returning ingestDocument!");
return ingestDocument;

}

void appendVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> knnMap, List<List<Float>> vectors) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,17 @@

package org.opensearch.neuralsearch.ml;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import lombok.SneakyThrows;

import org.junit.Before;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -35,8 +27,6 @@
import org.opensearch.neuralsearch.constants.TestCommonConstants;
import org.opensearch.test.OpenSearchTestCase;

import com.google.common.collect.ImmutableList;

public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

@Mock
Expand Down Expand Up @@ -124,26 +114,6 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() {
Mockito.verifyNoMoreInteractions(resultListener);
}

@SneakyThrows
public void test_blockingInferenceSentences() {
ActionFuture actionFuture = mock(ActionFuture.class);
when(client.predict(anyString(), any(MLInput.class))).thenReturn(actionFuture);
List<ModelTensors> tensorsList = new ArrayList<>();

List<ModelTensor> tensors = new ArrayList<>();
ModelTensor tensor = mock(ModelTensor.class);
when(tensor.getData()).thenReturn(TestCommonConstants.PREDICT_VECTOR_ARRAY);
tensors.add(tensor);

ModelTensors modelTensors = new ModelTensors(tensors);
tensorsList.add(modelTensors);

ModelTensorOutput mlOutput = new ModelTensorOutput(tensorsList);
when(actionFuture.get()).thenReturn(mlOutput);
List<List<Float>> result = accessor.inferenceSentences("modelId", ImmutableList.of("mock"));
assertEquals(TestCommonConstants.PREDICT_VECTOR_ARRAY[0], result.get(0).get(0));
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

import org.apache.http.HttpHeaders;
Expand Down Expand Up @@ -36,13 +35,8 @@ public void testTextEmbeddingProcessor() throws Exception {
}

private String uploadTextEmbeddingModel() throws Exception {
String currentPath = System.getProperty("user.dir");
Path testClusterPath = Path.of(currentPath).getParent().resolveSibling("testclusters/integTest-0/data");
Path path = Path.of(testClusterPath + "/all-MiniLM-L6-v2.zip");
Files.copy(Path.of(classLoader.getResource("model/all-MiniLM-L6-v2.zip").toURI()), path);
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
String request = String.format(Locale.getDefault(), requestBody, path);
return uploadModel(request);
return uploadModel(requestBody);
}

private void createTextEmbeddingIndex() throws Exception {
Expand Down
Loading

0 comments on commit d538ad1

Please sign in to comment.