Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profile API enhancement #653

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.indices.MLIndicesHandler',
'org.opensearch.ml.rest.RestMLPredictionAction',
'org.opensearch.ml.profile.MLModelProfile',
'org.opensearch.ml.profile.MLDeploymentProfile',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why exclude this class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to cover this file by UT but it keeps showing the branch coverage is not above the threshold, and I see the MLModelProfile is also been excluded here so I thought this is a workaround of the code coverage.

'org.opensearch.ml.profile.MLPredictRequestStats',
'org.opensearch.ml.action.load.TransportLoadModelAction',
'org.opensearch.ml.model.MLModelManager',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.factory;

import org.opensearch.action.get.MultiGetAction;
import org.opensearch.action.get.MultiGetRequestBuilder;
import org.opensearch.client.Client;

public class MultiGetRequestBuilderFactory {

public MultiGetRequestBuilder createMultiGetRequestBuilder(Client client) {
return new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems over design by adding a new method for just one line of code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should always avoid new in code since the concept of IoC, also new is a blocker of UT, using factory method is a preferred approach to avoid these.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.algorithms.anomalylocalization.AnomalyLocalizerImpl;
import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator;
import org.opensearch.ml.factory.MultiGetRequestBuilderFactory;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.model.MLModelCacheHelper;
Expand Down Expand Up @@ -180,6 +181,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin {
public static final String ML_ROLE_NAME = "ml";
private NamedXContentRegistry xContentRegistry;

private MultiGetRequestBuilderFactory multiGetRequestBuilderFactory;

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return ImmutableList
Expand Down Expand Up @@ -268,7 +271,7 @@ public Collection<Object> createComponents(

mlModelMetaCreate = new MLModelMetaCreate(mlIndicesHandler, threadPool, client);
mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry);

multiGetRequestBuilderFactory = new MultiGetRequestBuilderFactory();
MLTaskDispatcher mlTaskDispatcher = new MLTaskDispatcher(clusterService, client, settings, nodeHelper);
mlTrainingTaskRunner = new MLTrainingTaskRunner(
threadPool,
Expand Down Expand Up @@ -391,7 +394,7 @@ public List<RestHandler> getRestHandlers(
RestMLGetTaskAction restMLGetTaskAction = new RestMLGetTaskAction();
RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction();
RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction();
RestMLProfileAction restMLProfileAction = new RestMLProfileAction(clusterService);
RestMLProfileAction restMLProfileAction = new RestMLProfileAction(clusterService, multiGetRequestBuilderFactory);
RestMLUploadModelAction restMLUploadModelAction = new RestMLUploadModelAction();
RestMLLoadModelAction restMLLoadModelAction = new RestMLLoadModelAction();
RestMLUnloadModelAction restMLUnloadModelAction = new RestMLUnloadModelAction(clusterService);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.profile;

import java.io.IOException;
import java.util.List;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

import org.opensearch.common.Strings;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.ToXContentFragment;
import org.opensearch.common.xcontent.XContentBuilder;

@Getter
@Log4j2
public class MLDeploymentProfile implements ToXContentFragment, Writeable {

private String modelId;

private String modelName;

@Setter
private List<String> targetNodeIds;

@Setter
private List<String> notDeployedNodeIds;

public MLDeploymentProfile(String modelName, String modelId, List<String> targetNodeIds, List<String> notDeployedNodeIds) {
this.modelName = modelName;
this.modelId = modelId;
this.targetNodeIds = targetNodeIds;
this.notDeployedNodeIds = notDeployedNodeIds;
}

public MLDeploymentProfile(String modelName, String modelId) {
this.modelName = modelName;
this.modelId = modelId;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (!Strings.isNullOrEmpty(modelName)) {
builder.field("model_name", modelName);
}
if (targetNodeIds != null && targetNodeIds.size() > 0) {
builder.field("target_node_ids", targetNodeIds);
}
if (notDeployedNodeIds != null && notDeployedNodeIds.size() > 0) {
builder.field("not_deployed_node_ids", notDeployedNodeIds);
}
builder.endObject();
return builder;
}

public MLDeploymentProfile(StreamInput in) throws IOException {
this.modelName = in.readOptionalString();
this.targetNodeIds = in.readOptionalStringList();
this.notDeployedNodeIds = in.readOptionalStringList();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(modelName);
out.writeOptionalStringCollection(targetNodeIds);
out.writeOptionalStringCollection(notDeployedNodeIds);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public class MLProfileInput implements ToXContentObject, Writeable {
public static final String RETURN_ALL_TASKS = "return_all_tasks";
public static final String RETURN_ALL_MODELS = "return_all_models";

public static final String PROFILE_AND_DEPLOYMENT = "profileAndDeployment";

/**
* Which models profiles will be retrieved
*/
Expand All @@ -52,18 +54,29 @@ public class MLProfileInput implements ToXContentObject, Writeable {
@Setter
private boolean returnAllModels;

@Setter
private String profileAndDeployment;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that we add new content in profile response. How about renaming this as target_response?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, I can make the change, what's your opinion regarding the path parameter? Do we need to change as well? /_plugins/_ml/profile?profile_and_deployment=all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change the path parameter too


/**
* Constructor
* @param modelIds
* @param taskIds
*/
@Builder
public MLProfileInput(Set<String> modelIds, Set<String> taskIds, Set<String> nodeIds, boolean returnAllTasks, boolean returnAllModels) {
public MLProfileInput(
Set<String> modelIds,
Set<String> taskIds,
Set<String> nodeIds,
boolean returnAllTasks,
boolean returnAllModels,
String profileAndDeployment
) {
this.modelIds = modelIds;
this.taskIds = taskIds;
this.nodeIds = nodeIds;
this.returnAllTasks = returnAllTasks;
this.returnAllModels = returnAllModels;
this.profileAndDeployment = profileAndDeployment;
}

public MLProfileInput() {
Expand All @@ -81,6 +94,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalStringCollection(nodeIds);
out.writeBoolean(returnAllTasks);
out.writeBoolean(returnAllModels);
out.writeOptionalString(profileAndDeployment);
}

public MLProfileInput(StreamInput input) throws IOException {
Expand All @@ -89,6 +103,7 @@ public MLProfileInput(StreamInput input) throws IOException {
nodeIds = input.readBoolean() ? new HashSet<>(input.readStringList()) : new HashSet<>();
this.returnAllTasks = input.readBoolean();
this.returnAllModels = input.readBoolean();
this.profileAndDeployment = input.readOptionalString();
}

public static MLProfileInput parse(XContentParser parser) throws IOException {
Expand All @@ -97,6 +112,7 @@ public static MLProfileInput parse(XContentParser parser) throws IOException {
Set<String> nodeIds = new HashSet<>();
boolean returnALlTasks = false;
boolean returnAllModels = false;
String profileAndDeployment = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);

Expand All @@ -120,6 +136,9 @@ public static MLProfileInput parse(XContentParser parser) throws IOException {
case RETURN_ALL_MODELS:
returnAllModels = parser.booleanValue();
break;
case PROFILE_AND_DEPLOYMENT:
profileAndDeployment = parser.textOrNull();
break;
default:
parser.skipChildren();
break;
Expand All @@ -133,6 +152,7 @@ public static MLProfileInput parse(XContentParser parser) throws IOException {
.nodeIds(nodeIds)
.returnAllTasks(returnALlTasks)
.returnAllModels(returnAllModels)
.profileAndDeployment(profileAndDeployment)
.build();
}

Expand All @@ -150,6 +170,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
builder.field(RETURN_ALL_TASKS, returnAllTasks);
builder.field(RETURN_ALL_MODELS, returnAllModels);
if (profileAndDeployment != null) {
builder.field(PROFILE_AND_DEPLOYMENT, profileAndDeployment);
}
builder.endObject();
return builder;
}
Expand Down
Loading