diff --git a/plugin/src/main/java/org/opensearch/ml/action/profile/MLProfileModelResponse.java b/plugin/src/main/java/org/opensearch/ml/action/profile/MLProfileModelResponse.java new file mode 100644 index 0000000000..80ed0b2bcb --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/profile/MLProfileModelResponse.java @@ -0,0 +1,100 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.action.profile; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContentFragment; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.profile.MLModelProfile; + +@Getter +@NoArgsConstructor +public class MLProfileModelResponse implements ToXContentFragment, Writeable { + @Setter + private String[] targetWorkerNodes; + + @Setter + private String[] workerNodes; + + private Map mlModelProfileMap = new HashMap<>(); + + private Map mlTaskMap = new HashMap<>(); + + public MLProfileModelResponse(String[] targetWorkerNodes, String[] workerNodes) { + this.targetWorkerNodes = targetWorkerNodes; + this.workerNodes = workerNodes; + } + + public MLProfileModelResponse(StreamInput in) throws IOException { + this.workerNodes = in.readOptionalStringArray(); + this.targetWorkerNodes = in.readOptionalStringArray(); + if (in.readBoolean()) { + this.mlModelProfileMap = in.readMap(StreamInput::readString, MLModelProfile::new); + } + if (in.readBoolean()) { + this.mlTaskMap = in.readMap(StreamInput::readString, MLTask::new); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (targetWorkerNodes != null) { + builder.field("target_worker_nodes", targetWorkerNodes); + } + if (workerNodes != null) { + builder.field("worker_nodes", workerNodes); + } + if (mlModelProfileMap.size() > 0) { + builder.startObject("nodes"); + for (Map.Entry entry : mlModelProfileMap.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } + if (mlTaskMap.size() > 0) { + builder.startObject("tasks"); + for (Map.Entry entry : mlTaskMap.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeOptionalStringArray(workerNodes); + streamOutput.writeOptionalStringArray(targetWorkerNodes); + if (mlModelProfileMap.size() > 0) { + streamOutput.writeBoolean(true); + streamOutput.writeMap(mlModelProfileMap, StreamOutput::writeString, (o, r) -> r.writeTo(o)); + } else { + streamOutput.writeBoolean(false); + } + if (mlTaskMap.size() > 0) { + streamOutput.writeBoolean(true); + streamOutput.writeMap(mlTaskMap, StreamOutput::writeString, (o, r) -> r.writeTo(o)); + } else { + streamOutput.writeBoolean(false); + } + + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java index 9a64fb0c7f..5de27ce8dc 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java @@ -15,7 +15,9 @@ import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; @@ -28,20 +30,29 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.action.profile.MLProfileAction; +import org.opensearch.ml.action.profile.MLProfileModelResponse; import org.opensearch.ml.action.profile.MLProfileNodeResponse; import org.opensearch.ml.action.profile.MLProfileRequest; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.profile.MLModelProfile; import org.opensearch.ml.profile.MLProfileInput; +import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestStatus; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; @Log4j2 public class RestMLProfileAction extends BaseRestHandler { private static final String PROFILE_ML_ACTION = "profile_ml"; + private static final String VIEW = "view"; + private static final String MODEL_VIEW = "model"; + private static final String NODE_VIEW = "node"; + private ClusterService clusterService; /** @@ -80,6 +91,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } else { mlProfileInput = createMLProfileInputFromRequestParams(request); } + String view = RestActionUtils.getStringParam(request, VIEW).orElse(NODE_VIEW); String[] nodeIds = mlProfileInput.retrieveProfileOnAllNodes() ? getAllNodes(clusterService) : mlProfileInput.getNodeIds().toArray(new String[0]); @@ -93,7 +105,16 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli List nodeProfiles = r.getNodes().stream().filter(s -> !s.isEmpty()).collect(Collectors.toList()); log.debug("Build MLProfileNodeResponse for size of {}", nodeProfiles.size()); if (nodeProfiles.size() > 0) { - r.toXContent(builder, ToXContent.EMPTY_PARAMS); + if (NODE_VIEW.equals(view)) { + r.toXContent(builder, ToXContent.EMPTY_PARAMS); + } else if (MODEL_VIEW.equals(view)) { + Map modelCentricProfileMap = buildModelCentricResult(nodeProfiles); + builder.startObject("models"); + for (Map.Entry entry : modelCentricProfileMap.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } } builder.endObject(); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); @@ -105,6 +126,59 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli }; } + /** + * The data structure for node centric is: + * MLProfileNodeResponse: + * taskMap: Map + * modelMap: Map model_id, MLModelProfile + * And we need to convert to format like this: + * modelMap: Map> + */ + private Map buildModelCentricResult(List nodeResponses) { + // aggregate model information into one final map. + Map modelCentricMap = new HashMap<>(); + for (MLProfileNodeResponse mlProfileNodeResponse : nodeResponses) { + String nodeId = mlProfileNodeResponse.getNode().getId(); + Map modelProfileMap = mlProfileNodeResponse.getMlNodeModels(); + Map taskProfileMap = mlProfileNodeResponse.getMlNodeTasks(); + for (Map.Entry entry : modelProfileMap.entrySet()) { + MLProfileModelResponse mlProfileModelResponse = modelCentricMap.get(entry.getKey()); + if (mlProfileModelResponse == null) { + mlProfileModelResponse = new MLProfileModelResponse( + entry.getValue().getTargetWorkerNodes(), + entry.getValue().getWorkerNodes() + ); + modelCentricMap.put(entry.getKey(), mlProfileModelResponse); + } + if (mlProfileModelResponse.getTargetWorkerNodes() == null || mlProfileModelResponse.getWorkerNodes() == null) { + mlProfileModelResponse.setTargetWorkerNodes(entry.getValue().getTargetWorkerNodes()); + mlProfileModelResponse.setWorkerNodes(entry.getValue().getWorkerNodes()); + } + // Create a new object and remove targetWorkerNodes and workerNodes. + MLModelProfile modelProfile = new MLModelProfile( + entry.getValue().getModelState(), + entry.getValue().getPredictor(), + null, + null, + entry.getValue().getModelInferenceStats(), + entry.getValue().getPredictRequestStats() + ); + mlProfileModelResponse.getMlModelProfileMap().putAll(ImmutableMap.of(nodeId, modelProfile)); + } + + for (Map.Entry entry : taskProfileMap.entrySet()) { + String modelId = entry.getValue().getModelId(); + MLProfileModelResponse mlProfileModelResponse = modelCentricMap.get(modelId); + if (mlProfileModelResponse == null) { + mlProfileModelResponse = new MLProfileModelResponse(); + modelCentricMap.put(modelId, mlProfileModelResponse); + } + mlProfileModelResponse.getMlTaskMap().putAll(ImmutableMap.of(entry.getKey(), entry.getValue())); + } + } + return modelCentricMap; + } + MLProfileInput createMLProfileInputFromRequestParams(RestRequest request) { MLProfileInput mlProfileInput = new MLProfileInput(); Optional modelIds = splitCommaSeparatedParam(request, PARAMETER_MODEL_ID); diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index 94835b892e..00c86bbfda 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -169,4 +169,9 @@ private static String coalesceToEmpty(@Nullable String s) { private static boolean isNullOrEmpty(@Nullable String s) { return s == null || s.isEmpty(); } + + public static Optional getStringParam(RestRequest request, String paramName) { + return Optional.ofNullable(request.param(paramName)); + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/action/profile/MLProfileModelResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/profile/MLProfileModelResponseTests.java new file mode 100644 index 0000000000..f63104dfd1 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/profile/MLProfileModelResponseTests.java @@ -0,0 +1,112 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.action.profile; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.dataset.MLInputDataType; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.profile.MLModelProfile; +import org.opensearch.ml.profile.MLPredictRequestStats; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.test.OpenSearchTestCase; + +public class MLProfileModelResponseTests extends OpenSearchTestCase { + + MLTask mlTask; + MLModelProfile mlModelProfile; + + @Before + public void setup() { + mlTask = MLTask + .builder() + .taskId("test_id") + .modelId("model_id") + .taskType(MLTaskType.TRAINING) + .functionName(FunctionName.AD_LIBSVM) + .state(MLTaskState.CREATED) + .inputType(MLInputDataType.DATA_FRAME) + .progress(0.4f) + .outputIndex("test_index") + .workerNodes(Arrays.asList("test_node")) + .createTime(Instant.ofEpochMilli(123)) + .lastUpdateTime(Instant.ofEpochMilli(123)) + .error("error") + .user(new User()) + .async(false) + .build(); + mlModelProfile = MLModelProfile + .builder() + .predictor("test_predictor") + .workerNodes(new String[] { "node1", "node2" }) + .modelState(MLModelState.LOADED) + .modelInferenceStats(MLPredictRequestStats.builder().count(10L).average(11.0).max(20.0).min(5.0).build()) + .build(); + } + + public void test_create_MLProfileModelResponse_withArgs() throws IOException { + String[] targetWorkerNodes = new String[] { "node1", "node2" }; + String[] workerNodes = new String[] { "node1" }; + Map profileMap = new HashMap<>(); + Map taskMap = new HashMap<>(); + profileMap.put("node1", mlModelProfile); + taskMap.put("node1", mlTask); + MLProfileModelResponse response = new MLProfileModelResponse(targetWorkerNodes, workerNodes); + response.getMlModelProfileMap().putAll(profileMap); + response.getMlTaskMap().putAll(taskMap); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLProfileModelResponse newResponse = new MLProfileModelResponse(output.bytes().streamInput()); + assertNotNull(newResponse.getTargetWorkerNodes()); + assertNotNull(response.getTargetWorkerNodes()); + assertEquals(newResponse.getTargetWorkerNodes().length, response.getTargetWorkerNodes().length); + assertEquals(newResponse.getMlModelProfileMap().size(), response.getMlModelProfileMap().size()); + assertEquals(newResponse.getMlTaskMap().size(), response.getMlTaskMap().size()); + } + + public void test_create_MLProfileModelResponse_NoArgs() throws IOException { + MLProfileModelResponse response = new MLProfileModelResponse(); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + MLProfileModelResponse newResponse = new MLProfileModelResponse(output.bytes().streamInput()); + assertNull(response.getWorkerNodes()); + assertNull(newResponse.getWorkerNodes()); + } + + public void test_toXContent() throws IOException { + String[] targetWorkerNodes = new String[] { "node1", "node2" }; + String[] workerNodes = new String[] { "node1" }; + Map profileMap = new HashMap<>(); + Map taskMap = new HashMap<>(); + profileMap.put("node1", mlModelProfile); + taskMap.put("node1", mlTask); + MLProfileModelResponse response = new MLProfileModelResponse(targetWorkerNodes, workerNodes); + response.getMlModelProfileMap().putAll(profileMap); + response.getMlTaskMap().putAll(taskMap); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String xContentString = TestHelper.xContentBuilderToString(builder); + System.out.println(xContentString); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java index 76a63441f0..6a702e8277 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java @@ -12,8 +12,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ml.utils.TestHelper.getProfileRestRequest; -import static org.opensearch.ml.utils.TestHelper.setupTestClusterState; +import static org.opensearch.ml.utils.TestHelper.*; import java.io.IOException; import java.time.Instant; @@ -68,6 +67,7 @@ import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; public class RestMLProfileActionTests extends OpenSearchTestCase { @Rule @@ -286,6 +286,14 @@ public void test_PrepareRequest_Failure() throws Exception { verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any()); } + public void test_WhenViewIsModel_ReturnModelViewResult() throws Exception { + MLProfileInput mlProfileInput = new MLProfileInput(); + RestRequest request = getProfileRestRequestWithQueryParams(mlProfileInput, ImmutableMap.of("view", "model")); + profileAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLProfileRequest.class); + verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any()); + } + private RestRequest getRestRequest() { Map params = new HashMap<>(); params.put("task_id", "test_id"); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index e7b03ac1df..5dcfe1f0d4 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -213,14 +213,22 @@ public static RestRequest getStatsRestRequest(MLStatsInput input) throws IOExcep } public static RestRequest getProfileRestRequest(MLProfileInput input) throws IOException { - XContentBuilder builder = XContentFactory.jsonBuilder(); - input.toXContent(builder, ToXContent.EMPTY_PARAMS); - String requestContent = TestHelper.xContentBuilderToString(builder); + return new FakeRestRequest.Builder(getXContentRegistry()) + .withContent(new BytesArray(buildRequestContent(input)), XContentType.JSON) + .build(); + } - RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) - .withContent(new BytesArray(requestContent), XContentType.JSON) + public static RestRequest getProfileRestRequestWithQueryParams(MLProfileInput input, Map params) throws IOException { + return new FakeRestRequest.Builder(getXContentRegistry()) + .withContent(new BytesArray(buildRequestContent(input)), XContentType.JSON) + .withParams(params) .build(); - return request; + } + + private static String buildRequestContent(MLProfileInput input) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + return TestHelper.xContentBuilderToString(builder); } public static RestRequest getStatsRestRequest() {