Skip to content

Commit

Permalink
support dispatching execute task; don't dispatch ML task again (opens…
Browse files Browse the repository at this point in the history
…earch-project#279)

* support dispatching execute task; don't dispatch ML task again

Signed-off-by: Yaliang Wu <[email protected]>

* remove MLPredictTaskRunner from jacoco exclusion list

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Apr 25, 2022
1 parent 8c21c7c commit a032707
Show file tree
Hide file tree
Showing 15 changed files with 267 additions and 145 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport;

import lombok.Getter;
import lombok.Setter;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;

import java.io.IOException;
import java.util.UUID;

@Getter
@Setter
public class MLTaskRequest extends ActionRequest {

protected boolean dispatchTask;
protected final String requestID;

public MLTaskRequest(boolean dispatchTask) {
this.dispatchTask = dispatchTask;
this.requestID = UUID.randomUUID().toString();
}

public MLTaskRequest(StreamInput in) throws IOException {
super(in);
this.requestID = in.readString();
this.dispatchTask = in.readBoolean();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(requestID);
out.writeBoolean(dispatchTask);
}

@Override
public ActionRequestValidationException validate() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.Input;
import org.opensearch.ml.common.transport.MLTaskRequest;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -31,17 +32,22 @@
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLExecuteTaskRequest extends ActionRequest {
public class MLExecuteTaskRequest extends MLTaskRequest {

FunctionName functionName;
Input input;

@Builder
public MLExecuteTaskRequest(@NonNull FunctionName functionName, Input input) {
public MLExecuteTaskRequest(@NonNull FunctionName functionName, Input input, boolean dispatchTask) {
super(dispatchTask);
this.functionName = functionName;
this.input = input;
}

public MLExecuteTaskRequest(@NonNull FunctionName functionName, Input input) {
this(functionName, input, true);
}

public MLExecuteTaskRequest(StreamInput in) throws IOException {
super(in);
this.functionName = in.readEnum(FunctionName.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,29 @@
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.ml.common.transport.MLTaskRequest;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLPredictionTaskRequest extends ActionRequest {
public class MLPredictionTaskRequest extends MLTaskRequest {

String modelId;
MLInput mlInput;

@Builder
public MLPredictionTaskRequest(String modelId, MLInput mlInput) {
public MLPredictionTaskRequest(String modelId, MLInput mlInput, boolean dispatchTask) {
super(dispatchTask);
this.mlInput = mlInput;
this.modelId = modelId;
}

public MLPredictionTaskRequest(String modelId, MLInput mlInput) {
this(modelId, mlInput, true);
}

public MLPredictionTaskRequest(StreamInput in) throws IOException {
super(in);
this.modelId = in.readOptionalString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.transport.MLTaskRequest;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -29,7 +30,7 @@
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLTrainingTaskRequest extends ActionRequest {
public class MLTrainingTaskRequest extends MLTaskRequest {

/**
* the name of algorithm
Expand All @@ -38,11 +39,16 @@ public class MLTrainingTaskRequest extends ActionRequest {
boolean async;

@Builder
public MLTrainingTaskRequest(MLInput mlInput, boolean async) {
public MLTrainingTaskRequest(MLInput mlInput, boolean async, boolean dispatchTask) {
super(dispatchTask);
this.mlInput = mlInput;
this.async = async;
}

public MLTrainingTaskRequest(MLInput mlInput, boolean async) {
this(mlInput, async, true);
}

public MLTrainingTaskRequest(StreamInput in) throws IOException {
super(in);
this.mlInput = new MLInput(in);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ public Collection<Object> createComponents(
mlStats,
mlInputDatasetHandler,
mlTaskDispatcher,
mlCircuitBreakerService
mlCircuitBreakerService,
xContentRegistry
);
mlTrainAndPredictTaskRunner = new MLTrainAndPredictTaskRunner(
threadPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.Input;
import org.opensearch.ml.common.parameter.Output;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.TransportResponseHandler;

/**
* MLExecuteTaskRunner is responsible for running execute tasks.
Expand All @@ -44,26 +46,30 @@ public MLExecuteTaskRunner(
MLTaskDispatcher mlTaskDispatcher,
MLCircuitBreakerService mlCircuitBreakerService
) {
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService);
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
this.threadPool = threadPool;
this.clusterService = clusterService;
this.client = client;
this.mlInputDatasetHandler = mlInputDatasetHandler;
}

@Override
protected String getTransportActionName() {
return MLExecuteTaskAction.NAME;
}

@Override
protected TransportResponseHandler<MLExecuteTaskResponse> getResponseHandler(ActionListener<MLExecuteTaskResponse> listener) {
return new ActionListenerResponseHandler<>(listener, MLExecuteTaskResponse::new);
}

/**
* Execute algorithm and return result.
* TODO: 1. support backend task run; 2. support dispatch task to remote node
* @param request MLExecuteTaskRequest
* @param transportService transport service
* @param listener Action listener
*/
@Override
public void executeTask(
MLExecuteTaskRequest request,
TransportService transportService,
ActionListener<MLExecuteTaskResponse> listener
) {
protected void executeTask(MLExecuteTaskRequest request, ActionListener<MLExecuteTaskResponse> listener) {
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
try {
Input input = request.getInput();
Expand Down
110 changes: 60 additions & 50 deletions plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.task;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX;
import static org.opensearch.ml.permission.AccessController.checkUserPermissions;
import static org.opensearch.ml.permission.AccessController.getUserContext;
Expand All @@ -17,7 +18,6 @@

import java.time.Instant;
import java.util.Base64;
import java.util.Map;
import java.util.UUID;

import lombok.extern.log4j.Log4j2;
Expand All @@ -32,6 +32,10 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.dataframe.DataFrame;
Expand All @@ -53,7 +57,7 @@
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.TransportResponseHandler;

/**
* MLPredictTaskRunner is responsible for running predict tasks.
Expand All @@ -64,6 +68,7 @@ public class MLPredictTaskRunner extends MLTaskRunner<MLPredictionTaskRequest, M
private final ClusterService clusterService;
private final Client client;
private final MLInputDatasetHandler mlInputDatasetHandler;
private final NamedXContentRegistry xContentRegistry;

public MLPredictTaskRunner(
ThreadPool threadPool,
Expand All @@ -73,42 +78,34 @@ public MLPredictTaskRunner(
MLStats mlStats,
MLInputDatasetHandler mlInputDatasetHandler,
MLTaskDispatcher mlTaskDispatcher,
MLCircuitBreakerService mlCircuitBreakerService
MLCircuitBreakerService mlCircuitBreakerService,
NamedXContentRegistry xContentRegistry
) {
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService);
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
this.threadPool = threadPool;
this.clusterService = clusterService;
this.client = client;
this.mlInputDatasetHandler = mlInputDatasetHandler;
this.xContentRegistry = xContentRegistry;
}

@Override
public void executeTask(MLPredictionTaskRequest request, TransportService transportService, ActionListener<MLTaskResponse> listener) {
mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> {
if (clusterService.localNode().getId().equals(node.getId())) {
// Execute prediction task locally
log.info("execute ML prediction request {} locally on node {}", request.toString(), node.getId());
startPredictionTask(request, listener);
} else {
// Execute batch task remotely
log.info("execute ML prediction request {} remotely on node {}", request.toString(), node.getId());
transportService
.sendRequest(
node,
MLPredictionTaskAction.NAME,
request,
new ActionListenerResponseHandler<>(listener, MLTaskResponse::new)
);
}
}, e -> listener.onFailure(e)));
protected String getTransportActionName() {
return MLPredictionTaskAction.NAME;
}

@Override
protected TransportResponseHandler<MLTaskResponse> getResponseHandler(ActionListener<MLTaskResponse> listener) {
return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new);
}

/**
* Start prediction task
* @param request MLPredictionTaskRequest
* @param listener Action listener
*/
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
@Override
protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask
Expand Down Expand Up @@ -166,36 +163,49 @@ private void predict(
internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
return;
}
Map<String, Object> source = r.getSourceAsMap();
User requestUser = getUserContext(client);
User resourceUser = User.parse((String) source.get(USER));
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
// The backend roles of request user and resource user doesn't have intersection
OpenSearchException e = new OpenSearchException(
"User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId()
);
handlePredictFailure(mlTask, internalListener, e, false);
return;
}
try (
XContentParser xContentParser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, r.getSourceAsString())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser);
MLModel mlModel = MLModel.parse(xContentParser);
User resourceUser = mlModel.getUser();
User requestUser = getUserContext(client);
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
// The backend roles of request user and resource user doesn't have intersection
OpenSearchException e = new OpenSearchException(
"User: "
+ requestUser.getName()
+ " does not have permissions to run predict by model: "
+ request.getModelId()
);
handlePredictFailure(mlTask, internalListener, e, false);
return;
}
Model model = new Model();
model.setName(mlModel.getName());
model.setVersion(mlModel.getVersion());
byte[] decoded = Base64.getDecoder().decode(mlModel.getContent());
model.setContent(decoded);

// run predict
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
MLOutput output = MLEngine
.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
}

Model model = new Model();
model.setName((String) source.get(MLModel.MODEL_NAME));
model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
model.setContent(decoded);

// run predict
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
MLOutput output = MLEngine
.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
handleAsyncMLTaskComplete(mlTask);
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
internalListener.onResponse(response);
} catch (Exception e) {
log.error("Failed to predict model " + request.getModelId(), e);
internalListener.onFailure(e);
}

// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
handleAsyncMLTaskComplete(mlTask);
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
internalListener.onResponse(response);
}, e -> {
log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e);
handlePredictFailure(mlTask, internalListener, e, true);
Expand Down
Loading

0 comments on commit a032707

Please sign in to comment.