From 0458a4452cb277db0c9745dc217c3d94285da0db Mon Sep 17 00:00:00 2001 From: Abdul Muneer Kolarkunnu Date: Thu, 7 Nov 2024 12:49:34 +0530 Subject: [PATCH 1/4] [FEATURE]Improve test coverage for RemoteModel.java Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict(). Also renamed some tests to match with testing methods. Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu --- .../algorithms/remote/RemoteModelTest.java | 73 +++++++++++++++++-- 1 file changed, 65 insertions(+), 8 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index c14b329586..f60e803dd4 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -23,6 +23,7 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.MockitoAnnotations; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.MLModel; @@ -30,14 +31,17 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorProtocols; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.MLStaticMockBase; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import com.google.common.collect.ImmutableMap; -public class RemoteModelTest { +public class RemoteModelTest extends MLStaticMockBase { @Mock MLInput mlInput; @@ -45,6 +49,9 @@ public class RemoteModelTest { @Mock MLModel mlModel; + @Mock + RemoteConnectorExecutor remoteConnectorExecutor; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -73,7 +80,7 @@ public void test_predict_throw_IllegalStateException() { } @Test - public void predict_NullConnectorExecutor() { + public void asyncPredict_NullConnectorExecutor() { ActionListener actionListener = mock(ActionListener.class); remoteModel.asyncPredict(mlInput, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -86,7 +93,7 @@ public void predict_NullConnectorExecutor() { } @Test - public void predict_ModelDeployed_WrongInput() { + public void asyncPredict_ModelDeployed_WrongInput() { Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); when(mlModel.getConnector()).thenReturn(connector); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); @@ -99,12 +106,63 @@ public void predict_ModelDeployed_WrongInput() { } @Test - public void initModel_RuntimeException() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Tag mismatch!"); + public void asyncPredict_Failure_With_RuntimeException() { + asyncPredict_Failure_With_Throwable( + new RuntimeException("Remote Connection Exception!"), + RuntimeException.class, + "Remote Connection Exception!" + ); + } + + @Test + public void asyncPredict_Failure_With_Throwable() { + asyncPredict_Failure_With_Throwable( + new Error("Remote Connection Error!"), + MLException.class, + "java.lang.Error: Remote Connection Error!" + ); + } + + private void asyncPredict_Failure_With_Throwable( + Throwable actualException, + Class expExceptionClass, + String expExceptionMessage + ) { + ActionListener actionListener = mock(ActionListener.class); + doThrow(actualException) + .when(remoteConnectorExecutor) + .executeAction(ConnectorAction.ActionType.PREDICT.toString(), mlInput, actionListener); + try (MockedStatic loader = mockStatic(MLEngineClassLoader.class)) { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + loader + .when(() -> MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class)) + .thenReturn(remoteConnectorExecutor); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + remoteModel.asyncPredict(mlInput, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assert expExceptionClass.isInstance(argumentCaptor.getValue()); + assertEquals(expExceptionMessage, argumentCaptor.getValue().getMessage()); + } + } + + @Test + public void initModel_Failure_With_RuntimeException() { + initModel_Failure_With_Throwable(new IllegalArgumentException("Tag mismatch!"), IllegalArgumentException.class); + } + + @Test + public void initModel_Failure_With_Throwable() { + initModel_Failure_With_Throwable(new Error("Decryption Error!"), MLException.class); + } + + private void initModel_Failure_With_Throwable(Throwable actualException, Class expExcepClass) { + exceptionRule.expect(expExcepClass); + exceptionRule.expectMessage(actualException.getMessage()); Connector connector = createConnector(null); when(mlModel.getConnector()).thenReturn(connector); - doThrow(new IllegalArgumentException("Tag mismatch!")).when(encryptor).decrypt(any()); + doThrow(actualException).when(encryptor).decrypt(any()); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); } @@ -129,7 +187,6 @@ public void initModel_WithHeader() { Assert.assertNotNull(executor.getConnector().getDecryptedHeaders()); assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); - remoteModel.close(); Assert.assertNull(remoteModel.getConnectorExecutor()); } From 14fd97f4536f7fe4bdf730b58b733468a33b4aa6 Mon Sep 17 00:00:00 2001 From: Abdul Muneer Kolarkunnu Date: Fri, 8 Nov 2024 11:33:37 +0530 Subject: [PATCH 2/4] [FEATURE]Improve test coverage for RemoteModel.java Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict(). Also renamed some tests to match with testing methods. Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu --- .../ml/engine/algorithms/remote/RemoteModelTest.java | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index f60e803dd4..4620f673c5 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -149,17 +149,21 @@ private void asyncPredict_Failure_With_Throwable( @Test public void initModel_Failure_With_RuntimeException() { - initModel_Failure_With_Throwable(new IllegalArgumentException("Tag mismatch!"), IllegalArgumentException.class); + initModel_Failure_With_Throwable(new IllegalArgumentException("Tag mismatch!"), IllegalArgumentException.class, "Tag mismatch!"); } @Test public void initModel_Failure_With_Throwable() { - initModel_Failure_With_Throwable(new Error("Decryption Error!"), MLException.class); + initModel_Failure_With_Throwable(new Error("Decryption Error!"), MLException.class, "Decryption Error!"); } - private void initModel_Failure_With_Throwable(Throwable actualException, Class expExcepClass) { + private void initModel_Failure_With_Throwable( + Throwable actualException, + Class expExcepClass, + String expExceptionMessage + ) { exceptionRule.expect(expExcepClass); - exceptionRule.expectMessage(actualException.getMessage()); + exceptionRule.expectMessage(expExceptionMessage); Connector connector = createConnector(null); when(mlModel.getConnector()).thenReturn(connector); doThrow(actualException).when(encryptor).decrypt(any()); From 963753908876865c017e4e6f7cdca584d229f529 Mon Sep 17 00:00:00 2001 From: Abdul Muneer Kolarkunnu Date: Mon, 11 Nov 2024 16:31:29 +0530 Subject: [PATCH 3/4] [FEATURE]Improve test coverage for RemoteModel.java Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict(). Also renamed some tests to match with testing methods. Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu --- .../engine/algorithms/remote/RemoteModelTest.java | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index 4620f673c5..b186a23893 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -15,6 +15,7 @@ import java.util.Arrays; import java.util.Map; +import java.util.Collections; import org.junit.Assert; import org.junit.Before; @@ -31,6 +32,7 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorProtocols; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskResponse; @@ -94,6 +96,17 @@ public void asyncPredict_NullConnectorExecutor() { @Test public void asyncPredict_ModelDeployed_WrongInput() { + asyncPredict_ModelDeployed_WrongInput("pre_process_function not defined in connector"); + } + + @Test + public void asyncPredict_With_RemoteInferenceInputDataSet() { + when(mlInput.getInputDataset()).thenReturn( + new RemoteInferenceInputDataSet(Collections.emptyMap(), ConnectorAction.ActionType.BATCH_PREDICT)); + asyncPredict_ModelDeployed_WrongInput("no BATCH_PREDICT action found"); + } + + private void asyncPredict_ModelDeployed_WrongInput(String expExceptionMessage) { Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); when(mlModel.getConnector()).thenReturn(connector); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); @@ -102,7 +115,7 @@ public void asyncPredict_ModelDeployed_WrongInput() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assert argumentCaptor.getValue() instanceof RuntimeException; - assertEquals("pre_process_function not defined in connector", argumentCaptor.getValue().getMessage()); + assertEquals(expExceptionMessage, argumentCaptor.getValue().getMessage()); } @Test From 131f9435fdc293bdc0b496ac665cd4425b357d83 Mon Sep 17 00:00:00 2001 From: Abdul Muneer Kolarkunnu Date: Tue, 12 Nov 2024 08:15:09 +0530 Subject: [PATCH 4/4] [FEATURE]Improve test coverage for RemoteModel.java Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict(). Also renamed some tests to match with testing methods. Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu --- .../opensearch/ml/engine/algorithms/remote/RemoteModelTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index b186a23893..075019834c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -14,8 +14,8 @@ import static org.mockito.Mockito.when; import java.util.Arrays; -import java.util.Map; import java.util.Collections; +import java.util.Map; import org.junit.Assert; import org.junit.Before;