Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Change text embedding processor to async mode for better isolation #35

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -109,29 +107,6 @@ public void inferenceSentences(
}, listener::onFailure));
}

/**
* Abstraction to call predict function of api of MLClient with provided targetResponseFilters. It uses the
* custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent
* using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of
* inputText. We are not making this function generic enough to take any function or TaskType as currently we need
* to run only TextEmbedding tasks only. Please note this method is a blocking method, use this only when the processing
* needs block waiting for response, otherwise please use {@link #inferenceSentences(String, List, ActionListener)}
* instead.
* @param modelId {@link String}
* @param inputText {@link List} of {@link String} on which inference needs to happen.
* @return {@link List} of {@link List} of {@link String} represents the text embedding vector result.
* @throws ExecutionException If the underlying task failed, this exception will be thrown in the future.get().
* @throws InterruptedException If the thread is interrupted, this will be thrown.
*/
public List<List<Float>> inferenceSentences(@NonNull final String modelId, @NonNull final List<String> inputText)
throws ExecutionException, InterruptedException {
final MLInput mlInput = createMLInput(TARGET_RESPONSE_FILTERS, inputText);
final ActionFuture<MLOutput> outputActionFuture = mlClient.predict(modelId, mlInput);
final List<List<Float>> vector = buildVectorFromResponse(outputActionFuture.get());
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
return vector;
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
Expand Down Expand Up @@ -80,17 +81,33 @@ private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {

@Override
public IngestDocument execute(IngestDocument ingestDocument) {
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument);
return ingestDocument;
}

/**
* This method will be invoked by PipelineService to make async inference and then delegate the handler to
* process the inference response or failure.
* @param ingestDocument {@link IngestDocument} which is the document passed to processor.
* @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done.
*/
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
// When received a bulk indexing request, the pipeline will be executed in this method, (see
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226).
// Before the pipeline execution, the pipeline will be marked as resolved (means executed),
// and then this overriding method will be invoked when executing the text embedding processor.
// After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation.
try {
List<List<Float>> vectors = mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap));
appendVectorFieldsToDocument(ingestDocument, knnMap, vectors);
} catch (ExecutionException | InterruptedException e) {
log.error("Text embedding processor failed with exception: ", e);
throw new RuntimeException("Text embedding processor failed with exception", e);
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument);
mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap), ActionListener.wrap(vectors -> {
appendVectorFieldsToDocument(ingestDocument, knnMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
} catch (Exception e) {
handler.accept(null, e);
}
log.debug("Text embedding completed, returning ingestDocument!");
return ingestDocument;

}

void appendVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> knnMap, List<List<Float>> vectors) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,17 @@

package org.opensearch.neuralsearch.ml;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import lombok.SneakyThrows;

import org.junit.Before;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -35,8 +27,6 @@
import org.opensearch.neuralsearch.constants.TestCommonConstants;
import org.opensearch.test.OpenSearchTestCase;

import com.google.common.collect.ImmutableList;

public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

@Mock
Expand Down Expand Up @@ -124,26 +114,6 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() {
Mockito.verifyNoMoreInteractions(resultListener);
}

@SneakyThrows
public void test_blockingInferenceSentences() {
ActionFuture actionFuture = mock(ActionFuture.class);
when(client.predict(anyString(), any(MLInput.class))).thenReturn(actionFuture);
List<ModelTensors> tensorsList = new ArrayList<>();

List<ModelTensor> tensors = new ArrayList<>();
ModelTensor tensor = mock(ModelTensor.class);
when(tensor.getData()).thenReturn(TestCommonConstants.PREDICT_VECTOR_ARRAY);
tensors.add(tensor);

ModelTensors modelTensors = new ModelTensors(tensors);
tensorsList.add(modelTensors);

ModelTensorOutput mlOutput = new ModelTensorOutput(tensorsList);
when(actionFuture.get()).thenReturn(mlOutput);
List<List<Float>> result = accessor.inferenceSentences("modelId", ImmutableList.of("mock"));
assertEquals(TestCommonConstants.PREDICT_VECTOR_ARRAY[0], result.get(0).get(0));
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

import org.apache.http.HttpHeaders;
Expand Down Expand Up @@ -36,13 +35,8 @@ public void testTextEmbeddingProcessor() throws Exception {
}

private String uploadTextEmbeddingModel() throws Exception {
String currentPath = System.getProperty("user.dir");
Path testClusterPath = Path.of(currentPath).getParent().resolveSibling("testclusters/integTest-0/data");
Path path = Path.of(testClusterPath + "/all-MiniLM-L6-v2.zip");
Files.copy(Path.of(classLoader.getResource("model/all-MiniLM-L6-v2.zip").toURI()), path);
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
String request = String.format(Locale.getDefault(), requestBody, path);
return uploadModel(request);
return uploadModel(requestBody);
}

private void createTextEmbeddingIndex() throws Exception {
Expand Down
Loading