Skip to content

Commit

Permalink
[FEATURE]Improve test coverage for RemoteModel.java
Browse files Browse the repository at this point in the history
Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict().
Also renamed some tests to match with testing methods.

Resolves opensearch-project#1382

Signed-off-by: Abdul Muneer Kolarkunnu <[email protected]>
  • Loading branch information
akolarkunnu committed Nov 11, 2024
1 parent 14fd97f commit 9637539
Showing 1 changed file with 14 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import java.util.Arrays;
import java.util.Map;
import java.util.Collections;

import org.junit.Assert;
import org.junit.Before;
Expand All @@ -31,6 +32,7 @@
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorProtocols;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskResponse;
Expand Down Expand Up @@ -94,6 +96,17 @@ public void asyncPredict_NullConnectorExecutor() {

@Test
public void asyncPredict_ModelDeployed_WrongInput() {
asyncPredict_ModelDeployed_WrongInput("pre_process_function not defined in connector");
}

@Test
public void asyncPredict_With_RemoteInferenceInputDataSet() {
when(mlInput.getInputDataset()).thenReturn(
new RemoteInferenceInputDataSet(Collections.emptyMap(), ConnectorAction.ActionType.BATCH_PREDICT));
asyncPredict_ModelDeployed_WrongInput("no BATCH_PREDICT action found");
}

private void asyncPredict_ModelDeployed_WrongInput(String expExceptionMessage) {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
Expand All @@ -102,7 +115,7 @@ public void asyncPredict_ModelDeployed_WrongInput() {
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assert argumentCaptor.getValue() instanceof RuntimeException;
assertEquals("pre_process_function not defined in connector", argumentCaptor.getValue().getMessage());
assertEquals(expExceptionMessage, argumentCaptor.getValue().getMessage());
}

@Test
Expand Down

0 comments on commit 9637539

Please sign in to comment.