diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java index a37ac71e8a..ddf0104c9e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java @@ -12,31 +12,37 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.MLTaskType; import java.io.IOException; @Getter public class MLDeployModelResponse extends ActionResponse implements ToXContentObject { public static final String TASK_ID_FIELD = "task_id"; + public static final String TASK_TYPE_FIELD = "task_type"; public static final String STATUS_FIELD = "status"; private String taskId; + private MLTaskType taskType; private String status; public MLDeployModelResponse(StreamInput in) throws IOException { super(in); this.taskId = in.readString(); + this.taskType = in.readEnum(MLTaskType.class); this.status = in.readString(); } - public MLDeployModelResponse(String taskId, String status) { + public MLDeployModelResponse(String taskId, MLTaskType mlTaskType, String status) { this.taskId = taskId; + this.taskType = mlTaskType; this.status= status; } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(taskId); + out.writeEnum(taskType); out.writeString(status); } @@ -44,6 +50,9 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); builder.field(TASK_ID_FIELD, taskId); + if (taskType != null) { + builder.field(TASK_TYPE_FIELD, taskType); + } builder.field(STATUS_FIELD, status); builder.endObject(); return builder; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java index e80f3aefd6..5b4c2f2cd3 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java @@ -6,6 +6,7 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.MLTaskType; import java.io.IOException; @@ -16,37 +17,40 @@ public class MLDeployModelResponseTest { private String taskId; private String status; + private MLTaskType taskType; @Before public void setUp() throws Exception { taskId = "test_id"; status = "test"; + taskType = MLTaskType.DEPLOY_MODEL; } @Test public void writeTo_Success() throws IOException { // Setup BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - MLDeployModelResponse response = new MLDeployModelResponse(taskId, status); + MLDeployModelResponse response = new MLDeployModelResponse(taskId, taskType, status); // Run the test response.writeTo(bytesStreamOutput); MLDeployModelResponse parsedResponse = new MLDeployModelResponse(bytesStreamOutput.bytes().streamInput()); // Verify the results assertEquals(response.getTaskId(), parsedResponse.getTaskId()); + assertEquals(response.getTaskType(), parsedResponse.getTaskType()); assertEquals(response.getStatus(), parsedResponse.getStatus()); } @Test public void testToXContent() throws IOException { // Setup - MLDeployModelResponse response = new MLDeployModelResponse(taskId, status); + MLDeployModelResponse response = new MLDeployModelResponse(taskId, taskType, status); // Run the test XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); // Verify the results - assertEquals("{\"task_id\":\"test_id\"," + + assertEquals("{\"task_id\":\"test_id\"," + "\"task_type\":\"DEPLOY_MODEL\"," + "\"status\":\"test\"}", jsonStr); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index acdfb5ec09..0fc7edf55d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -23,6 +23,8 @@ import java.util.Set; import java.util.stream.Collectors; +import javax.swing.*; + import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -213,9 +215,14 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { String taskId = response.getId(); mlTask.setTaskId(taskId); + if (algorithm == FunctionName.REMOTE) { + mlTaskManager.add(mlTask, nodeIds); + deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener); + return; + } try { mlTaskManager.add(mlTask, nodeIds); - listener.onResponse(new MLDeployModelResponse(taskId, MLTaskState.CREATED.name())); + listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name())); threadPool .executor(DEPLOY_THREAD_POOL) .execute( @@ -260,6 +267,82 @@ protected void doExecute(Task task, ActionRequest request, ActionListener eligibleNodes, + boolean deployToAllNodes, + ActionListener listener + ) { + MLDeployModelInput deployModelInput = new MLDeployModelInput( + mlModel.getModelId(), + mlTask.getTaskId(), + mlModel.getModelContentHash(), + eligibleNodes.size(), + localNodeId, + deployToAllNodes, + mlTask + ); + + MLDeployModelNodesRequest deployModelRequest = new MLDeployModelNodesRequest( + eligibleNodes.toArray(new DiscoveryNode[0]), + deployModelInput + ); + + ActionListener actionListener = deployModelNodesResponseListener( + mlTask.getTaskId(), + mlModel.getModelId(), + listener + ); + List workerNodes = eligibleNodes.stream().map(n -> n.getId()).collect(Collectors.toList()); + mlModelManager + .updateModel( + mlModel.getModelId(), + ImmutableMap + .of( + MLModel.MODEL_STATE_FIELD, + MLModelState.DEPLOYING, + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, + eligibleNodes.size(), + MLModel.PLANNING_WORKER_NODES_FIELD, + workerNodes, + MLModel.DEPLOY_TO_ALL_NODES_FIELD, + deployToAllNodes + ), + ActionListener + .wrap( + r -> client.execute(MLDeployModelOnNodeAction.INSTANCE, deployModelRequest, actionListener), + actionListener::onFailure + ) + ); + } + + private ActionListener deployModelNodesResponseListener( + String taskId, + String modelId, + ActionListener listener + ) { + return ActionListener.wrap(r -> { + if (mlTaskManager.contains(taskId)) { + mlTaskManager.updateMLTask(taskId, ImmutableMap.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false); + } + listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.COMPLETED.name())); + }, e -> { + log.error("Failed to deploy model " + modelId, e); + mlTaskManager + .updateMLTask( + taskId, + ImmutableMap.of(MLTask.ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED), + TASK_SEMAPHORE_TIMEOUT, + true + ); + mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED)); + listener.onFailure(e); + }); + } + @VisibleForTesting void updateModelDeployStatusAndTriggerOnNodesAction( String modelId, diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 6fcb190dfa..4556073d7d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -233,7 +233,6 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen throw new IllegalArgumentException("URL can't match trusted url regex"); } } - System.out.println("registering the model"); boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE; MLTask mlTask = MLTask .builder() @@ -250,7 +249,6 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); mlTask.setTaskId(taskId); - System.out.println("mlModelManager calls registerMLRemoteModel"); mlModelManager.registerMLRemoteModel(registerModelInput, mlTask, listener); }, e -> { logException("Failed to register model", e, log); diff --git a/plugin/src/test/java/org/opensearch/ml/action/profile/MLProfileModelResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/profile/MLProfileModelResponseTests.java index ca98524590..262cee551c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/profile/MLProfileModelResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/profile/MLProfileModelResponseTests.java @@ -106,7 +106,6 @@ public void test_toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String xContentString = TestHelper.xContentBuilderToString(builder); - System.out.println(xContentString); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 346d210ebc..4eced40d94 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -129,7 +129,7 @@ public void testDeployRemoteModel() throws IOException, InterruptedException { String modelId = (String) responseMap.get("model_id"); response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - assertEquals("CREATED", (String) responseMap.get("status")); + assertEquals("COMPLETED", (String) responseMap.get("status")); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java index 2886380eba..97ed37230a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java @@ -137,7 +137,6 @@ public void testPrepareRequest() throws Exception { SearchRequest searchRequest = argumentCaptor.getValue(); String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_CONNECTOR_INDEX }, indices); - System.out.println(searchRequest); assertEquals( "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}", searchRequest.source().toString()