From b20717698176f34710778b1332c03351dcd74bd2 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 11 Jan 2023 01:59:53 +0800 Subject: [PATCH] Add retry mechanism for neural search inference (#91) Add basic retry mechanism for neural search inference Signed-off-by: Zan Niu --- .../ml/MLCommonsClientAccessor.java | 20 ++++++- .../neuralsearch/util/RetryUtil.java | 36 ++++++++++++ .../ml/MLCommonsClientAccessorTests.java | 46 +++++++++++++++ .../TextEmbeddingProcessorTests.java | 57 +++++++++++-------- 4 files changed, 134 insertions(+), 25 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 3ac0f7979..98bb52f7c 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -25,6 +25,7 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.neuralsearch.util.RetryUtil; /** * This class will act as an abstraction on the MLCommons client for accessing the ML Capabilities @@ -98,13 +99,30 @@ public void inferenceSentences( @NonNull final String modelId, @NonNull final List inputText, @NonNull final ActionListener>> listener + ) { + inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener); + } + + private void inferenceSentencesWithRetry( + final List targetResponseFilters, + final String modelId, + final List inputText, + final int retryTime, + final ActionListener>> listener ) { MLInput mlInput = createMLInput(targetResponseFilters, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> vector = buildVectorFromResponse(mlOutput); log.debug("Inference Response for input sentence {} is : {} ", inputText, vector); listener.onResponse(vector); - }, listener::onFailure)); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); } private MLInput createMLInput(final List targetResponseFilters, List inputText) { diff --git a/src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java b/src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java new file mode 100644 index 000000000..5c1d486d4 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/RetryUtil.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import java.util.List; + +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.opensearch.transport.NodeDisconnectedException; +import org.opensearch.transport.NodeNotConnectedException; + +import com.google.common.collect.ImmutableList; + +public class RetryUtil { + + private static final int MAX_RETRY = 3; + + private static final List> RETRYABLE_EXCEPTIONS = ImmutableList.of( + NodeNotConnectedException.class, + NodeDisconnectedException.class + ); + + /** + * + * @param e {@link Exception} which is the exception received to check if retryable. + * @param retryTime {@link int} which is the current retried times. + * @return {@link boolean} which is the result of if current exception needs retry or not. + */ + public static boolean shouldRetry(final Exception e, int retryTime) { + boolean hasRetryException = RETRYABLE_EXCEPTIONS.stream().anyMatch(x -> ExceptionUtils.indexOfThrowable(e, x) != -1); + return hasRetryException && retryTime < MAX_RETRY; + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 81523d8d7..b7dfec083 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,6 +5,9 @@ package org.opensearch.neuralsearch.ml; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; + import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; @@ -17,6 +20,7 @@ import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; @@ -26,6 +30,7 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.neuralsearch.constants.TestCommonConstants; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.NodeNotConnectedException; public class MLCommonsClientAccessorTests extends OpenSearchTestCase { @@ -114,6 +119,47 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { Mockito.verifyNoMoreInteractions(resultListener); } + public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Times() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(nodeNodeConnectedException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentences( + TestCommonConstants.TARGET_RESPONSE_FILTERS, + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST, + resultListener + ); + + Mockito.verify(client, times(4)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onFailure(nodeNodeConnectedException); + } + + public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { + final IllegalStateException illegalStateException = new IllegalStateException("Illegal state"); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(illegalStateException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentences( + TestCommonConstants.TARGET_RESPONSE_FILTERS, + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST, + resultListener + ); + + Mockito.verify(client, times(1)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onFailure(illegalStateException); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 5890f2edc..d4a92f103 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -17,6 +17,8 @@ import java.util.function.BiConsumer; import java.util.function.Supplier; +import lombok.SneakyThrows; + import org.junit.Before; import org.mockito.InjectMocks; import org.mockito.Mock; @@ -54,16 +56,17 @@ public void setup() { when(env.settings()).thenReturn(settings); } - private TextEmbeddingProcessor createInstance(List> vector) throws Exception { + @SneakyThrows + private TextEmbeddingProcessor createInstance(List> vector) { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); - TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); - return processor; + return textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } - public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalArgumentException() throws Exception { + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalArgumentException() { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -78,7 +81,8 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalA } } - public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() throws Exception { + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -89,7 +93,7 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalA } } - public void testExecute_successful() throws Exception { + public void testExecute_successful() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); @@ -108,7 +112,8 @@ public void testExecute_successful() throws Exception { verify(handler).accept(any(IngestDocument.class), isNull()); } - public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() throws Exception { + @SneakyThrows + public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); @@ -127,7 +132,8 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep verify(handler).accept(isNull(), any(RuntimeException.class)); } - public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() throws Exception { + @SneakyThrows + public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() { Map sourceAndMetadata = new HashMap<>(); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); @@ -144,7 +150,7 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() thr verify(handler).accept(any(IngestDocument.class), isNull()); } - public void testExecute_withListTypeInput_successful() throws Exception { + public void testExecute_withListTypeInput_successful() { List list1 = ImmutableList.of("test1", "test2", "test3"); List list2 = ImmutableList.of("test4", "test5", "test6"); Map sourceAndMetadata = new HashMap<>(); @@ -165,7 +171,7 @@ public void testExecute_withListTypeInput_successful() throws Exception { verify(handler).accept(any(IngestDocument.class), isNull()); } - public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() throws Exception { + public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", " "); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); @@ -176,7 +182,7 @@ public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentExcep verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } - public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() throws Exception { + public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() { List list1 = ImmutableList.of("", "test2", "test3"); Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", list1); @@ -188,7 +194,7 @@ public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } - public void testExecute_listHasNonStringValue_throwIllegalArgumentException() throws Exception { + public void testExecute_listHasNonStringValue_throwIllegalArgumentException() { List list2 = ImmutableList.of(1, 2, 3); Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key2", list2); @@ -199,7 +205,7 @@ public void testExecute_listHasNonStringValue_throwIllegalArgumentException() th verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } - public void testExecute_listHasNull_throwIllegalArgumentException() throws Exception { + public void testExecute_listHasNull_throwIllegalArgumentException() { List list = new ArrayList<>(); list.add("hello"); list.add(null); @@ -213,7 +219,7 @@ public void testExecute_listHasNull_throwIllegalArgumentException() throws Excep verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } - public void testExecute_withMapTypeInput_successful() throws Exception { + public void testExecute_withMapTypeInput_successful() { Map map1 = ImmutableMap.of("test1", "test2"); Map map2 = ImmutableMap.of("test4", "test5"); Map sourceAndMetadata = new HashMap<>(); @@ -235,7 +241,7 @@ public void testExecute_withMapTypeInput_successful() throws Exception { } - public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() throws Exception { + public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { Map map1 = ImmutableMap.of("test1", "test2"); Map map2 = ImmutableMap.of("test3", 209.3D); Map sourceAndMetadata = new HashMap<>(); @@ -248,7 +254,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() thr verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } - public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() throws Exception { + public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { Map map1 = ImmutableMap.of("test1", "test2"); Map map2 = ImmutableMap.of("test3", " "); Map sourceAndMetadata = new HashMap<>(); @@ -261,7 +267,7 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() t verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } - public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throws Exception { + public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { Map ret = createMaxDepthLimitExceedMap(() -> 1); Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", "hello world"); @@ -273,7 +279,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throw verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } - public void testExecute_MLClientAccessorThrowFail_handlerFailure() throws Exception { + public void testExecute_MLClientAccessorThrowFail_handlerFailure() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); @@ -303,7 +309,7 @@ private Map createMaxDepthLimitExceedMap(Supplier maxDe return innerMap; } - public void testExecute_hybridTypeInput_successful() throws Exception { + public void testExecute_hybridTypeInput_successful() { List list1 = ImmutableList.of("test1", "test2"); Map> map1 = ImmutableMap.of("test3", list1); Map sourceAndMetadata = new HashMap<>(); @@ -314,7 +320,7 @@ public void testExecute_hybridTypeInput_successful() throws Exception { assert document.getSourceAndMetadata().containsKey("key2"); } - public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentException() throws Exception { + public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", 100); sourceAndMetadata.put("key2", 100.232D); @@ -331,7 +337,7 @@ public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentE verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } - public void testGetType_successful() throws Exception { + public void testGetType_successful() { TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); assert processor.getType().equals(TextEmbeddingProcessor.TYPE); } @@ -348,7 +354,8 @@ public void testProcessResponse_successful() throws Exception { assertEquals(12, ingestDocument.getSourceAndMetadata().size()); } - public void testBuildVectorOutput_withPlainStringValue_successful() throws Exception { + @SneakyThrows + public void testBuildVectorOutput_withPlainStringValue_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); @@ -373,8 +380,9 @@ public void testBuildVectorOutput_withPlainStringValue_successful() throws Excep assertTrue(result.containsKey("oriKey6_knn")); } + @SneakyThrows @SuppressWarnings("unchecked") - public void testBuildVectorOutput_withNestedMap_successful() throws Exception { + public void testBuildVectorOutput_withNestedMap_successful() { Map config = createNestedMapConfiguration(); IngestDocument ingestDocument = createNestedMapIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); @@ -421,7 +429,8 @@ private List> createMockVectorWithLength(int size) { return result; } - private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map fieldMap) throws Exception { + @SneakyThrows + private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map fieldMap) { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");