Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry mechanism for neural search inference. #91

Merged
merged 6 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.neuralsearch.util.RetryUtil;

/**
* This class will act as an abstraction on the MLCommons client for accessing the ML Capabilities
Expand Down Expand Up @@ -98,13 +99,30 @@ public void inferenceSentences(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener);
}

private void inferenceSentencesWithRetry(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<List<Float>>> listener
) {
MLInput mlInput = createMLInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
listener.onResponse(vector);
}, listener::onFailure));
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
Expand Down
36 changes: 36 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.util;

import java.util.List;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.opensearch.transport.NodeDisconnectedException;
import org.opensearch.transport.NodeNotConnectedException;

import com.google.common.collect.ImmutableList;

public class RetryUtil {

private static final int MAX_RETRY = 3;

private static final List<Class<? extends Throwable>> RETRYABLE_EXCEPTIONS = ImmutableList.of(
NodeNotConnectedException.class,
NodeDisconnectedException.class
);

/**
*
* @param e {@link Exception} which is the exception received to check if retryable.
* @param retryTime {@link int} which is the current retried times.
* @return {@link boolean} which is the result of if current exception needs retry or not.
*/
public static boolean shouldRetry(final Exception e, int retryTime) {
boolean hasRetryException = RETRYABLE_EXCEPTIONS.stream().anyMatch(x -> ExceptionUtils.indexOfThrowable(e, x) != -1);
return hasRetryException && retryTime < MAX_RETRY;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.neuralsearch.ml;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -17,6 +20,7 @@
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
Expand All @@ -26,6 +30,7 @@
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.neuralsearch.constants.TestCommonConstants;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.NodeNotConnectedException;

public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -114,6 +119,47 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() {
Mockito.verifyNoMoreInteractions(resultListener);
}

public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Times() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentences(
TestCommonConstants.TARGET_RESPONSE_FILTERS,
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST,
resultListener
);

Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onFailure(nodeNodeConnectedException);
}

