From 845346b49a7c1c45cf1721a013b31329ac2792d6 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 31 Jan 2023 12:09:47 -0800 Subject: [PATCH] add planning work nodes to model (#715) * add planning work nodes to model Signed-off-by: Yaliang Wu * add test Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu --- .../org/opensearch/ml/common/CommonValue.java | 3 +++ .../org/opensearch/ml/common/MLModel.java | 21 ++++++++++++++++++- .../action/load/TransportLoadModelAction.java | 10 ++++++++- .../load/TransportLoadModelActionTests.java | 13 +++++++++++- 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 40c9b89e0b..9b4a35e8ae 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -91,6 +91,9 @@ public class CommonValue { + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD + "\" : {\"type\": \"integer\"},\n" + " \"" + + MLModel.PLANNING_WORKER_NODES_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + MLModel.MODEL_CONFIG_FIELD + "\" : {\"properties\":{\"" + MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 46519cfba9..2eacdc80dc 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -21,6 +21,8 @@ import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; +import java.util.List; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.USER; @@ -54,6 +56,7 @@ public class MLModel implements ToXContentObject { public static final String TOTAL_CHUNKS_FIELD = "total_chunks"; public static final String PLANNING_WORKER_NODE_COUNT_FIELD = "planning_worker_node_count"; public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count"; + public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes"; private String name; private FunctionName algorithm; @@ -81,6 +84,7 @@ public class MLModel implements ToXContentObject { private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes private Integer currentWorkerNodeCount; // model is deployed to how many nodes + private String[] planningWorkerNodes; // plan to deploy model to these nodes @Builder(toBuilder = true) public MLModel(String name, FunctionName algorithm, @@ -101,7 +105,8 @@ public MLModel(String name, String modelId, Integer chunkNumber, Integer totalChunks, Integer planningWorkerNodeCount, - Integer currentWorkerNodeCount) { + Integer currentWorkerNodeCount, + String[] planningWorkerNodes) { this.name = name; this.algorithm = algorithm; this.version = version; @@ -123,6 +128,7 @@ public MLModel(String name, this.totalChunks = totalChunks; this.planningWorkerNodeCount = planningWorkerNodeCount; this.currentWorkerNodeCount = currentWorkerNodeCount; + this.planningWorkerNodes = planningWorkerNodes; } public MLModel(StreamInput input) throws IOException{ @@ -158,6 +164,7 @@ public MLModel(StreamInput input) throws IOException{ totalChunks = input.readOptionalInt(); planningWorkerNodeCount = input.readOptionalInt(); currentWorkerNodeCount = input.readOptionalInt(); + planningWorkerNodes = input.readOptionalStringArray(); } } @@ -203,6 +210,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalInt(totalChunks); out.writeOptionalInt(planningWorkerNodeCount); out.writeOptionalInt(currentWorkerNodeCount); + out.writeOptionalStringArray(planningWorkerNodes); } @Override @@ -271,6 +279,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (currentWorkerNodeCount != null) { builder.field(CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkerNodeCount); } + if (planningWorkerNodes != null && planningWorkerNodes.length > 0) { + builder.field(PLANNING_WORKER_NODES_FIELD, planningWorkerNodes); + } builder.endObject(); return builder; } @@ -300,6 +311,7 @@ public static MLModel parse(XContentParser parser) throws IOException { Integer totalChunks = null; Integer planningWorkerNodeCount = null; Integer currentWorkerNodeCount = null; + List planningWorkerNodes = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -361,6 +373,12 @@ public static MLModel parse(XContentParser parser) throws IOException { case CURRENT_WORKER_NODE_COUNT_FIELD: currentWorkerNodeCount = parser.intValue(); break; + case PLANNING_WORKER_NODES_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + planningWorkerNodes.add(parser.text()); + } + break; case CREATED_TIME_FIELD: createdTime = Instant.ofEpochMilli(parser.longValue()); break; @@ -403,6 +421,7 @@ public static MLModel parse(XContentParser parser) throws IOException { .totalChunks(totalChunks) .planningWorkerNodeCount(planningWorkerNodeCount) .currentWorkerNodeCount(currentWorkerNodeCount) + .planningWorkerNodes(planningWorkerNodes.toArray(new String[0])) .build(); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java index 53c9e81780..9fdecd612a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java @@ -254,11 +254,19 @@ void updateModelLoadStatusAndTriggerOnNodesAction( mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, MLModelState.LOAD_FAILED)); }); + List workerNodes = eligibleNodes.stream().map(n -> n.getId()).collect(Collectors.toList()); mlModelManager .updateModel( modelId, ImmutableMap - .of(MLModel.MODEL_STATE_FIELD, MLModelState.LOADING, MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, eligibleNodes.size()), + .of( + MLModel.MODEL_STATE_FIELD, + MLModelState.LOADING, + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, + eligibleNodes.size(), + MLModel.PLANNING_WORKER_NODES_FIELD, + workerNodes + ), ActionListener .wrap( r -> client.execute(MLLoadModelOnNodeAction.INSTANCE, loadModelRequest, actionListener), diff --git a/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelActionTests.java index 033080c436..4f3af4b55d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelActionTests.java @@ -10,10 +10,13 @@ import java.lang.reflect.Field; import java.nio.file.Path; +import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.concurrent.ExecutorService; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; @@ -231,6 +234,8 @@ public void testUpdateModelLoadStatusAndTriggerOnNodesAction_success() throws No when(mlTaskManager.contains(anyString())).thenReturn(true); + DiscoveryNode discoveryNode = mock(DiscoveryNode.class); + when(discoveryNode.getId()).thenReturn("node1"); transportLoadModelAction .updateModelLoadStatusAndTriggerOnNodesAction( modelId, @@ -238,10 +243,16 @@ public void testUpdateModelLoadStatusAndTriggerOnNodesAction_success() throws No mlModel, localNodeId, mlTask, - eligibleNodes, + Arrays.asList(discoveryNode), FunctionName.ANOMALY_LOCALIZATION ); verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + verify(mlModelManager).updateModel(anyString(), captor.capture(), any()); + Map map = captor.getValue(); + assertNotNull(map.get(MLModel.PLANNING_WORKER_NODES_FIELD)); + assertEquals(1, (((List) map.get(MLModel.PLANNING_WORKER_NODES_FIELD)).size())); } public void testUpdateModelLoadStatusAndTriggerOnNodesAction_whenMLTaskManagerThrowException_ListenerOnFailureExecuted() {