Skip to content

Commit

Permalink
add planning work nodes to model (opensearch-project#715)
Browse files Browse the repository at this point in the history
* add planning work nodes to model

Signed-off-by: Yaliang Wu <[email protected]>

* add test

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Mar 4, 2023
1 parent 4d9c571 commit 9a657fa
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ public class CommonValue {
+ MLModel.CURRENT_WORKER_NODE_COUNT_FIELD
+ "\" : {\"type\": \"integer\"},\n"
+ " \""
+ MLModel.PLANNING_WORKER_NODES_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.MODEL_CONFIG_FIELD
+ "\" : {\"properties\":{\""
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
Expand Down
21 changes: 20 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.USER;
Expand Down Expand Up @@ -54,6 +56,7 @@ public class MLModel implements ToXContentObject {
public static final String TOTAL_CHUNKS_FIELD = "total_chunks";
public static final String PLANNING_WORKER_NODE_COUNT_FIELD = "planning_worker_node_count";
public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count";
public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes";

private String name;
private FunctionName algorithm;
Expand Down Expand Up @@ -81,6 +84,7 @@ public class MLModel implements ToXContentObject {
private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes
private Integer currentWorkerNodeCount; // model is deployed to how many nodes

private String[] planningWorkerNodes; // plan to deploy model to these nodes
@Builder(toBuilder = true)
public MLModel(String name,
FunctionName algorithm,
Expand All @@ -101,7 +105,8 @@ public MLModel(String name,
String modelId, Integer chunkNumber,
Integer totalChunks,
Integer planningWorkerNodeCount,
Integer currentWorkerNodeCount) {
Integer currentWorkerNodeCount,
String[] planningWorkerNodes) {
this.name = name;
this.algorithm = algorithm;
this.version = version;
Expand All @@ -123,6 +128,7 @@ public MLModel(String name,
this.totalChunks = totalChunks;
this.planningWorkerNodeCount = planningWorkerNodeCount;
this.currentWorkerNodeCount = currentWorkerNodeCount;
this.planningWorkerNodes = planningWorkerNodes;
}

public MLModel(StreamInput input) throws IOException{
Expand Down Expand Up @@ -158,6 +164,7 @@ public MLModel(StreamInput input) throws IOException{
totalChunks = input.readOptionalInt();
planningWorkerNodeCount = input.readOptionalInt();
currentWorkerNodeCount = input.readOptionalInt();
planningWorkerNodes = input.readOptionalStringArray();
}
}

Expand Down Expand Up @@ -203,6 +210,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(totalChunks);
out.writeOptionalInt(planningWorkerNodeCount);
out.writeOptionalInt(currentWorkerNodeCount);
out.writeOptionalStringArray(planningWorkerNodes);
}

@Override
Expand Down Expand Up @@ -271,6 +279,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (currentWorkerNodeCount != null) {
builder.field(CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkerNodeCount);
}
if (planningWorkerNodes != null && planningWorkerNodes.length > 0) {
builder.field(PLANNING_WORKER_NODES_FIELD, planningWorkerNodes);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -300,6 +311,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
Integer totalChunks = null;
Integer planningWorkerNodeCount = null;
Integer currentWorkerNodeCount = null;
List<String> planningWorkerNodes = new ArrayList<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -361,6 +373,12 @@ public static MLModel parse(XContentParser parser) throws IOException {
case CURRENT_WORKER_NODE_COUNT_FIELD:
currentWorkerNodeCount = parser.intValue();
break;
case PLANNING_WORKER_NODES_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
planningWorkerNodes.add(parser.text());
}
break;
case CREATED_TIME_FIELD:
createdTime = Instant.ofEpochMilli(parser.longValue());
break;
Expand Down Expand Up @@ -403,6 +421,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
.totalChunks(totalChunks)
.planningWorkerNodeCount(planningWorkerNodeCount)
.currentWorkerNodeCount(currentWorkerNodeCount)
.planningWorkerNodes(planningWorkerNodes.toArray(new String[0]))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,19 @@ void updateModelLoadStatusAndTriggerOnNodesAction(
mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, MLModelState.LOAD_FAILED));
});

List<String> workerNodes = eligibleNodes.stream().map(n -> n.getId()).collect(Collectors.toList());
mlModelManager
.updateModel(
modelId,
ImmutableMap
.of(MLModel.MODEL_STATE_FIELD, MLModelState.LOADING, MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, eligibleNodes.size()),
.of(
MLModel.MODEL_STATE_FIELD,
MLModelState.LOADING,
MLModel.PLANNING_WORKER_NODE_COUNT_FIELD,
eligibleNodes.size(),
MLModel.PLANNING_WORKER_NODES_FIELD,
workerNodes
),
ActionListener
.wrap(
r -> client.execute(MLLoadModelOnNodeAction.INSTANCE, loadModelRequest, actionListener),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

import java.lang.reflect.Field;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
Expand Down Expand Up @@ -231,17 +234,25 @@ public void testUpdateModelLoadStatusAndTriggerOnNodesAction_success() throws No

when(mlTaskManager.contains(anyString())).thenReturn(true);

DiscoveryNode discoveryNode = mock(DiscoveryNode.class);
when(discoveryNode.getId()).thenReturn("node1");
transportLoadModelAction
.updateModelLoadStatusAndTriggerOnNodesAction(
modelId,
"mock_task_id",
mlModel,
localNodeId,
mlTask,
eligibleNodes,
Arrays.asList(discoveryNode),
FunctionName.ANOMALY_LOCALIZATION
);
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());

ArgumentCaptor<Map<String, Object>> captor = ArgumentCaptor.forClass(Map.class);
verify(mlModelManager).updateModel(anyString(), captor.capture(), any());
Map<String, Object> map = captor.getValue();
assertNotNull(map.get(MLModel.PLANNING_WORKER_NODES_FIELD));
assertEquals(1, (((List) map.get(MLModel.PLANNING_WORKER_NODES_FIELD)).size()));
}

public void testUpdateModelLoadStatusAndTriggerOnNodesAction_whenMLTaskManagerThrowException_ListenerOnFailureExecuted() {
Expand Down

0 comments on commit 9a657fa

Please sign in to comment.