public void testInferenceSentences_whenNotConnectionException_thenNoRetry() {
final IllegalStateException illegalStateException = new IllegalStateException("Illegal state");
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(illegalStateException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentences(
TestCommonConstants.TARGET_RESPONSE_FILTERS,
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST,
resultListener
);

Mockito.verify(client, times(1))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onFailure(illegalStateException);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import lombok.SneakyThrows;

import org.junit.Before;
import org.mockito.InjectMocks;
import org.mockito.Mock;
Expand Down Expand Up @@ -54,16 +56,17 @@ public void setup() {
when(env.settings()).thenReturn(settings);
}

private TextEmbeddingProcessor createInstance(List<List<Float>> vector) throws Exception {
@SneakyThrows
private TextEmbeddingProcessor createInstance(List<List<Float>> vector) {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped"));
TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
return processor;
return textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
}

public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalArgumentException() throws Exception {
@SneakyThrows
public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalArgumentException() {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
Expand All @@ -78,7 +81,8 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalA
}
}

public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() throws Exception {
@SneakyThrows
public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
Expand All @@ -89,7 +93,7 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalA
}
}

public void testExecute_successful() throws Exception {
public void testExecute_successful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "value1");
sourceAndMetadata.put("key2", "value2");
Expand All @@ -108,7 +112,8 @@ public void testExecute_successful() throws Exception {
verify(handler).accept(any(IngestDocument.class), isNull());
}

public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() throws Exception {
@SneakyThrows
public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "value1");
sourceAndMetadata.put("key2", "value2");
Expand All @@ -127,7 +132,8 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep
verify(handler).accept(isNull(), any(RuntimeException.class));
}

public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() throws Exception {
@SneakyThrows
public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
Map<String, Processor.Factory> registry = new HashMap<>();
Expand All @@ -144,7 +150,7 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() thr
verify(handler).accept(any(IngestDocument.class), isNull());
}

public void testExecute_withListTypeInput_successful() throws Exception {
public void testExecute_withListTypeInput_successful() {
List<String> list1 = ImmutableList.of("test1", "test2", "test3");
List<String> list2 = ImmutableList.of("test4", "test5", "test6");
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand All @@ -165,7 +171,7 @@ public void testExecute_withListTypeInput_successful() throws Exception {
verify(handler).accept(any(IngestDocument.class), isNull());
}

public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() throws Exception {
public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", " ");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
Expand All @@ -176,7 +182,7 @@ public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentExcep
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() throws Exception {
public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() {
List<String> list1 = ImmutableList.of("", "test2", "test3");
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", list1);
Expand All @@ -188,7 +194,7 @@ public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException()
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasNonStringValue_throwIllegalArgumentException() throws Exception {
public void testExecute_listHasNonStringValue_throwIllegalArgumentException() {
List<Integer> list2 = ImmutableList.of(1, 2, 3);
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key2", list2);
Expand All @@ -199,7 +205,7 @@ public void testExecute_listHasNonStringValue_throwIllegalArgumentException() th
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasNull_throwIllegalArgumentException() throws Exception {
public void testExecute_listHasNull_throwIllegalArgumentException() {
List<String> list = new ArrayList<>();
list.add("hello");
list.add(null);
Expand All @@ -213,7 +219,7 @@ public void testExecute_listHasNull_throwIllegalArgumentException() throws Excep
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_withMapTypeInput_successful() throws Exception {
public void testExecute_withMapTypeInput_successful() {
Map<String, String> map1 = ImmutableMap.of("test1", "test2");
Map<String, String> map2 = ImmutableMap.of("test4", "test5");
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand All @@ -235,7 +241,7 @@ public void testExecute_withMapTypeInput_successful() throws Exception {

}

public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() throws Exception {
public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() {
Map<String, String> map1 = ImmutableMap.of("test1", "test2");
Map<String, Double> map2 = ImmutableMap.of("test3", 209.3D);
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand All @@ -248,7 +254,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() thr
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() throws Exception {
public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() {
Map<String, String> map1 = ImmutableMap.of("test1", "test2");
Map<String, String> map2 = ImmutableMap.of("test3", " ");
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand All @@ -261,7 +267,7 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() t
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throws Exception {
public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() {
Map<String, Object> ret = createMaxDepthLimitExceedMap(() -> 1);
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "hello world");
Expand All @@ -273,7 +279,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throw
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_MLClientAccessorThrowFail_handlerFailure() throws Exception {
public void testExecute_MLClientAccessorThrowFail_handlerFailure() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "value1");
sourceAndMetadata.put("key2", "value2");
Expand Down Expand Up @@ -303,7 +309,7 @@ private Map<String, Object> createMaxDepthLimitExceedMap(Supplier<Integer> maxDe
return innerMap;
}

public void testExecute_hybridTypeInput_successful() throws Exception {
public void testExecute_hybridTypeInput_successful() {
List<String> list1 = ImmutableList.of("test1", "test2");
Map<String, List<String>> map1 = ImmutableMap.of("test3", list1);
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand All @@ -314,7 +320,7 @@ public void testExecute_hybridTypeInput_successful() throws Exception {
assert document.getSourceAndMetadata().containsKey("key2");
}

public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentException() throws Exception {
public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentException() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", 100);
sourceAndMetadata.put("key2", 100.232D);
Expand All @@ -331,7 +337,7 @@ public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentE
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testGetType_successful() throws Exception {
public void testGetType_successful() {
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
assert processor.getType().equals(TextEmbeddingProcessor.TYPE);
}
Expand All @@ -348,7 +354,8 @@ public void testProcessResponse_successful() throws Exception {
assertEquals(12, ingestDocument.getSourceAndMetadata().size());
}

public void testBuildVectorOutput_withPlainStringValue_successful() throws Exception {
@SneakyThrows
public void testBuildVectorOutput_withPlainStringValue_successful() {
Map<String, Object> config = createPlainStringConfiguration();
IngestDocument ingestDocument = createPlainIngestDocument();
TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
Expand All @@ -373,8 +380,9 @@ public void testBuildVectorOutput_withPlainStringValue_successful() throws Excep
assertTrue(result.containsKey("oriKey6_knn"));
}

@SneakyThrows
@SuppressWarnings("unchecked")
public void testBuildVectorOutput_withNestedMap_successful() throws Exception {
public void testBuildVectorOutput_withNestedMap_successful() {
Map<String, Object> config = createNestedMapConfiguration();
IngestDocument ingestDocument = createNestedMapIngestDocument();
TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
Expand Down Expand Up @@ -421,7 +429,8 @@ private List<List<Float>> createMockVectorWithLength(int size) {
return result;
}

private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map<String, Object> fieldMap) throws Exception {
@SneakyThrows
private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map<String, Object> fieldMap) {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
Expand Down