diff --git a/common/src/main/java/org/opensearch/ml/common/MLTask.java b/common/src/main/java/org/opensearch/ml/common/MLTask.java index 84acb4d123..9080181e62 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTask.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTask.java @@ -20,6 +20,9 @@ import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.USER; @@ -54,7 +57,7 @@ public class MLTask implements ToXContentObject, Writeable { private Float progress; private final String outputIndex; @Setter - private String workerNode; + private List workerNodes; private final Instant createTime; private Instant lastUpdateTime; @Setter @@ -72,7 +75,7 @@ public MLTask( MLInputDataType inputType, Float progress, String outputIndex, - String workerNode, + List workerNodes, Instant createTime, Instant lastUpdateTime, String error, @@ -87,7 +90,7 @@ public MLTask( this.inputType = inputType; this.progress = progress; this.outputIndex = outputIndex; - this.workerNode = workerNode; + this.workerNodes = workerNodes; this.createTime = createTime; this.lastUpdateTime = lastUpdateTime; this.error = error; @@ -108,7 +111,7 @@ public MLTask(StreamInput input) throws IOException { } this.progress = input.readOptionalFloat(); this.outputIndex = input.readOptionalString(); - this.workerNode = input.readString(); + this.workerNodes = input.readStringList(); this.createTime = input.readInstant(); this.lastUpdateTime = input.readInstant(); this.error = input.readOptionalString(); @@ -135,7 +138,7 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalFloat(progress); out.writeOptionalString(outputIndex); - out.writeString(workerNode); + out.writeStringCollection(workerNodes); out.writeInstant(createTime); out.writeInstant(lastUpdateTime); out.writeOptionalString(error); @@ -174,8 +177,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params if (outputIndex != null) { builder.field(OUTPUT_INDEX_FIELD, outputIndex); } - if (workerNode != null) { - builder.field(WORKER_NODE_FIELD, workerNode); + if (workerNodes != null) { + builder.field(WORKER_NODE_FIELD, workerNodes); } if (createTime != null) { builder.field(CREATE_TIME_FIELD, createTime.toEpochMilli()); @@ -207,7 +210,7 @@ public static MLTask parse(XContentParser parser) throws IOException { MLInputDataType inputType = null; Float progress = null; String outputIndex = null; - String workerNode = null; + List workerNodes = null; Instant createTime = null; Instant lastUpdateTime = null; String error = null; @@ -245,7 +248,15 @@ public static MLTask parse(XContentParser parser) throws IOException { outputIndex = parser.text(); break; case WORKER_NODE_FIELD: - workerNode = parser.text(); + if (XContentParser.Token.START_ARRAY == parser.currentToken()) { + workerNodes = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + workerNodes.add(parser.text()); + } + } else { + String[] nodes = parser.text().split(","); + workerNodes = Arrays.asList(nodes); + } break; case CREATE_TIME_FIELD: createTime = Instant.ofEpochMilli(parser.longValue()); @@ -276,7 +287,7 @@ public static MLTask parse(XContentParser parser) throws IOException { .inputType(inputType) .progress(progress) .outputIndex(outputIndex) - .workerNode(workerNode) + .workerNodes(workerNodes) .createTime(createTime) .lastUpdateTime(lastUpdateTime) .error(error) diff --git a/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java b/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java index 36d1892af4..39d391bf33 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common; +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; + import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -32,7 +37,7 @@ public void setup() { .functionName(FunctionName.KMEANS) .state(MLTaskState.RUNNING) .inputType(MLInputDataType.DATA_FRAME) - .workerNode("node1") + .workerNodes(Arrays.asList("node1")) .progress(0.0f) .outputIndex("test_index") .error("test_error") @@ -57,7 +62,7 @@ public void toXContent() throws IOException { Assert.assertEquals( "{\"task_id\":\"dummy taskId\",\"model_id\":\"test_model_id\",\"task_type\":\"PREDICTION\"," + "\"function_name\":\"KMEANS\",\"state\":\"RUNNING\",\"input_type\":\"DATA_FRAME\",\"progress\":0.0," - + "\"output_index\":\"test_index\",\"worker_node\":\"node1\",\"create_time\":1641599940000," + + "\"output_index\":\"test_index\",\"worker_node\":[\"node1\"],\"create_time\":1641599940000," + "\"last_update_time\":1641600000000,\"error\":\"test_error\",\"is_async\":false}", taskContent ); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java index 11c287a521..d870265702 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.function.Consumer; @@ -48,7 +49,7 @@ public void setUp() throws Exception { .functionName(functionName) .state(MLTaskState.RUNNING) .inputType(MLInputDataType.DATA_FRAME) - .workerNode("mlTaskNode1") + .workerNodes(Arrays.asList("mlTaskNode1")) .progress(0.0f) .outputIndex("test_index") .error("test_error") diff --git a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java index 690e3ae274..f79900a1ca 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java @@ -27,6 +27,7 @@ import java.io.UncheckedIOException; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -52,7 +53,7 @@ public void setUp() throws Exception { .functionName(functionName) .state(MLTaskState.RUNNING) .inputType(MLInputDataType.DATA_FRAME) - .workerNode("mlTaskNode1") + .workerNodes(Arrays.asList("mlTaskNode1")) .progress(0.0f) .outputIndex("test_index") .error("test_error") diff --git a/common/src/test/java/org/opensearch/ml/common/transport/load/LoadModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/load/LoadModelInputTest.java index c70f3e5e1c..be626d5576 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/load/LoadModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/load/LoadModelInputTest.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -44,7 +45,7 @@ public void setUp() throws Exception { .functionName(FunctionName.LINEAR_REGRESSION) .state(MLTaskState.RUNNING) .inputType(MLInputDataType.DATA_FRAME) - .workerNode("node1") + .workerNodes(Arrays.asList("node1")) .progress(0.0f) .outputIndex("test_index") .error("test_error") diff --git a/common/src/test/java/org/opensearch/ml/common/transport/load/LoadModelNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/load/LoadModelNodesRequestTest.java index 157852a691..ff224de006 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/load/LoadModelNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/load/LoadModelNodesRequestTest.java @@ -20,6 +20,7 @@ import java.net.InetAddress; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Arrays; import java.util.Collections; import static org.junit.Assert.*; @@ -70,7 +71,7 @@ public void setUp() throws Exception { .functionName(FunctionName.LINEAR_REGRESSION) .state(MLTaskState.RUNNING) .inputType(MLInputDataType.DATA_FRAME) - .workerNode("node1") + .workerNodes(Arrays.asList("node1")) .progress(0.0f) .outputIndex("test_index") .error("test_error") diff --git a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java index ee25f49fbb..10b2e0b07c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Arrays; import static org.junit.Assert.*; @@ -34,7 +35,7 @@ public void setUp() { .inputType(MLInputDataType.DATA_FRAME) .progress(1.3f) .outputIndex("some index") - .workerNode("some node") + .workerNodes(Arrays.asList("some node")) .createTime(Instant.ofEpochMilli(123)) .lastUpdateTime(Instant.ofEpochMilli(123)) .error("error") @@ -58,7 +59,7 @@ public void writeTo_Success() throws IOException { assertEquals(response.mlTask.getInputType(), parsedResponse.mlTask.getInputType()); assertEquals(response.mlTask.getProgress(), parsedResponse.mlTask.getProgress()); assertEquals(response.mlTask.getOutputIndex(), parsedResponse.mlTask.getOutputIndex()); - assertEquals(response.mlTask.getWorkerNode(), parsedResponse.mlTask.getWorkerNode()); + assertEquals(response.mlTask.getWorkerNodes(), parsedResponse.mlTask.getWorkerNodes()); assertEquals(response.mlTask.getCreateTime(), parsedResponse.mlTask.getCreateTime()); assertEquals(response.mlTask.getLastUpdateTime(), parsedResponse.mlTask.getLastUpdateTime()); assertEquals(response.mlTask.getError(), parsedResponse.mlTask.getError()); @@ -79,7 +80,7 @@ public void toXContentTest() throws IOException { "\"input_type\":\"DATA_FRAME\"," + "\"progress\":1.3," + "\"output_index\":\"some index\"," + - "\"worker_node\":\"some node\"," + + "\"worker_node\":[\"some node\"]," + "\"create_time\":123," + "\"last_update_time\":123," + "\"error\":\"error\"," + diff --git a/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java index e66d9124ba..892dae5208 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java @@ -137,8 +137,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { String taskId = response.getId(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload/TransportUploadModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload/TransportUploadModelAction.java index b3d553d795..5dd5030ee8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload/TransportUploadModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload/TransportUploadModelAction.java @@ -52,6 +52,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @Log4j2 @@ -127,12 +128,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { String nodeId = node.getId(); - mlTask.setWorkerNode(nodeId); + mlTask.setWorkerNodes(ImmutableList.of(nodeId)); mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 874def5a01..9f2b9a9b84 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -6,6 +6,7 @@ package org.opensearch.ml.model; import java.util.DoubleSummaryStatistics; +import java.util.List; import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -29,16 +30,30 @@ public class MLModelCache { private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLModelState modelState; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) FunctionName functionName; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Predictable predictor; + private @Getter(AccessLevel.PROTECTED) Set targetWorkerNodes; private final Set workerNodes; private final Queue modelInferenceDurationQueue; private final Queue predictRequestDurationQueue; public MLModelCache() { + targetWorkerNodes = ConcurrentHashMap.newKeySet(); workerNodes = ConcurrentHashMap.newKeySet(); modelInferenceDurationQueue = new ConcurrentLinkedQueue<>(); predictRequestDurationQueue = new ConcurrentLinkedQueue<>(); } + public void setTargetWorkerNodes(List targetWorkerNodes) { + if (targetWorkerNodes == null || targetWorkerNodes.size() == 0) { + throw new IllegalArgumentException("Null or empty target worker nodes"); + } + this.targetWorkerNodes.clear(); + this.targetWorkerNodes.addAll(targetWorkerNodes); + } + + public String[] getTargetWorkerNodes() { + return targetWorkerNodes.toArray(new String[0]); + } + public void removeWorkerNode(String nodeId) { workerNodes.remove(nodeId); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 9c3ac29875..081328e55c 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -41,7 +42,7 @@ public MLModelCacheHelper(ClusterService clusterService, Settings settings) { * @param state model state * @param functionName function name */ - public synchronized void initModelState(String modelId, MLModelState state, FunctionName functionName) { + public synchronized void initModelState(String modelId, MLModelState state, FunctionName functionName, List targetWorkerNodes) { if (isModelRunningOnNode(modelId)) { throw new MLLimitExceededException("Duplicate load model task"); } @@ -49,6 +50,7 @@ public synchronized void initModelState(String modelId, MLModelState state, Func MLModelCache modelCache = new MLModelCache(); modelCache.setModelState(state); modelCache.setFunctionName(functionName); + modelCache.setTargetWorkerNodes(targetWorkerNodes); modelCaches.put(modelId, modelCache); } @@ -254,6 +256,10 @@ public MLModelProfile getModelProfile(String modelId) { if (modelCache.getPredictor() != null) { builder.predictor(modelCache.getPredictor().toString()); } + String[] targetWorkerNodes = modelCache.getTargetWorkerNodes(); + if (targetWorkerNodes.length > 0) { + builder.targetWorkerNodes(targetWorkerNodes); + } String[] workerNodes = modelCache.getWorkerNodes(); if (workerNodes.length > 0) { builder.workerNodes(workerNodes); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 9b943681a2..fe02a64999 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -430,7 +430,7 @@ public void loadModel( listener.onFailure(new IllegalArgumentException("Exceed max model per node limit")); return; } - modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName); + modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName, mlTask.getWorkerNodes()); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { checkAndAddRunningTask(mlTask, maxLoadTasksPerNode); this.getModel(modelId, threadedActionListener(LOAD_THREAD_POOL, ActionListener.wrap(mlModel -> { diff --git a/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java b/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java index 80da27d9c5..385a1c91ce 100644 --- a/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java +++ b/plugin/src/main/java/org/opensearch/ml/profile/MLModelProfile.java @@ -24,6 +24,7 @@ public class MLModelProfile implements ToXContentFragment, Writeable { private final MLModelState modelState; private final String predictor; + private final String[] targetWorkerNodes; private final String[] workerNodes; private final MLPredictRequestStats modelInferenceStats; private final MLPredictRequestStats predictRequestStats; @@ -32,12 +33,14 @@ public class MLModelProfile implements ToXContentFragment, Writeable { public MLModelProfile( MLModelState modelState, String predictor, + String[] targetWorkerNodes, String[] workerNodes, MLPredictRequestStats modelInferenceStats, MLPredictRequestStats predictRequestStats ) { this.modelState = modelState; this.predictor = predictor; + this.targetWorkerNodes = targetWorkerNodes; this.workerNodes = workerNodes; this.modelInferenceStats = modelInferenceStats; this.predictRequestStats = predictRequestStats; @@ -52,6 +55,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (predictor != null) { builder.field("predictor", predictor); } + if (targetWorkerNodes != null) { + builder.field("target_worker_nodes", targetWorkerNodes); + } if (workerNodes != null) { builder.field("worker_nodes", workerNodes); } @@ -72,6 +78,7 @@ public MLModelProfile(StreamInput in) throws IOException { this.modelState = null; } this.predictor = in.readOptionalString(); + this.targetWorkerNodes = in.readOptionalStringArray(); this.workerNodes = in.readOptionalStringArray(); if (in.readBoolean()) { this.modelInferenceStats = new MLPredictRequestStats(in); @@ -94,6 +101,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalString(predictor); + out.writeOptionalStringArray(targetWorkerNodes); out.writeOptionalStringArray(workerNodes); if (modelInferenceStats != null) { out.writeBoolean(true); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 1c7039bf96..86caa7b366 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -60,6 +60,8 @@ import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; +import com.google.common.collect.ImmutableList; + /** * MLPredictTaskRunner is responsible for running predict tasks. */ @@ -159,7 +161,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -225,7 +227,8 @@ public void testTransportUploadModelActionDoExecuteWithCreateTaskException() { listener.onFailure(new Exception("Failed to create upload model task")); return null; }).when(mlTaskManager).createMLTask(any(), any()); - + when(node1.getId()).thenReturn("NodeId1"); + when(clusterService.localNode()).thenReturn(node1); transportUploadModelAction.doExecute(task, prepareRequest(), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java index ec2d988077..5aa7b03e40 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java @@ -12,7 +12,9 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; import static org.opensearch.ml.utils.TestHelper.clusterSetting; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; @@ -47,6 +49,8 @@ public class MLModelCacheHelperTests extends OpenSearchTestCase { private TextEmbeddingModel predictor; private int maxMonitoringRequests; + private List targetWorkerNodes; + @Before public void setup() { MockitoAnnotations.openMocks(this); @@ -61,11 +65,13 @@ public void setup() { modelId = "model_id1"; nodeId = "node_id1"; predictor = spy(new TextEmbeddingModel()); + targetWorkerNodes = new ArrayList<>(); + targetWorkerNodes.add(nodeId); } public void testModelState() { assertFalse(cacheHelper.isModelLoaded(modelId)); - cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); assertFalse(cacheHelper.isModelLoaded(modelId)); cacheHelper.setModelState(modelId, MLModelState.LOADED); assertTrue(cacheHelper.isModelLoaded(modelId)); @@ -75,8 +81,8 @@ public void testModelState() { public void testModelState_DuplicateError() { expectedEx.expect(MLLimitExceededException.class); expectedEx.expectMessage("Duplicate load model task"); - cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); - cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); } public void testPredictor_NotFoundException() { @@ -86,7 +92,7 @@ public void testPredictor_NotFoundException() { } public void testPredictor() { - cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); assertNull(cacheHelper.getPredictor(modelId)); cacheHelper.setPredictor(modelId, predictor); assertEquals(predictor, cacheHelper.getPredictor(modelId)); @@ -94,7 +100,7 @@ public void testPredictor() { public void testGetAndRemoveModel() { assertFalse(cacheHelper.isModelRunningOnNode(modelId)); - cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); String[] loadedModels = cacheHelper.getLoadedModels(); assertEquals(0, loadedModels.length); @@ -110,7 +116,7 @@ public void testGetAndRemoveModel() { } public void testRemoveModel_WrongModelId() { - cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); cacheHelper.removeModel("wrong_model_id"); assertArrayEquals(new String[] { modelId }, cacheHelper.getAllModels()); } @@ -163,7 +169,7 @@ public void testRemoveWorkerNode_ModelState() { } public void testRemoveModel_Loaded() { - cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); cacheHelper.setModelState(modelId, MLModelState.LOADED); cacheHelper.setPredictor(modelId, predictor); cacheHelper.removeModel(modelId); @@ -179,7 +185,7 @@ public void testClearWorkerNodes_NullModelState() { } public void testClearWorkerNodes_ModelState() { - cacheHelper.initModelState(modelId, MLModelState.LOADED, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId, MLModelState.LOADED, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); cacheHelper.addWorkerNode(modelId, nodeId); cacheHelper.clearWorkerNodes(); assertArrayEquals(new String[] { modelId }, cacheHelper.getAllModels()); @@ -206,7 +212,7 @@ public void testSyncWorkerNodes_NullModelState() { public void testSyncWorkerNodes_ModelState() { String modelId2 = "model_id2"; - cacheHelper.initModelState(modelId2, MLModelState.LOADED, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId2, MLModelState.LOADED, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); cacheHelper.addWorkerNode(modelId, nodeId); cacheHelper.addWorkerNode(modelId2, nodeId); @@ -243,7 +249,7 @@ public void testGetModelProfile_WrongModelId() { } public void testGetModelProfile() { - cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); cacheHelper.setModelState(modelId, MLModelState.LOADED); cacheHelper.setPredictor(modelId, predictor); cacheHelper.addWorkerNode(modelId, nodeId); @@ -266,7 +272,7 @@ public void testGetModelProfile() { } public void testGetModelProfile_Loading() { - cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING); + cacheHelper.initModelState(modelId, MLModelState.LOADING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes); MLModelProfile modelProfile = cacheHelper.getModelProfile(modelId); assertNotNull(modelProfile); assertEquals(MLModelState.LOADING, modelProfile.getModelState()); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java index 8e2ce9f2fb..76a63441f0 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java @@ -67,6 +67,8 @@ import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import com.google.common.collect.ImmutableList; + public class RestMLProfileActionTests extends OpenSearchTestCase { @Rule public ExpectedException thrown = ExpectedException.none(); @@ -115,7 +117,7 @@ public void setup() throws IOException { .inputType(MLInputDataType.DATA_FRAME) .progress(0.4f) .outputIndex("test_index") - .workerNode("test_node") + .workerNodes(ImmutableList.of("test_node")) .createTime(Instant.ofEpochMilli(123)) .lastUpdateTime(Instant.ofEpochMilli(123)) .error("error")