Skip to content

Commit

Permalink
unit tests coverage for load/unload/syncup (opensearch-project#592)
Browse files Browse the repository at this point in the history
* unit tests coverage for load/unload/syncup

Signed-off-by: Bhavana Ramaram <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
rbhavna authored and b4sjoo committed Dec 5, 2022
1 parent df3babd commit 4e3b8f9
Show file tree
Hide file tree
Showing 5 changed files with 845 additions and 5 deletions.
5 changes: 0 additions & 5 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,7 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.profile.MLModelProfile',
'org.opensearch.ml.profile.MLPredictRequestStats',
'org.opensearch.ml.action.load.TransportLoadModelAction',
'org.opensearch.ml.action.load.TransportLoadModelOnNodeAction',
'org.opensearch.ml.model.MLModelManager',
'org.opensearch.ml.action.unload.TransportUnloadModelAction',
'org.opensearch.ml.action.forward.TransportForwardAction',
'org.opensearch.ml.action.syncup.TransportSyncUpOnNodeAction',
'org.opensearch.ml.stats.MLClusterLevelStat',
'org.opensearch.ml.stats.MLStatLevel',
'org.opensearch.ml.utils.IndexUtils',
Expand All @@ -258,7 +254,6 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.task.MLTrainAndPredictTaskRunner',
'org.opensearch.ml.task.MLExecuteTaskRunner',
'org.opensearch.ml.action.profile.MLProfileTransportAction',
'org.opensearch.ml.action.syncup.TransportSyncUpOnNodeAction',
'org.opensearch.ml.action.models.DeleteModelTransportAction.1',
'org.opensearch.ml.rest.RestMLPredictionAction'
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.common.transport.forward.MLForwardRequestType.LOAD_MODEL_DONE;
import static org.opensearch.ml.common.transport.forward.MLForwardRequestType.UPLOAD_MODEL;
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;

import java.util.Arrays;
Expand All @@ -40,11 +41,14 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.forward.MLForwardInput;
import org.opensearch.ml.common.transport.forward.MLForwardRequest;
import org.opensearch.ml.common.transport.forward.MLForwardResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.upload.MLUploadInput;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLTaskCache;
import org.opensearch.ml.task.MLTaskManager;
Expand Down Expand Up @@ -125,6 +129,29 @@ public void testDoExecute_LoadModelDone_Error() {
verify(mlTaskManager, never()).updateMLTask(anyString(), any(), anyLong(), anyBoolean());
}

public void testDoExecute_LoadModelDone_NoError() {
Set<String> workerNodes = new HashSet<>();
workerNodes.add(nodeId1);
workerNodes.add(nodeId2);
when(mlTaskManager.getWorkNodes(anyString())).thenReturn(workerNodes);
when(mlModelManager.getWorkerNodes(anyString())).thenReturn(new String[] { nodeId1, nodeId2 });

MLForwardInput forwardInput = MLForwardInput
.builder()
.requestType(LOAD_MODEL_DONE)
.taskId(taskId)
.modelId(modelId)
.workerNodeId(nodeId1)
.build();
MLForwardRequest forwardRequest = MLForwardRequest.builder().forwardInput(forwardInput).build();
forwardAction.doExecute(task, forwardRequest, listener);
ArgumentCaptor<MLForwardResponse> response = ArgumentCaptor.forClass(MLForwardResponse.class);
verify(listener).onResponse(response.capture());
assertEquals("ok", response.getValue().getStatus());
assertNull(response.getValue().getMlOutput());
verify(mlTaskManager, never()).updateMLTask(anyString(), any(), anyLong(), anyBoolean());
}

public void testDoExecute_LoadModelDone_Error_NullTaskWorkerNodes() {
when(mlTaskManager.getWorkNodes(anyString())).thenReturn(null);
MLTaskCache mlTaskCache = MLTaskCache.builder().mlTask(createMlTask(MLTaskType.UPLOAD_MODEL)).build();
Expand Down Expand Up @@ -198,6 +225,23 @@ public void testDoExecute_LoadModel_Exception() {
assertEquals(error, exception.getValue().getMessage());
}

public void testDoExecute_UploadModel() {
MLForwardInput forwardInput = MLForwardInput
.builder()
.requestType(UPLOAD_MODEL)
.mlTask(createMlTask(MLTaskType.UPLOAD_MODEL))
.taskId(taskId)
.uploadInput(prepareInput())
.build();
MLForwardRequest forwardRequest = MLForwardRequest.builder().forwardInput(forwardInput).build();
forwardAction.doExecute(task, forwardRequest, listener);
ArgumentCaptor<MLForwardResponse> response = ArgumentCaptor.forClass(MLForwardResponse.class);
verify(listener).onResponse(response.capture());
assertEquals("ok", response.getValue().getStatus());
assertNull(response.getValue().getMlOutput());
verify(mlModelManager).uploadMLModel(any(), any());
}

private MLTask createMlTask(MLTaskType mlTaskType) {
return MLTask
.builder()
Expand All @@ -208,4 +252,20 @@ private MLTask createMlTask(MLTaskType mlTaskType) {
.taskType(mlTaskType)
.build();
}

private MLUploadInput prepareInput() {
MLUploadInput uploadInput = MLUploadInput
.builder()
.functionName(FunctionName.BATCH_RCF)
.loadModel(true)
.version("1.0")
.modelName("Test Model")
.modelConfig(
new TextEmbeddingModelConfig("CUSTOM", 123, TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, "all config")
)
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.url("http://test_url")
.build();
return uploadInput;
}
}
Loading

0 comments on commit 4e3b8f9

Please sign in to comment.