Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change text embedding processor to async mode for better isolation #27

Merged
merged 2 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -109,29 +107,6 @@ public void inferenceSentences(
}, listener::onFailure));
}

/**
* Abstraction to call predict function of api of MLClient with provided targetResponseFilters. It uses the
* custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent
* using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of
* inputText. We are not making this function generic enough to take any function or TaskType as currently we need
* to run only TextEmbedding tasks only. Please note this method is a blocking method, use this only when the processing
* needs block waiting for response, otherwise please use {@link #inferenceSentences(String, List, ActionListener)}
* instead.
* @param modelId {@link String}
* @param inputText {@link List} of {@link String} on which inference needs to happen.
* @return {@link List} of {@link List} of {@link String} represents the text embedding vector result.
* @throws ExecutionException If the underlying task failed, this exception will be thrown in the future.get().
* @throws InterruptedException If the thread is interrupted, this will be thrown.
*/
public List<List<Float>> inferenceSentences(@NonNull final String modelId, @NonNull final List<String> inputText)
throws ExecutionException, InterruptedException {
final MLInput mlInput = createMLInput(TARGET_RESPONSE_FILTERS, inputText);
final ActionFuture<MLOutput> outputActionFuture = mlClient.predict(modelId, mlInput);
final List<List<Float>> vector = buildVectorFromResponse(outputActionFuture.get());
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
return vector;
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
Expand Down Expand Up @@ -80,17 +81,33 @@ private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {

@Override
public IngestDocument execute(IngestDocument ingestDocument) {
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument);
return ingestDocument;
}
navneet1v marked this conversation as resolved.
Show resolved Hide resolved

/**
* This method will be invoked by PipelineService to make async inference and then delegate the handler to
* process the inference response or failure.
* @param ingestDocument {@link IngestDocument} which is the document passed to processor.
* @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done.
*/
@Override
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, could we refactor this method to batch calls to the model so we can improve throughput?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmazanec15 what kind of batching you are suggesting here? Is it like batching the write or batching the inference calls?

Copy link
Collaborator Author

@zane-neo zane-neo Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's doable, but two things we need to consider.

  1. Effort, changing to bulk we need more effort and we need to consider a proper threshold either a batch number of linger time, and we need to consider if we need to expose these settings to user.
  2. Benefit, in a cluster each two nodes have TCP connection keeping alive, there's no overhead and the network time consuming is relative low comparing with CPU consuming. Based on the performance testing, the bottle neck is the CPU resource instead of network IO. So changing to batch may not increase performance dramatically.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zane-neo That makes sense. It might make sense to create a proof of concept test in future to see what the benefit would be.

@navneet1v the idea would be to batch the inference calls. For the async action, instead of calling inference directly, submit the update to some kind of queue that will then create 1 large request for multiple docs and call the handlers once it completes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zane-neo That makes sense. It might make sense to create a proof of concept test in future to see what the benefit would be.

@navneet1v the idea would be to batch the inference calls. For the async action, instead of calling inference directly, submit the update to some kind of queue that will then create 1 large request for multiple docs and call the handlers once it completes.

I am not like 100% convinced on the idea of batching and the reason behind that is batching the data and making 1 call, will help if the number of threads are less at the ML node level and multiple requests are competing for the threads(api call threads).
Secondly, for creating the batch we need to put the data in a sync queue, that will further slow down the processing.
Third, batches create problem where 1 single input sentence can delay all the documents processing.

But we can discuss further on this, but before even doing the POC, we should first do some deep-dive on the ML-Commons code to know if batching can really help or not.

public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an override function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is an override function, added the override annotation also added proper java doc on this method.

navneet1v marked this conversation as resolved.
Show resolved Hide resolved
// When received a bulk indexing request, the pipeline will be executed in this method, (see
// https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226).
// Before the pipeline execution, the pipeline will be marked as resolved (means executed),
// and then this overriding method will be invoked when executing the text embedding processor.
// After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation.
try {
List<List<Float>> vectors = mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap));
appendVectorFieldsToDocument(ingestDocument, knnMap, vectors);
} catch (ExecutionException | InterruptedException e) {
log.error("Text embedding processor failed with exception: ", e);
throw new RuntimeException("Text embedding processor failed with exception", e);
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument);
mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap), ActionListener.wrap(vectors -> {
appendVectorFieldsToDocument(ingestDocument, knnMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
} catch (Exception e) {
handler.accept(null, e);
}
Comment on lines +107 to 109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which function is throwing the exception which we are catching here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

underlying validateEmbeddingFieldsValue is throwing IllegalArgumentException.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the exception is coming from this function only, add the exception around that function only and take out predict API call from the try block.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked again, not only the method validateEmbeddingFieldsValue is throwing exception but also the mlClient.predict, and the listener.onFailure is not invoked when exception arise, this exception could populate to execute(IngestDocument, BiConsumer), if not catching it, this will further populate to upper methods which could impact the pipeline execution, this is what we don't want. Also the parent method also use try catch to catch all the exceptions please refer to: https://github.com/opensearch-project/OpenSearch/blob/8c9ca4e858e6333265080972cf57809dbc086208/server/src/main/java/org/opensearch/ingest/Processor.java#L64.

log.debug("Text embedding completed, returning ingestDocument!");
return ingestDocument;

}

void appendVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> knnMap, List<List<Float>> vectors) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,17 @@

package org.opensearch.neuralsearch.ml;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import lombok.SneakyThrows;

import org.junit.Before;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -35,8 +27,6 @@
import org.opensearch.neuralsearch.constants.TestCommonConstants;
import org.opensearch.test.OpenSearchTestCase;

import com.google.common.collect.ImmutableList;

public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

@Mock
Expand Down Expand Up @@ -124,26 +114,6 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() {
Mockito.verifyNoMoreInteractions(resultListener);
}

@SneakyThrows
public void test_blockingInferenceSentences() {
ActionFuture actionFuture = mock(ActionFuture.class);
when(client.predict(anyString(), any(MLInput.class))).thenReturn(actionFuture);
List<ModelTensors> tensorsList = new ArrayList<>();

List<ModelTensor> tensors = new ArrayList<>();
ModelTensor tensor = mock(ModelTensor.class);
when(tensor.getData()).thenReturn(TestCommonConstants.PREDICT_VECTOR_ARRAY);
tensors.add(tensor);

ModelTensors modelTensors = new ModelTensors(tensors);
tensorsList.add(modelTensors);

ModelTensorOutput mlOutput = new ModelTensorOutput(tensorsList);
when(actionFuture.get()).thenReturn(mlOutput);
List<List<Float>> result = accessor.inferenceSentences("modelId", ImmutableList.of("mock"));
assertEquals(TestCommonConstants.PREDICT_VECTOR_ARRAY[0], result.get(0).get(0));
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

import org.apache.http.HttpHeaders;
Expand Down Expand Up @@ -36,13 +35,8 @@ public void testTextEmbeddingProcessor() throws Exception {
}

private String uploadTextEmbeddingModel() throws Exception {
String currentPath = System.getProperty("user.dir");
Path testClusterPath = Path.of(currentPath).getParent().resolveSibling("testclusters/integTest-0/data");
Path path = Path.of(testClusterPath + "/all-MiniLM-L6-v2.zip");
Files.copy(Path.of(classLoader.getResource("model/all-MiniLM-L6-v2.zip").toURI()), path);
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
String request = String.format(Locale.getDefault(), requestBody, path);
return uploadModel(request);
return uploadModel(requestBody);
}

private void createTextEmbeddingIndex() throws Exception {
Expand Down
Loading