diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index 4620f673c5..b186a23893 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -15,6 +15,7 @@ import java.util.Arrays; import java.util.Map; +import java.util.Collections; import org.junit.Assert; import org.junit.Before; @@ -31,6 +32,7 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorProtocols; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskResponse; @@ -94,6 +96,17 @@ public void asyncPredict_NullConnectorExecutor() { @Test public void asyncPredict_ModelDeployed_WrongInput() { + asyncPredict_ModelDeployed_WrongInput("pre_process_function not defined in connector"); + } + + @Test + public void asyncPredict_With_RemoteInferenceInputDataSet() { + when(mlInput.getInputDataset()).thenReturn( + new RemoteInferenceInputDataSet(Collections.emptyMap(), ConnectorAction.ActionType.BATCH_PREDICT)); + asyncPredict_ModelDeployed_WrongInput("no BATCH_PREDICT action found"); + } + + private void asyncPredict_ModelDeployed_WrongInput(String expExceptionMessage) { Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); when(mlModel.getConnector()).thenReturn(connector); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); @@ -102,7 +115,7 @@ public void asyncPredict_ModelDeployed_WrongInput() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assert argumentCaptor.getValue() instanceof RuntimeException; - assertEquals("pre_process_function not defined in connector", argumentCaptor.getValue().getMessage()); + assertEquals(expExceptionMessage, argumentCaptor.getValue().getMessage()); } @Test