Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add planning work nodes to model #715

Merged
merged 2 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ public class CommonValue {
+ MLModel.CURRENT_WORKER_NODE_COUNT_FIELD
+ "\" : {\"type\": \"integer\"},\n"
+ " \""
+ MLModel.PLANNING_WORKER_NODES_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.MODEL_CONFIG_FIELD
+ "\" : {\"properties\":{\""
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
Expand Down
21 changes: 20 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.USER;
Expand Down Expand Up @@ -54,6 +56,7 @@ public class MLModel implements ToXContentObject {
public static final String TOTAL_CHUNKS_FIELD = "total_chunks";
public static final String PLANNING_WORKER_NODE_COUNT_FIELD = "planning_worker_node_count";
public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count";
public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes";

private String name;
private FunctionName algorithm;
Expand Down Expand Up @@ -81,6 +84,7 @@ public class MLModel implements ToXContentObject {
private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes
private Integer currentWorkerNodeCount; // model is deployed to how many nodes

private String[] planningWorkerNodes; // plan to deploy model to these nodes
@Builder(toBuilder = true)
public MLModel(String name,
FunctionName algorithm,
Expand All @@ -101,7 +105,8 @@ public MLModel(String name,
String modelId, Integer chunkNumber,
Integer totalChunks,
Integer planningWorkerNodeCount,
Integer currentWorkerNodeCount) {
Integer currentWorkerNodeCount,
String[] planningWorkerNodes) {
this.name = name;
this.algorithm = algorithm;
this.version = version;
Expand All @@ -123,6 +128,7 @@ public MLModel(String name,
this.totalChunks = totalChunks;
this.planningWorkerNodeCount = planningWorkerNodeCount;
this.currentWorkerNodeCount = currentWorkerNodeCount;
this.planningWorkerNodes = planningWorkerNodes;
}

public MLModel(StreamInput input) throws IOException{
Expand Down Expand Up @@ -158,6 +164,7 @@ public MLModel(StreamInput input) throws IOException{
totalChunks = input.readOptionalInt();
planningWorkerNodeCount = input.readOptionalInt();
currentWorkerNodeCount = input.readOptionalInt();
planningWorkerNodes = input.readOptionalStringArray();
}
}

Expand Down Expand Up @@ -203,6 +210,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(totalChunks);
out.writeOptionalInt(planningWorkerNodeCount);
out.writeOptionalInt(currentWorkerNodeCount);
out.writeOptionalStringArray(planningWorkerNodes);
}

@Override
Expand Down Expand Up @@ -271,6 +279,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (currentWorkerNodeCount != null) {
builder.field(CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkerNodeCount);
}
if (planningWorkerNodes != null && planningWorkerNodes.length > 0) {
builder.field(PLANNING_WORKER_NODES_FIELD, planningWorkerNodes);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -300,6 +311,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
Integer totalChunks = null;
Integer planningWorkerNodeCount = null;
Integer currentWorkerNodeCount = null;
List<String> planningWorkerNodes = new ArrayList<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -361,6 +373,12 @@ public static MLModel parse(XContentParser parser) throws IOException {
case CURRENT_WORKER_NODE_COUNT_FIELD:
currentWorkerNodeCount = parser.intValue();
break;
case PLANNING_WORKER_NODES_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
planningWorkerNodes.add(parser.text());
}
break;
case CREATED_TIME_FIELD:
createdTime = Instant.ofEpochMilli(parser.longValue());
break;
Expand Down Expand Up @@ -403,6 +421,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
.totalChunks(totalChunks)
.planningWorkerNodeCount(planningWorkerNodeCount)
.currentWorkerNodeCount(currentWorkerNodeCount)
.planningWorkerNodes(planningWorkerNodes.toArray(new String[0]))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,19 @@ void updateModelLoadStatusAndTriggerOnNodesAction(
mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, MLModelState.LOAD_FAILED));
});

List<String> workerNodes = eligibleNodes.stream().map(n -> n.getId()).collect(Collectors.toList());
mlModelManager
.updateModel(
modelId,
ImmutableMap
.of(MLModel.MODEL_STATE_FIELD, MLModelState.LOADING, MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, eligibleNodes.size()),
.of(
MLModel.MODEL_STATE_FIELD,
MLModelState.LOADING,
MLModel.PLANNING_WORKER_NODE_COUNT_FIELD,
eligibleNodes.size(),
MLModel.PLANNING_WORKER_NODES_FIELD,
workerNodes
),
ActionListener
.wrap(
r -> client.execute(MLLoadModelOnNodeAction.INSTANCE, loadModelRequest, actionListener),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

import java.lang.reflect.Field;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
Expand Down Expand Up @@ -231,17 +234,25 @@ public void testUpdateModelLoadStatusAndTriggerOnNodesAction_success() throws No

when(mlTaskManager.contains(anyString())).thenReturn(true);

DiscoveryNode discoveryNode = mock(DiscoveryNode.class);
when(discoveryNode.getId()).thenReturn("node1");
transportLoadModelAction
.updateModelLoadStatusAndTriggerOnNodesAction(
modelId,
"mock_task_id",
mlModel,
localNodeId,
mlTask,
eligibleNodes,
Arrays.asList(discoveryNode),
FunctionName.ANOMALY_LOCALIZATION
);
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());

ArgumentCaptor<Map<String, Object>> captor = ArgumentCaptor.forClass(Map.class);
verify(mlModelManager).updateModel(anyString(), captor.capture(), any());
Map<String, Object> map = captor.getValue();
assertNotNull(map.get(MLModel.PLANNING_WORKER_NODES_FIELD));
assertEquals(1, (((List) map.get(MLModel.PLANNING_WORKER_NODES_FIELD)).size()));
}

public void testUpdateModelLoadStatusAndTriggerOnNodesAction_whenMLTaskManagerThrowException_ListenerOnFailureExecuted() {
Expand Down