diff --git a/plugin/build.gradle b/plugin/build.gradle index ce18502cd9..8f68d06b10 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -240,11 +240,7 @@ List jacocoExclusions = [ 'org.opensearch.ml.profile.MLModelProfile', 'org.opensearch.ml.profile.MLPredictRequestStats', 'org.opensearch.ml.action.load.TransportLoadModelAction', - 'org.opensearch.ml.action.load.TransportLoadModelOnNodeAction', 'org.opensearch.ml.model.MLModelManager', - 'org.opensearch.ml.action.unload.TransportUnloadModelAction', - 'org.opensearch.ml.action.forward.TransportForwardAction', - 'org.opensearch.ml.action.syncup.TransportSyncUpOnNodeAction', 'org.opensearch.ml.stats.MLClusterLevelStat', 'org.opensearch.ml.stats.MLStatLevel', 'org.opensearch.ml.utils.IndexUtils', @@ -258,7 +254,6 @@ List jacocoExclusions = [ 'org.opensearch.ml.task.MLTrainAndPredictTaskRunner', 'org.opensearch.ml.task.MLExecuteTaskRunner', 'org.opensearch.ml.action.profile.MLProfileTransportAction', - 'org.opensearch.ml.action.syncup.TransportSyncUpOnNodeAction', 'org.opensearch.ml.action.models.DeleteModelTransportAction.1', 'org.opensearch.ml.rest.RestMLPredictionAction' ] diff --git a/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java index b2b3ef9687..448141beb1 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java @@ -19,6 +19,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.common.transport.forward.MLForwardRequestType.LOAD_MODEL_DONE; +import static org.opensearch.ml.common.transport.forward.MLForwardRequestType.UPLOAD_MODEL; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import java.util.Arrays; @@ -40,11 +41,14 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.forward.MLForwardInput; import org.opensearch.ml.common.transport.forward.MLForwardRequest; import org.opensearch.ml.common.transport.forward.MLForwardResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpAction; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; +import org.opensearch.ml.common.transport.upload.MLUploadInput; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.task.MLTaskCache; import org.opensearch.ml.task.MLTaskManager; @@ -125,6 +129,29 @@ public void testDoExecute_LoadModelDone_Error() { verify(mlTaskManager, never()).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); } + public void testDoExecute_LoadModelDone_NoError() { + Set workerNodes = new HashSet<>(); + workerNodes.add(nodeId1); + workerNodes.add(nodeId2); + when(mlTaskManager.getWorkNodes(anyString())).thenReturn(workerNodes); + when(mlModelManager.getWorkerNodes(anyString())).thenReturn(new String[] { nodeId1, nodeId2 }); + + MLForwardInput forwardInput = MLForwardInput + .builder() + .requestType(LOAD_MODEL_DONE) + .taskId(taskId) + .modelId(modelId) + .workerNodeId(nodeId1) + .build(); + MLForwardRequest forwardRequest = MLForwardRequest.builder().forwardInput(forwardInput).build(); + forwardAction.doExecute(task, forwardRequest, listener); + ArgumentCaptor response = ArgumentCaptor.forClass(MLForwardResponse.class); + verify(listener).onResponse(response.capture()); + assertEquals("ok", response.getValue().getStatus()); + assertNull(response.getValue().getMlOutput()); + verify(mlTaskManager, never()).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); + } + public void testDoExecute_LoadModelDone_Error_NullTaskWorkerNodes() { when(mlTaskManager.getWorkNodes(anyString())).thenReturn(null); MLTaskCache mlTaskCache = MLTaskCache.builder().mlTask(createMlTask(MLTaskType.UPLOAD_MODEL)).build(); @@ -198,6 +225,23 @@ public void testDoExecute_LoadModel_Exception() { assertEquals(error, exception.getValue().getMessage()); } + public void testDoExecute_UploadModel() { + MLForwardInput forwardInput = MLForwardInput + .builder() + .requestType(UPLOAD_MODEL) + .mlTask(createMlTask(MLTaskType.UPLOAD_MODEL)) + .taskId(taskId) + .uploadInput(prepareInput()) + .build(); + MLForwardRequest forwardRequest = MLForwardRequest.builder().forwardInput(forwardInput).build(); + forwardAction.doExecute(task, forwardRequest, listener); + ArgumentCaptor response = ArgumentCaptor.forClass(MLForwardResponse.class); + verify(listener).onResponse(response.capture()); + assertEquals("ok", response.getValue().getStatus()); + assertNull(response.getValue().getMlOutput()); + verify(mlModelManager).uploadMLModel(any(), any()); + } + private MLTask createMlTask(MLTaskType mlTaskType) { return MLTask .builder() @@ -208,4 +252,20 @@ private MLTask createMlTask(MLTaskType mlTaskType) { .taskType(mlTaskType) .build(); } + + private MLUploadInput prepareInput() { + MLUploadInput uploadInput = MLUploadInput + .builder() + .functionName(FunctionName.BATCH_RCF) + .loadModel(true) + .version("1.0") + .modelName("Test Model") + .modelConfig( + new TextEmbeddingModelConfig("CUSTOM", 123, TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, "all config") + ) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .url("http://test_url") + .build(); + return uploadInput; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelOnNodeActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelOnNodeActionTests.java new file mode 100644 index 0000000000..48534b0e48 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelOnNodeActionTests.java @@ -0,0 +1,327 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.load; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; + +import java.io.IOException; +import java.net.InetAddress; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.breaker.MLCircuitBreakerService; +import org.opensearch.ml.common.dataset.MLInputDataType; +import org.opensearch.ml.common.exception.MLLimitExceededException; +import org.opensearch.ml.common.transport.forward.MLForwardResponse; +import org.opensearch.ml.common.transport.load.LoadModelInput; +import org.opensearch.ml.common.transport.load.LoadModelNodeRequest; +import org.opensearch.ml.common.transport.load.LoadModelNodeResponse; +import org.opensearch.ml.common.transport.load.LoadModelNodesRequest; +import org.opensearch.ml.common.transport.load.LoadModelNodesResponse; +import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class TransportLoadModelOnNodeActionTests extends OpenSearchTestCase { + + @Mock + private TransportService transportService; + + @Mock + private ModelHelper modelHelper; + + @Mock + private MLTaskManager mlTaskManager; + + @Mock + private MLModelManager mlModelManager; + + @Mock + private ClusterService clusterService; + + @Mock + private ThreadPool threadPool; + + @Mock + private Client client; + + @Mock + private NamedXContentRegistry xContentRegistry; + + @Mock + private MLCircuitBreakerService mlCircuitBreakerService; + + @Mock + private MLStats mlStats; + + @Mock + private ExecutorService executorService; + + @Mock + private ActionFilters actionFilters; + + private TransportLoadModelOnNodeAction action; + + private ThreadContext threadContext; + + private Settings settings; + + private DiscoveryNode localNode; + private DiscoveryNode localNode1; + private DiscoveryNode localNode2; + private DiscoveryNode localNode3; + private DiscoveryNode clusterManagerNode; + private final String clusterManagerNodeId = "clusterManagerNode"; + + private MLTask mlTask; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE.getKey(), 1).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + action = new TransportLoadModelOnNodeAction( + transportService, + actionFilters, + modelHelper, + mlTaskManager, + mlModelManager, + clusterService, + threadPool, + client, + xContentRegistry, + mlCircuitBreakerService, + mlStats, + settings + ); + + clusterManagerNode = new DiscoveryNode( + clusterManagerNodeId, + buildNewFakeTransportAddress(), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + localNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + localNode1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + localNode2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + localNode3 = new DiscoveryNode( + "foo3", + "foo3", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); + when(clusterService.localNode()).thenReturn(localNode); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse("successful"); + return null; + }).when(mlModelManager).loadModel(any(), any(), any(), any()); + MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); + doAnswer(invocation -> { + ActionListenerResponseHandler handler = invocation.getArgument(3); + handler.handleResponse(forwardResponse); + return null; + }).when(transportService).sendRequest(any(), any(), any(), any()); + + DiscoveryNodes nodes = DiscoveryNodes.builder().add(clusterManagerNode).add(localNode1).add(localNode1).add(localNode1).build(); + ClusterState clusterState = new ClusterState( + new ClusterName("Local Cluster"), + 123l, + "111111", + null, + null, + nodes, + null, + null, + 0, + false + ); + when(clusterService.state()).thenReturn(clusterState); + + Instant time = Instant.now(); + mlTask = MLTask + .builder() + .taskId("mlTaskTaskId") + .modelId("mlTaskModelId") + .taskType(MLTaskType.PREDICTION) + .functionName(FunctionName.LINEAR_REGRESSION) + .state(MLTaskState.RUNNING) + .inputType(MLInputDataType.DATA_FRAME) + .workerNode("node1") + .progress(0.0f) + .outputIndex("test_index") + .error("test_error") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .lastUpdateTime(time) + .build(); + + } + + public void testConstructor() { + assertNotNull(action); + } + + public void testNewResponses() { + final LoadModelNodesRequest nodesRequest = prepareRequest(); + Map modelToLoadStatus = new HashMap<>(); + modelToLoadStatus.put("modelName:version", "response"); + LoadModelNodeResponse response = new LoadModelNodeResponse(localNode, modelToLoadStatus); + final List responses = Arrays.asList(new LoadModelNodeResponse[] { response }); + final List failures = new ArrayList(); + LoadModelNodesResponse response1 = action.newResponse(nodesRequest, responses, failures); + assertNotNull(response1); + } + + public void testNewRequest() { + final LoadModelNodesRequest nodesRequest = prepareRequest(); + final LoadModelNodeRequest request = action.newNodeRequest(nodesRequest); + assertNotNull(request); + } + + public void testNewNodeResponse() throws IOException { + Map modelToLoadStatus = new HashMap<>(); + modelToLoadStatus.put("modelName:version", "response"); + LoadModelNodeResponse response = new LoadModelNodeResponse(localNode, modelToLoadStatus); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + final LoadModelNodeResponse response1 = action.newNodeResponse(output.bytes().streamInput()); + assertNotNull(response1); + } + + public void testNodeOperation_Success() { + final LoadModelNodesRequest nodesRequest = prepareRequest(); + final LoadModelNodeRequest request = action.newNodeRequest(nodesRequest); + final LoadModelNodeResponse response = action.nodeOperation(request); + assertNotNull(response); + } + + public void testNodeOperation_LoadModelException() { + when(mlModelManager.checkAndAddRunningTask(any(), any())).thenThrow(NullPointerException.class); + final LoadModelNodesRequest nodesRequest = prepareRequest(); + final LoadModelNodeRequest request = action.newNodeRequest(nodesRequest); + final LoadModelNodeResponse response = action.nodeOperation(request); + assertNotNull(response); + } + + public void testNodeOperation_Exception() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new RuntimeException("Something went wrong")); + return null; + }).when(mlModelManager).loadModel(any(), any(), any(), any()); + final LoadModelNodesRequest nodesRequest = prepareRequest(); + final LoadModelNodeRequest request = action.newNodeRequest(nodesRequest); + final LoadModelNodeResponse response = action.nodeOperation(request); + assertNotNull(response); + } + + public void testNodeOperation_MLLimitExceededException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new MLLimitExceededException("Limit exceeded exception")); + return null; + }).when(mlModelManager).loadModel(any(), any(), any(), any()); + final LoadModelNodesRequest nodesRequest = prepareRequest(); + final LoadModelNodeRequest request = action.newNodeRequest(nodesRequest); + final LoadModelNodeResponse response = action.nodeOperation(request); + assertNotNull(response); + } + + public void testNodeOperation_ErrorMessageNotNull() { + when(mlModelManager.checkAndAddRunningTask(any(), any())).thenReturn("Error message"); + final LoadModelNodesRequest nodesRequest = prepareRequest(); + final LoadModelNodeRequest request = action.newNodeRequest(nodesRequest); + final LoadModelNodeResponse response = action.nodeOperation(request); + assertNotNull(response); + } + + private LoadModelNodesRequest prepareRequest() { + DiscoveryNode[] nodeIds = { localNode1, localNode2, localNode3 }; + LoadModelInput loadModelInput = new LoadModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", mlTask); + return new LoadModelNodesRequest(nodeIds, loadModelInput); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java new file mode 100644 index 0000000000..d3f676cc42 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java @@ -0,0 +1,269 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.syncup; + +import static java.util.Collections.emptyMap; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.when; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE; +import static org.opensearch.ml.utils.TestHelper.ML_ROLE; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; + +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.junit.Before; +import org.junit.rules.TemporaryFolder; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.sync.MLSyncUpInput; +import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest; +import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; +import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; +import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableSet; + +public class TransportSyncUpOnNodeActionTests extends OpenSearchTestCase { + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ModelHelper modelHelper; + + @Mock + private MLTaskManager mlTaskManager; + + @Mock + private MLModelManager mlModelManager; + + @Mock + private ClusterService clusterService; + + @Mock + private ThreadPool threadPool; + + @Mock + private Client client; + + @Mock + private NamedXContentRegistry xContentRegistry; + + @Mock + private MLEngine mlEngine; + + private Settings settings; + + public TemporaryFolder testFolder = new TemporaryFolder(); + + private TransportSyncUpOnNodeAction action; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + mockSettings(true); + when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); + action = new TransportSyncUpOnNodeAction( + transportService, + actionFilters, + modelHelper, + mlTaskManager, + mlModelManager, + clusterService, + threadPool, + client, + xContentRegistry, + mlEngine + ); + } + + public void testConstructor() { + assertNotNull(action); + } + + public void testNewResponse() { + final MLSyncUpNodesRequest nodesRequest = Mockito.mock(MLSyncUpNodesRequest.class); + final List responses = new ArrayList(); + final List failures = new ArrayList(); + final MLSyncUpNodesResponse response = action.newResponse(nodesRequest, responses, failures); + assertNotNull(response); + } + + public void testNewRequest() { + final MLSyncUpNodeRequest request = action.newNodeRequest(new MLSyncUpNodesRequest(new String[] {}, prepareRequest())); + assertNotNull(request); + } + + public void testNewNodeResponse() throws IOException { + final DiscoveryNode mlNode1 = new DiscoveryNode( + "123", + buildNewFakeTransportAddress(), + emptyMap(), + ImmutableSet.of(ML_ROLE), + Version.CURRENT + ); + String[] loadedModelIds = new String[] { "123" }; + String[] runningLoadModelTaskIds = new String[] { "1" }; + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(mlNode1, "LOADED", loadedModelIds, runningLoadModelTaskIds); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + final MLSyncUpNodeResponse response1 = action.newNodeResponse(output.bytes().streamInput()); + assertNotNull(response1); + } + + public void testNodeOperation_AddedWorkerNodes() throws IOException { + testFolder.create(); + File file1 = testFolder.newFolder(); + File file2 = testFolder.newFolder(); + File file3 = testFolder.newFolder(); + for (int i = 0; i < 5; i++) { + File.createTempFile("Hello" + i, "1.txt", file1); + File.createTempFile("Hello" + i, "1.txt", file2); + File.createTempFile("Hello" + i, "1.txt", file3); + } + when(mlEngine.getModelCachePath(any())).thenReturn(Paths.get(file3.getCanonicalPath())); + when(mlEngine.getLoadModelPath(any())).thenReturn(Paths.get(file2.getCanonicalPath())); + when(mlEngine.getUploadModelPath(any())).thenReturn(Paths.get(file1.getCanonicalPath())); + DiscoveryNode localNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + when(clusterService.localNode()).thenReturn(localNode); + when(mlEngine.getUploadModelRootPath()).thenReturn(Paths.get(file1.getCanonicalPath())); + when(mlEngine.getLoadModelRootPath()).thenReturn(Paths.get(file2.getCanonicalPath())); + when(mlEngine.getModelCacheRootPath()).thenReturn(Paths.get(file3.getCanonicalPath())); + final MLSyncUpNodeRequest request = action.newNodeRequest(new MLSyncUpNodesRequest(new String[] {}, prepareRequest())); + final MLSyncUpNodeResponse response = action.nodeOperation(request); + assertNotNull(response); + file1.deleteOnExit(); + file2.deleteOnExit(); + file3.deleteOnExit(); + testFolder.delete(); + } + + public void testNodeOperation_RemovedWorkerNodes() throws IOException { + testFolder.create(); + File file1 = testFolder.newFolder(); + File file2 = testFolder.newFolder(); + File file3 = testFolder.newFolder(); + for (int i = 0; i < 5; i++) { + File.createTempFile("Hello" + i, "1.txt", file1); + File.createTempFile("Hello" + i, "1.txt", file2); + File.createTempFile("Hello" + i, "1.txt", file3); + } + when(mlEngine.getModelCachePath(any())).thenReturn(Paths.get(file3.getCanonicalPath())); + when(mlEngine.getLoadModelPath(any())).thenReturn(Paths.get(file2.getCanonicalPath())); + when(mlEngine.getUploadModelPath(any())).thenReturn(Paths.get(file1.getCanonicalPath())); + DiscoveryNode localNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + when(clusterService.localNode()).thenReturn(localNode); + when(mlEngine.getUploadModelRootPath()).thenReturn(Paths.get(file1.getCanonicalPath())); + when(mlEngine.getLoadModelRootPath()).thenReturn(Paths.get(file2.getCanonicalPath())); + when(mlEngine.getModelCacheRootPath()).thenReturn(Paths.get(file3.getCanonicalPath())); + when(mlTaskManager.contains(any())).thenReturn(true); + when(mlTaskManager.containsModel(any())).thenReturn(true); + when(mlModelManager.isModelRunningOnNode(anyString())).thenReturn(true); + final MLSyncUpNodeRequest request = action.newNodeRequest(new MLSyncUpNodesRequest(new String[] {}, prepareRequest2())); + final MLSyncUpNodeResponse response = action.nodeOperation(request); + assertNotNull(response); + file1.deleteOnExit(); + file2.deleteOnExit(); + file3.deleteOnExit(); + testFolder.delete(); + } + + private MLSyncUpInput prepareRequest() { + Map addedWorkerNodes = new HashMap<>(); + addedWorkerNodes.put("modelId1", new String[] { "nodeId1", "nodeId2", "nodeId3" }); + Map> modelRoutingTable = new HashMap<>(); + Map> runningLoadModelTasks = new HashMap<>(); + final HashSet set = new HashSet<>(); + set.addAll(Arrays.asList(new String[] { "nodeId3", "nodeId4", "nodeId5" })); + modelRoutingTable.put("modelId2", set); + MLSyncUpInput syncUpInput = MLSyncUpInput + .builder() + .getLoadedModels(true) + .addedWorkerNodes(addedWorkerNodes) + .modelRoutingTable(modelRoutingTable) + .runningLoadModelTasks(runningLoadModelTasks) + .clearRoutingTable(true) + .syncRunningLoadModelTasks(true) + .build(); + return syncUpInput; + } + + private MLSyncUpInput prepareRequest2() { + Map removedWorkerNodes = new HashMap<>(); + removedWorkerNodes.put("modelId2", new String[] { "nodeId3", "nodeId4", "nodeId5" }); + Map> modelRoutingTable = new HashMap<>(); + Map> runningLoadModelTasks = new HashMap<>(); + final HashSet set = new HashSet<>(); + set.addAll(Arrays.asList(new String[] { "nodeId3", "nodeId4", "nodeId5" })); + modelRoutingTable.put("modelId2", set); + MLSyncUpInput syncUpInput = MLSyncUpInput + .builder() + .getLoadedModels(true) + .removedWorkerNodes(removedWorkerNodes) + .modelRoutingTable(modelRoutingTable) + .runningLoadModelTasks(runningLoadModelTasks) + .clearRoutingTable(false) + .syncRunningLoadModelTasks(true) + .build(); + return syncUpInput; + } + + private void mockSettings(boolean onlyRunOnMLNode) { + settings = Settings.builder().put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), onlyRunOnMLNode).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ONLY_RUN_ON_ML_NODE); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/unload/TransportUnloadModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/unload/TransportUnloadModelActionTests.java new file mode 100644 index 0000000000..eb94c2492f --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/unload/TransportUnloadModelActionTests.java @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.unload; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.ExecutorService; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.transport.unload.UnloadModelNodeRequest; +import org.opensearch.ml.common.transport.unload.UnloadModelNodeResponse; +import org.opensearch.ml.common.transport.unload.UnloadModelNodesRequest; +import org.opensearch.ml.common.transport.unload.UnloadModelNodesResponse; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLStat; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class TransportUnloadModelActionTests extends OpenSearchTestCase { + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private MLModelManager mlModelManager; + + @Mock + private ClusterService clusterService; + + @Mock + private Client client; + + @Mock + private DiscoveryNodeHelper nodeFilter; + + @Mock + private MLStats mlStats; + + private ThreadContext threadContext; + + @Mock + private ExecutorService executorService; + + private TransportUnloadModelAction action; + + private DiscoveryNode localNode; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + action = new TransportUnloadModelAction( + transportService, + actionFilters, + mlModelManager, + clusterService, + null, + client, + nodeFilter, + mlStats + ); + localNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); + when(clusterService.localNode()).thenReturn(localNode); + } + + public void testConstructor() { + assertNotNull(action); + } + + public void testNewNodeRequest() { + final UnloadModelNodesRequest request = new UnloadModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final UnloadModelNodeRequest unLoadrequest = action.newNodeRequest(request); + assertNotNull(unLoadrequest); + } + + public void testNewNodeStreamRequest() throws IOException { + java.util.Map modelToLoadStatus = new HashMap<>(); + modelToLoadStatus.put("modelName:version", "response"); + UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + final UnloadModelNodeResponse unLoadResponse = action.newNodeResponse(output.bytes().streamInput()); + assertNotNull(unLoadResponse); + } + + public void testNodeOperation() { + MLStat mlStat = mock(MLStat.class); + when(mlStats.getStat(any())).thenReturn(mlStat); + final UnloadModelNodesRequest request = new UnloadModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final UnloadModelNodeResponse response = action.nodeOperation(new UnloadModelNodeRequest(request)); + assertNotNull(response); + } + + public void testNewResponseWithUnloadedModelStatus() { + final UnloadModelNodesRequest nodesRequest = new UnloadModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + java.util.Map modelToLoadStatus = new HashMap<>(); + modelToLoadStatus.put("modelName:version", "unloaded"); + UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus); + modelToLoadStatus.put("modelName:version", "unloaded"); + UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final UnloadModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + assertNotNull(response); + + } + + public void testNewResponseWithNotFoundModelStatus() { + final UnloadModelNodesRequest nodesRequest = new UnloadModelNodesRequest( + new String[] { "nodeId1", "nodeId2" }, + new String[] { "modelId1", "modelId2" } + ); + final List responses = new ArrayList<>(); + java.util.Map modelToLoadStatus = new HashMap<>(); + modelToLoadStatus.put("modelName:version", "not_found"); + UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus); + modelToLoadStatus.put("modelName:version", "not_found"); + UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus); + responses.add(response1); + responses.add(response2); + final List failures = new ArrayList<>(); + final UnloadModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); + assertNotNull(response); + } +}