Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

include deployment status in deploy API response #1336

Merged
merged 1 commit into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,47 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLTaskType;

import java.io.IOException;

@Getter
public class MLDeployModelResponse extends ActionResponse implements ToXContentObject {
public static final String TASK_ID_FIELD = "task_id";
public static final String TASK_TYPE_FIELD = "task_type";
public static final String STATUS_FIELD = "status";

private String taskId;
private MLTaskType taskType;
private String status;

public MLDeployModelResponse(StreamInput in) throws IOException {
super(in);
this.taskId = in.readString();
this.taskType = in.readEnum(MLTaskType.class);
this.status = in.readString();
}

public MLDeployModelResponse(String taskId, String status) {
public MLDeployModelResponse(String taskId, MLTaskType mlTaskType, String status) {
this.taskId = taskId;
this.taskType = mlTaskType;
this.status= status;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(taskId);
out.writeEnum(taskType);
out.writeString(status);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(TASK_ID_FIELD, taskId);
if (taskType != null) {
builder.field(TASK_TYPE_FIELD, taskType);
}
builder.field(STATUS_FIELD, status);
builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLTaskType;

import java.io.IOException;

Expand All @@ -16,37 +17,40 @@ public class MLDeployModelResponseTest {

private String taskId;
private String status;
private MLTaskType taskType;

@Before
public void setUp() throws Exception {
taskId = "test_id";
status = "test";
taskType = MLTaskType.DEPLOY_MODEL;
}

@Test
public void writeTo_Success() throws IOException {
// Setup
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
MLDeployModelResponse response = new MLDeployModelResponse(taskId, status);
MLDeployModelResponse response = new MLDeployModelResponse(taskId, taskType, status);
// Run the test
response.writeTo(bytesStreamOutput);
MLDeployModelResponse parsedResponse = new MLDeployModelResponse(bytesStreamOutput.bytes().streamInput());
// Verify the results
assertEquals(response.getTaskId(), parsedResponse.getTaskId());
assertEquals(response.getTaskType(), parsedResponse.getTaskType());
assertEquals(response.getStatus(), parsedResponse.getStatus());
}

@Test
public void testToXContent() throws IOException {
// Setup
MLDeployModelResponse response = new MLDeployModelResponse(taskId, status);
MLDeployModelResponse response = new MLDeployModelResponse(taskId, taskType, status);
// Run the test
XContentBuilder builder = XContentFactory.jsonBuilder();
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
// Verify the results
assertEquals("{\"task_id\":\"test_id\"," +
assertEquals("{\"task_id\":\"test_id\"," + "\"task_type\":\"DEPLOY_MODEL\"," +
"\"status\":\"test\"}", jsonStr);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,14 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
mlTask.setTaskId(taskId);
if (algorithm == FunctionName.REMOTE) {
mlTaskManager.add(mlTask, nodeIds);
deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener);
return;
}
try {
mlTaskManager.add(mlTask, nodeIds);
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskState.CREATED.name()));
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name()));
threadPool
.executor(DEPLOY_THREAD_POOL)
.execute(
Expand Down Expand Up @@ -260,6 +265,82 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl

}

@VisibleForTesting
void deployRemoteModel(
MLModel mlModel,
MLTask mlTask,
String localNodeId,
List<DiscoveryNode> eligibleNodes,
boolean deployToAllNodes,
ActionListener<MLDeployModelResponse> listener
) {
MLDeployModelInput deployModelInput = new MLDeployModelInput(
mlModel.getModelId(),
mlTask.getTaskId(),
mlModel.getModelContentHash(),
eligibleNodes.size(),
localNodeId,
deployToAllNodes,
mlTask
);

MLDeployModelNodesRequest deployModelRequest = new MLDeployModelNodesRequest(
eligibleNodes.toArray(new DiscoveryNode[0]),
deployModelInput
);

ActionListener<MLDeployModelNodesResponse> actionListener = deployModelNodesResponseListener(
mlTask.getTaskId(),
mlModel.getModelId(),
listener
);
List<String> workerNodes = eligibleNodes.stream().map(n -> n.getId()).collect(Collectors.toList());
mlModelManager
.updateModel(
mlModel.getModelId(),
ImmutableMap
.of(
MLModel.MODEL_STATE_FIELD,
MLModelState.DEPLOYING,
MLModel.PLANNING_WORKER_NODE_COUNT_FIELD,
eligibleNodes.size(),
MLModel.PLANNING_WORKER_NODES_FIELD,
workerNodes,
MLModel.DEPLOY_TO_ALL_NODES_FIELD,
deployToAllNodes
),
ActionListener
.wrap(
r -> client.execute(MLDeployModelOnNodeAction.INSTANCE, deployModelRequest, actionListener),
actionListener::onFailure
)
);
}

private ActionListener<MLDeployModelNodesResponse> deployModelNodesResponseListener(
String taskId,
String modelId,
ActionListener<MLDeployModelResponse> listener
) {
return ActionListener.wrap(r -> {
if (mlTaskManager.contains(taskId)) {
mlTaskManager.updateMLTask(taskId, ImmutableMap.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false);
}
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.COMPLETED.name()));
}, e -> {
log.error("Failed to deploy model " + modelId, e);
mlTaskManager
.updateMLTask(
taskId,
ImmutableMap.of(MLTask.ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED),
TASK_SEMAPHORE_TIMEOUT,
true
);
mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED));
listener.onFailure(e);
});
}

@VisibleForTesting
void updateModelDeployStatusAndTriggerOnNodesAction(
String modelId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
throw new IllegalArgumentException("URL can't match trusted url regex");
}
}
System.out.println("registering the model");
boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE;
MLTask mlTask = MLTask
.builder()
Expand All @@ -250,7 +249,6 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
mlTask.setTaskId(taskId);
System.out.println("mlModelManager calls registerMLRemoteModel");
mlModelManager.registerMLRemoteModel(registerModelInput, mlTask, listener);
}, e -> {
logException("Failed to register model", e, log);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ public void test_toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String xContentString = TestHelper.xContentBuilderToString(builder);
System.out.println(xContentString);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void testDeployRemoteModel() throws IOException, InterruptedException {
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
assertEquals("CREATED", (String) responseMap.get("status"));
assertEquals("COMPLETED", (String) responseMap.get("status"));
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ public void testPrepareRequest() throws Exception {
SearchRequest searchRequest = argumentCaptor.getValue();
String[] indices = searchRequest.indices();
assertArrayEquals(new String[] { ML_CONNECTOR_INDEX }, indices);
System.out.println(searchRequest);
assertEquals(
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}",
searchRequest.source().toString()
Expand Down
Loading