Skip to content

Commit

Permalink
Add retry mechanism for neural search inference (opensearch-project#91)
Browse files Browse the repository at this point in the history
Add basic retry mechanism for neural search inference

Signed-off-by: Zan Niu <[email protected]>
  • Loading branch information
zane-neo authored and navneet1v committed Jan 27, 2023
1 parent 76b0076 commit 46803c3
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.neuralsearch.util.RetryUtil;

/**
* This class will act as an abstraction on the MLCommons client for accessing the ML Capabilities
Expand Down Expand Up @@ -98,13 +99,30 @@ public void inferenceSentences(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener);
}

private void inferenceSentencesWithRetry(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<List<Float>>> listener
) {
MLInput mlInput = createMLInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
listener.onResponse(vector);
}, listener::onFailure));
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
Expand Down
36 changes: 36 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.util;

import java.util.List;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.opensearch.transport.NodeDisconnectedException;
import org.opensearch.transport.NodeNotConnectedException;

import com.google.common.collect.ImmutableList;

public class RetryUtil {

private static final int MAX_RETRY = 3;

private static final List<Class<? extends Throwable>> RETRYABLE_EXCEPTIONS = ImmutableList.of(
NodeNotConnectedException.class,
NodeDisconnectedException.class
);

/**
*
* @param e {@link Exception} which is the exception received to check if retryable.
* @param retryTime {@link int} which is the current retried times.
* @return {@link boolean} which is the result of if current exception needs retry or not.
*/
public static boolean shouldRetry(final Exception e, int retryTime) {
boolean hasRetryException = RETRYABLE_EXCEPTIONS.stream().anyMatch(x -> ExceptionUtils.indexOfThrowable(e, x) != -1);
return hasRetryException && retryTime < MAX_RETRY;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.neuralsearch.ml;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -17,6 +20,7 @@
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
Expand All @@ -26,6 +30,7 @@
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.neuralsearch.constants.TestCommonConstants;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.NodeNotConnectedException;

public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -114,6 +119,47 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() {
Mockito.verifyNoMoreInteractions(resultListener);
}

public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Times() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentences(
TestCommonConstants.TARGET_RESPONSE_FILTERS,
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST,
resultListener
);

Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onFailure(nodeNodeConnectedException);
}

public void testInferenceSentences_whenNotConnectionException_thenNoRetry() {
final IllegalStateException illegalStateException = new IllegalStateException("Illegal state");
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(illegalStateException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentences(
TestCommonConstants.TARGET_RESPONSE_FILTERS,
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST,
resultListener
);

Mockito.verify(client, times(1))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onFailure(illegalStateException);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import lombok.SneakyThrows;

import org.junit.Before;
import org.mockito.InjectMocks;
import org.mockito.Mock;
Expand Down Expand Up @@ -54,16 +56,17 @@ public void setup() {
when(env.settings()).thenReturn(settings);
}

private TextEmbeddingProcessor createInstance(List<List<Float>> vector) throws Exception {
@SneakyThrows
private TextEmbeddingProcessor createInstance(List<List<Float>> vector) {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped"));
TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
return processor;
return textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
}

public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalArgumentException() throws Exception {
@SneakyThrows
public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalArgumentException() {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
Expand All @@ -78,7 +81,8 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalA
}
}

public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() throws Exception {
@SneakyThrows
public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
Expand All @@ -89,7 +93,7 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalA
}
}

public void testExecute_successful() throws Exception {
public void testExecute_successful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "value1");
sourceAndMetadata.put("key2", "value2");
Expand All @@ -108,7 +112,8 @@ public void testExecute_successful() throws Exception {
verify(handler).accept(any(IngestDocument.class), isNull());
}

public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() throws Exception {
@SneakyThrows
public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "value1");
sourceAndMetadata.put("key2", "value2");
Expand All @@ -127,7 +132,8 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep
verify(handler).accept(isNull(), any(RuntimeException.class));
}

public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() throws Exception {
@SneakyThrows
public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
Map<String, Processor.Factory> registry = new HashMap<>();
Expand All @@ -144,7 +150,7 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() thr
verify(handler).accept(any(IngestDocument.class), isNull());
}

public void testExecute_withListTypeInput_successful() throws Exception {
public void testExecute_withListTypeInput_successful() {
List<String> list1 = ImmutableList.of("test1", "test2", "test3");
List<String> list2 = ImmutableList.of("test4", "test5", "test6");
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand All @@ -165,7 +171,7 @@ public void testExecute_withListTypeInput_successful() throws Exception {
verify(handler).accept(any(IngestDocument.class), isNull());
}

public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() throws Exception {
public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", " ");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
Expand All @@ -176,7 +182,7 @@ public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentExcep
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() throws Exception {
public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() {
List<String> list1 = ImmutableList.of("", "test2", "test3");
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", list1);
Expand All @@ -188,7 +194,7 @@ public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException()
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasNonStringValue_throwIllegalArgumentException() throws Exception {
public void testExecute_listHasNonStringValue_throwIllegalArgumentException() {
List<Integer> list2 = ImmutableList.of(1, 2, 3);
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key2", list2);
Expand All @@ -199,7 +205,7 @@ public void testExecute_listHasNonStringValue_throwIllegalArgumentException() th
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasNull_throwIllegalArgumentException() throws Exception {
public void testExecute_listHasNull_throwIllegalArgumentException() {
List<String> list = new ArrayList<>();
list.add("hello");
list.add(null);
Expand All @@ -213,7 +219,7 @@ public void testExecute_listHasNull_throwIllegalArgumentException() throws Excep
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_withMapTypeInput_successful() throws Exception {
public void testExecute_withMapTypeInput_successful() {
Map<String, String> map1 = ImmutableMap.of("test1", "test2");
Map<String, String> map2 = ImmutableMap.of("test4", "test5");
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand All @@ -235,7 +241,7 @@ public void testExecute_withMapTypeInput_successful() throws Exception {

}

public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() throws Exception {
public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() {
Map<String, String> map1 = ImmutableMap.of("test1", "test2");
Map<String, Double> map2 = ImmutableMap.of("test3", 209.3D);
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand All @@ -248,7 +254,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() thr
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() throws Exception {
public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() {
Map<String, String> map1 = ImmutableMap.of("test1", "test2");
Map<String, String> map2 = ImmutableMap.of("test3", " ");
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand All @@ -261,7 +267,7 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() t
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throws Exception {
public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() {
Map<String, Object> ret = createMaxDepthLimitExceedMap(() -> 1);
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "hello world");
Expand All @@ -273,7 +279,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throw
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_MLClientAccessorThrowFail_handlerFailure() throws Exception {
public void testExecute_MLClientAccessorThrowFail_handlerFailure() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "value1");
sourceAndMetadata.put("key2", "value2");
Expand Down Expand Up @@ -303,7 +309,7 @@ private Map<String, Object> createMaxDepthLimitExceedMap(Supplier<Integer> maxDe
return innerMap;
}

public void testExecute_hybridTypeInput_successful() throws Exception {
public void testExecute_hybridTypeInput_successful() {
List<String> list1 = ImmutableList.of("test1", "test2");
Map<String, List<String>> map1 = ImmutableMap.of("test3", list1);
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand All @@ -314,7 +320,7 @@ public void testExecute_hybridTypeInput_successful() throws Exception {
assert document.getSourceAndMetadata().containsKey("key2");
}

public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentException() throws Exception {
public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", 100);
sourceAndMetadata.put("key2", 100.232D);
Expand All @@ -331,7 +337,7 @@ public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentE
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testGetType_successful() throws Exception {
public void testGetType_successful() {
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
assert processor.getType().equals(TextEmbeddingProcessor.TYPE);
}
Expand All @@ -348,7 +354,8 @@ public void testProcessResponse_successful() throws Exception {
assertEquals(12, ingestDocument.getSourceAndMetadata().size());
}

public void testBuildVectorOutput_withPlainStringValue_successful() throws Exception {
@SneakyThrows
public void testBuildVectorOutput_withPlainStringValue_successful() {
Map<String, Object> config = createPlainStringConfiguration();
IngestDocument ingestDocument = createPlainIngestDocument();
TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
Expand All @@ -373,8 +380,9 @@ public void testBuildVectorOutput_withPlainStringValue_successful() throws Excep
assertTrue(result.containsKey("oriKey6_knn"));
}

@SneakyThrows
@SuppressWarnings("unchecked")
public void testBuildVectorOutput_withNestedMap_successful() throws Exception {
public void testBuildVectorOutput_withNestedMap_successful() {
Map<String, Object> config = createNestedMapConfiguration();
IngestDocument ingestDocument = createNestedMapIngestDocument();
TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
Expand Down Expand Up @@ -421,7 +429,8 @@ private List<List<Float>> createMockVectorWithLength(int size) {
return result;
}

private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map<String, Object> fieldMap) throws Exception {
@SneakyThrows
private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map<String, Object> fieldMap) {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
Expand Down

0 comments on commit 46803c3

Please sign in to comment.