Skip to content

Commit

Permalink
change task worker node to list; add target worker node to cache
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Dec 31, 2022
1 parent 6fb7970 commit b68034b
Show file tree
Hide file tree
Showing 23 changed files with 107 additions and 47 deletions.
26 changes: 16 additions & 10 deletions common/src/main/java/org/opensearch/ml/common/MLTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.USER;
Expand Down Expand Up @@ -54,7 +56,7 @@ public class MLTask implements ToXContentObject, Writeable {
private Float progress;
private final String outputIndex;
@Setter
private String workerNode;
private List<String> workerNodes;
private final Instant createTime;
private Instant lastUpdateTime;
@Setter
Expand All @@ -72,7 +74,7 @@ public MLTask(
MLInputDataType inputType,
Float progress,
String outputIndex,
String workerNode,
List<String> workerNodes,
Instant createTime,
Instant lastUpdateTime,
String error,
Expand All @@ -87,7 +89,7 @@ public MLTask(
this.inputType = inputType;
this.progress = progress;
this.outputIndex = outputIndex;
this.workerNode = workerNode;
this.workerNodes = workerNodes;
this.createTime = createTime;
this.lastUpdateTime = lastUpdateTime;
this.error = error;
Expand All @@ -108,7 +110,7 @@ public MLTask(StreamInput input) throws IOException {
}
this.progress = input.readOptionalFloat();
this.outputIndex = input.readOptionalString();
this.workerNode = input.readString();
this.workerNodes = input.readStringList();
this.createTime = input.readInstant();
this.lastUpdateTime = input.readInstant();
this.error = input.readOptionalString();
Expand All @@ -135,7 +137,7 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeOptionalFloat(progress);
out.writeOptionalString(outputIndex);
out.writeString(workerNode);
out.writeStringCollection(workerNodes);
out.writeInstant(createTime);
out.writeInstant(lastUpdateTime);
out.writeOptionalString(error);
Expand Down Expand Up @@ -174,8 +176,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (outputIndex != null) {
builder.field(OUTPUT_INDEX_FIELD, outputIndex);
}
if (workerNode != null) {
builder.field(WORKER_NODE_FIELD, workerNode);
if (workerNodes != null) {
builder.field(WORKER_NODE_FIELD, workerNodes);
}
if (createTime != null) {
builder.field(CREATE_TIME_FIELD, createTime.toEpochMilli());
Expand Down Expand Up @@ -207,7 +209,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
MLInputDataType inputType = null;
Float progress = null;
String outputIndex = null;
String workerNode = null;
List<String> workerNodes = null;
Instant createTime = null;
Instant lastUpdateTime = null;
String error = null;
Expand Down Expand Up @@ -245,7 +247,11 @@ public static MLTask parse(XContentParser parser) throws IOException {
outputIndex = parser.text();
break;
case WORKER_NODE_FIELD:
workerNode = parser.text();
workerNodes = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
workerNodes.add(parser.text());
}
break;
case CREATE_TIME_FIELD:
createTime = Instant.ofEpochMilli(parser.longValue());
Expand Down Expand Up @@ -276,7 +282,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
.inputType(inputType)
.progress(progress)
.outputIndex(outputIndex)
.workerNode(workerNode)
.workerNodes(workerNodes)
.createTime(createTime)
.lastUpdateTime(lastUpdateTime)
.error(error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;

import org.junit.Assert;
import org.junit.Before;
Expand All @@ -32,7 +33,7 @@ public void setup() {
.functionName(FunctionName.KMEANS)
.state(MLTaskState.RUNNING)
.inputType(MLInputDataType.DATA_FRAME)
.workerNode("node1")
.workerNodes(Arrays.asList("node1"))
.progress(0.0f)
.outputIndex("test_index")
.error("test_error")
Expand All @@ -57,7 +58,7 @@ public void toXContent() throws IOException {
Assert.assertEquals(
"{\"task_id\":\"dummy taskId\",\"model_id\":\"test_model_id\",\"task_type\":\"PREDICTION\","
+ "\"function_name\":\"KMEANS\",\"state\":\"RUNNING\",\"input_type\":\"DATA_FRAME\",\"progress\":0.0,"
+ "\"output_index\":\"test_index\",\"worker_node\":\"node1\",\"create_time\":1641599940000,"
+ "\"output_index\":\"test_index\",\"worker_node\":[\"node1\"],\"create_time\":1641599940000,"
+ "\"last_update_time\":1641600000000,\"error\":\"test_error\",\"is_async\":false}",
taskContent
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.function.Consumer;
Expand All @@ -48,7 +49,7 @@ public void setUp() throws Exception {
.functionName(functionName)
.state(MLTaskState.RUNNING)
.inputType(MLInputDataType.DATA_FRAME)
.workerNode("mlTaskNode1")
.workerNodes(Arrays.asList("mlTaskNode1"))
.progress(0.0f)
.outputIndex("test_index")
.error("test_error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.io.UncheckedIOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;

Expand All @@ -52,7 +53,7 @@ public void setUp() throws Exception {
.functionName(functionName)
.state(MLTaskState.RUNNING)
.inputType(MLInputDataType.DATA_FRAME)
.workerNode("mlTaskNode1")
.workerNodes(Arrays.asList("mlTaskNode1"))
.progress(0.0f)
.outputIndex("test_index")
.error("test_error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;

Expand All @@ -44,7 +45,7 @@ public void setUp() throws Exception {
.functionName(FunctionName.LINEAR_REGRESSION)
.state(MLTaskState.RUNNING)
.inputType(MLInputDataType.DATA_FRAME)
.workerNode("node1")
.workerNodes(Arrays.asList("node1"))
.progress(0.0f)
.outputIndex("test_index")
.error("test_error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.net.InetAddress;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;

import static org.junit.Assert.*;
Expand Down Expand Up @@ -70,7 +71,7 @@ public void setUp() throws Exception {
.functionName(FunctionName.LINEAR_REGRESSION)
.state(MLTaskState.RUNNING)
.inputType(MLInputDataType.DATA_FRAME)
.workerNode("node1")
.workerNodes(Arrays.asList("node1"))
.progress(0.0f)
.outputIndex("test_index")
.error("test_error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;

import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
Expand All @@ -35,7 +36,7 @@ public void setUp() {
.inputType(MLInputDataType.DATA_FRAME)
.progress(1.3f)
.outputIndex("some index")
.workerNode("some node")
.workerNodes(Arrays.asList("some node"))
.createTime(Instant.ofEpochMilli(123))
.lastUpdateTime(Instant.ofEpochMilli(123))
.error("error")
Expand All @@ -59,7 +60,7 @@ public void writeTo_Success() throws IOException {
assertEquals(response.mlTask.getInputType(), parsedResponse.mlTask.getInputType());
assertEquals(response.mlTask.getProgress(), parsedResponse.mlTask.getProgress());
assertEquals(response.mlTask.getOutputIndex(), parsedResponse.mlTask.getOutputIndex());
assertEquals(response.mlTask.getWorkerNode(), parsedResponse.mlTask.getWorkerNode());
assertEquals(response.mlTask.getWorkerNodes(), parsedResponse.mlTask.getWorkerNodes());
assertEquals(response.mlTask.getCreateTime(), parsedResponse.mlTask.getCreateTime());
assertEquals(response.mlTask.getLastUpdateTime(), parsedResponse.mlTask.getLastUpdateTime());
assertEquals(response.mlTask.getError(), parsedResponse.mlTask.getError());
Expand All @@ -80,7 +81,7 @@ public void toXContentTest() throws IOException {
"\"input_type\":\"DATA_FRAME\"," +
"\"progress\":1.3," +
"\"output_index\":\"some index\"," +
"\"worker_node\":\"some node\"," +
"\"worker_node\":[\"some node\"]," +
"\"create_time\":123," +
"\"last_update_time\":123," +
"\"error\":\"error\"," +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo
return;
}

String workerNodes = String.join(",", nodeIds);
log.warn("Will load model on these nodes: {}", workerNodes);
log.info("Will load model on these nodes: {}", String.join(",", nodeIds));
String localNodeId = clusterService.localNode().getId();

String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };
Expand All @@ -156,7 +155,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo
.createTime(Instant.now())
.lastUpdateTime(Instant.now())
.state(MLTaskState.CREATED)
.workerNode(workerNodes)
.workerNodes(nodeIds)
.build();
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

@Log4j2
Expand Down Expand Up @@ -127,12 +128,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Upload
.createTime(Instant.now())
.lastUpdateTime(Instant.now())
.state(MLTaskState.CREATED)
.workerNode(clusterService.localNode().getId())
.workerNodes(ImmutableList.of(clusterService.localNode().getId()))
.build();

mlTaskDispatcher.dispatch(ActionListener.wrap(node -> {
String nodeId = node.getId();
mlTask.setWorkerNode(nodeId);
mlTask.setWorkerNodes(ImmutableList.of(nodeId));

mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
Expand Down
15 changes: 15 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.model;

import java.util.DoubleSummaryStatistics;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -29,16 +30,30 @@ public class MLModelCache {
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLModelState modelState;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) FunctionName functionName;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Predictable predictor;
private @Getter(AccessLevel.PROTECTED) Set<String> targetWorkerNodes;
private final Set<String> workerNodes;
private final Queue<Double> modelInferenceDurationQueue;
private final Queue<Double> predictRequestDurationQueue;

public MLModelCache() {
targetWorkerNodes = ConcurrentHashMap.newKeySet();
workerNodes = ConcurrentHashMap.newKeySet();
modelInferenceDurationQueue = new ConcurrentLinkedQueue<>();
predictRequestDurationQueue = new ConcurrentLinkedQueue<>();
}

public void setTargetWorkerNodes(List<String> targetWorkerNodes) {
if (targetWorkerNodes == null || targetWorkerNodes.size() == 0) {
throw new IllegalArgumentException("Null or empty target worker nodes");
}
this.targetWorkerNodes.clear();
this.targetWorkerNodes.addAll(targetWorkerNodes);
}

public String[] getTargetWorkerNodes() {
return targetWorkerNodes.toArray(new String[0]);
}

public void removeWorkerNode(String nodeId) {
workerNodes.remove(nodeId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -41,14 +42,15 @@ public MLModelCacheHelper(ClusterService clusterService, Settings settings) {
* @param state model state
* @param functionName function name
*/
public synchronized void initModelState(String modelId, MLModelState state, FunctionName functionName) {
public synchronized void initModelState(String modelId, MLModelState state, FunctionName functionName, List<String> targetWorkerNodes) {
if (isModelRunningOnNode(modelId)) {
throw new MLLimitExceededException("Duplicate load model task");
}
log.debug("init model state for model {}, state: {}", modelId, state);
MLModelCache modelCache = new MLModelCache();
modelCache.setModelState(state);
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
modelCaches.put(modelId, modelCache);
}

Expand Down Expand Up @@ -254,6 +256,10 @@ public MLModelProfile getModelProfile(String modelId) {
if (modelCache.getPredictor() != null) {
builder.predictor(modelCache.getPredictor().toString());
}
String[] targetWorkerNodes = modelCache.getTargetWorkerNodes();
if (targetWorkerNodes.length > 0) {
builder.targetWorkerNodes(targetWorkerNodes);
}
String[] workerNodes = modelCache.getWorkerNodes();
if (workerNodes.length > 0) {
builder.workerNodes(workerNodes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ public void loadModel(
listener.onFailure(new IllegalArgumentException("Exceed max model per node limit"));
return;
}
modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName);
modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName, mlTask.getWorkerNodes());
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
checkAndAddRunningTask(mlTask, maxLoadTasksPerNode);
this.getModel(modelId, threadedActionListener(LOAD_THREAD_POOL, ActionListener.wrap(mlModel -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class MLModelProfile implements ToXContentFragment, Writeable {

private final MLModelState modelState;
private final String predictor;
private final String[] targetWorkerNodes;
private final String[] workerNodes;
private final MLPredictRequestStats modelInferenceStats;
private final MLPredictRequestStats predictRequestStats;
Expand All @@ -32,12 +33,14 @@ public class MLModelProfile implements ToXContentFragment, Writeable {
public MLModelProfile(
MLModelState modelState,
String predictor,
String[] targetWorkerNodes,
String[] workerNodes,
MLPredictRequestStats modelInferenceStats,
MLPredictRequestStats predictRequestStats
) {
this.modelState = modelState;
this.predictor = predictor;
this.targetWorkerNodes = targetWorkerNodes;
this.workerNodes = workerNodes;
this.modelInferenceStats = modelInferenceStats;
this.predictRequestStats = predictRequestStats;
Expand All @@ -52,6 +55,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (predictor != null) {
builder.field("predictor", predictor);
}
if (targetWorkerNodes != null) {
builder.field("worker_nodes", targetWorkerNodes);
}
if (workerNodes != null) {
builder.field("worker_nodes", workerNodes);
}
Expand All @@ -72,6 +78,7 @@ public MLModelProfile(StreamInput in) throws IOException {
this.modelState = null;
}
this.predictor = in.readOptionalString();
this.targetWorkerNodes = in.readOptionalStringArray();
this.workerNodes = in.readOptionalStringArray();
if (in.readBoolean()) {
this.modelInferenceStats = new MLPredictRequestStats(in);
Expand All @@ -94,6 +101,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalString(predictor);
out.writeOptionalStringArray(targetWorkerNodes);
out.writeOptionalStringArray(workerNodes);
if (modelInferenceStats != null) {
out.writeBoolean(true);
Expand Down
Loading

0 comments on commit b68034b

Please sign in to comment.