Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport to 1.3] support dispatching execute task; don't dispatch ML task again #298

Merged
merged 2 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport;

import lombok.Getter;
import lombok.Setter;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;

import java.io.IOException;
import java.util.UUID;

@Getter
@Setter
public class MLTaskRequest extends ActionRequest {

protected boolean dispatchTask;
protected final String requestID;

public MLTaskRequest(boolean dispatchTask) {
this.dispatchTask = dispatchTask;
this.requestID = UUID.randomUUID().toString();
}

public MLTaskRequest(StreamInput in) throws IOException {
super(in);
this.requestID = in.readString();
this.dispatchTask = in.readBoolean();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(requestID);
out.writeBoolean(dispatchTask);
}

@Override
public ActionRequestValidationException validate() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.ml.common.parameter.Input;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.transport.MLTaskRequest;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -29,12 +30,18 @@
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLExecuteTaskRequest extends ActionRequest {
public class MLExecuteTaskRequest extends MLTaskRequest {

Input input;

@Builder
public MLExecuteTaskRequest(Input input) {
this(input, true);
}

@Builder
public MLExecuteTaskRequest(Input input, boolean dispatchTask) {
super(dispatchTask);
this.input = input;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,29 @@
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.ml.common.transport.MLTaskRequest;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLPredictionTaskRequest extends ActionRequest {
public class MLPredictionTaskRequest extends MLTaskRequest {

String modelId;
MLInput mlInput;

@Builder
public MLPredictionTaskRequest(String modelId, MLInput mlInput) {
public MLPredictionTaskRequest(String modelId, MLInput mlInput, boolean dispatchTask) {
super(dispatchTask);
this.mlInput = mlInput;
this.modelId = modelId;
}

public MLPredictionTaskRequest(String modelId, MLInput mlInput) {
this(modelId, mlInput, true);
}

public MLPredictionTaskRequest(StreamInput in) throws IOException {
super(in);
this.modelId = in.readOptionalString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.transport.MLTaskRequest;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -29,7 +30,7 @@
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLTrainingTaskRequest extends ActionRequest {
public class MLTrainingTaskRequest extends MLTaskRequest {

/**
* the name of algorithm
Expand All @@ -38,11 +39,16 @@ public class MLTrainingTaskRequest extends ActionRequest {
boolean async;

@Builder
public MLTrainingTaskRequest(MLInput mlInput, boolean async) {
public MLTrainingTaskRequest(MLInput mlInput, boolean async, boolean dispatchTask) {
super(dispatchTask);
this.mlInput = mlInput;
this.async = async;
}

public MLTrainingTaskRequest(MLInput mlInput, boolean async) {
this(mlInput, async, true);
}

public MLTrainingTaskRequest(StreamInput in) throws IOException {
super(in);
this.mlInput = new MLInput(in);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ public Collection<Object> createComponents(
mlStats,
mlInputDatasetHandler,
mlTaskDispatcher,
mlCircuitBreakerService
mlCircuitBreakerService,
xContentRegistry
);
mlTrainAndPredictTaskRunner = new MLTrainAndPredictTaskRunner(
threadPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.parameter.Input;
import org.opensearch.ml.common.parameter.Output;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.TransportResponseHandler;

/**
* MLExecuteTaskRunner is responsible for running execute tasks.
Expand All @@ -43,26 +45,30 @@ public MLExecuteTaskRunner(
MLTaskDispatcher mlTaskDispatcher,
MLCircuitBreakerService mlCircuitBreakerService
) {
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService);
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
this.threadPool = threadPool;
this.clusterService = clusterService;
this.client = client;
this.mlInputDatasetHandler = mlInputDatasetHandler;
}

@Override
protected String getTransportActionName() {
return MLExecuteTaskAction.NAME;
}

@Override
protected TransportResponseHandler<MLExecuteTaskResponse> getResponseHandler(ActionListener<MLExecuteTaskResponse> listener) {
return new ActionListenerResponseHandler<>(listener, MLExecuteTaskResponse::new);
}

/**
* Execute algorithm and return result.
* TODO: 1. support backend task run; 2. support dispatch task to remote node
* @param request MLExecuteTaskRequest
* @param transportService transport service
* @param listener Action listener
*/
@Override
public void executeTask(
MLExecuteTaskRequest request,
TransportService transportService,
ActionListener<MLExecuteTaskResponse> listener
) {
protected void executeTask(MLExecuteTaskRequest request, ActionListener<MLExecuteTaskResponse> listener) {
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
Input input = request.getInput();
Output output = MLEngine.execute(input);
Expand Down
110 changes: 60 additions & 50 deletions plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.task;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX;
import static org.opensearch.ml.permission.AccessController.checkUserPermissions;
import static org.opensearch.ml.permission.AccessController.getUserContext;
Expand All @@ -17,7 +18,6 @@

import java.time.Instant;
import java.util.Base64;
import java.util.Map;
import java.util.UUID;

import lombok.extern.log4j.Log4j2;
Expand All @@ -32,6 +32,10 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.dataframe.DataFrame;
Expand All @@ -53,7 +57,7 @@
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.TransportResponseHandler;

/**
* MLPredictTaskRunner is responsible for running predict tasks.
Expand All @@ -64,6 +68,7 @@ public class MLPredictTaskRunner extends MLTaskRunner<MLPredictionTaskRequest, M
private final ClusterService clusterService;
private final Client client;
private final MLInputDatasetHandler mlInputDatasetHandler;
private final NamedXContentRegistry xContentRegistry;

public MLPredictTaskRunner(
ThreadPool threadPool,
Expand All @@ -73,42 +78,34 @@ public MLPredictTaskRunner(
MLStats mlStats,
MLInputDatasetHandler mlInputDatasetHandler,
MLTaskDispatcher mlTaskDispatcher,
MLCircuitBreakerService mlCircuitBreakerService
MLCircuitBreakerService mlCircuitBreakerService,
NamedXContentRegistry xContentRegistry
) {
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService);
super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
this.threadPool = threadPool;
this.clusterService = clusterService;
this.client = client;
this.mlInputDatasetHandler = mlInputDatasetHandler;
this.xContentRegistry = xContentRegistry;
}

@Override
public void executeTask(MLPredictionTaskRequest request, TransportService transportService, ActionListener<MLTaskResponse> listener) {
mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> {
if (clusterService.localNode().getId().equals(node.getId())) {
// Execute prediction task locally
log.info("execute ML prediction request {} locally on node {}", request.toString(), node.getId());
startPredictionTask(request, listener);
} else {
// Execute batch task remotely
log.info("execute ML prediction request {} remotely on node {}", request.toString(), node.getId());
transportService
.sendRequest(
node,
MLPredictionTaskAction.NAME,
request,
new ActionListenerResponseHandler<>(listener, MLTaskResponse::new)
);
}
}, e -> listener.onFailure(e)));
protected String getTransportActionName() {
return MLPredictionTaskAction.NAME;
}

@Override
protected TransportResponseHandler<MLTaskResponse> getResponseHandler(ActionListener<MLTaskResponse> listener) {
return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new);
}

/**
* Start prediction task
* @param request MLPredictionTaskRequest
* @param listener Action listener
*/
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
@Override
protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask
Expand Down Expand Up @@ -166,36 +163,49 @@ private void predict(
internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
return;
}
Map<String, Object> source = r.getSourceAsMap();
User requestUser = getUserContext(client);
User resourceUser = User.parse((String) source.get(USER));
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
// The backend roles of request user and resource user doesn't have intersection
OpenSearchException e = new OpenSearchException(
"User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId()
);
handlePredictFailure(mlTask, internalListener, e, false);
return;
}
try (
XContentParser xContentParser = XContentType.JSON
.xContent()
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, r.getSourceAsString())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser);
MLModel mlModel = MLModel.parse(xContentParser);
User resourceUser = mlModel.getUser();
User requestUser = getUserContext(client);
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
// The backend roles of request user and resource user doesn't have intersection
OpenSearchException e = new OpenSearchException(
"User: "
+ requestUser.getName()
+ " does not have permissions to run predict by model: "
+ request.getModelId()
);
handlePredictFailure(mlTask, internalListener, e, false);
return;
}
Model model = new Model();
model.setName(mlModel.getName());
model.setVersion(mlModel.getVersion());
byte[] decoded = Base64.getDecoder().decode(mlModel.getContent());
model.setContent(decoded);

// run predict
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
MLOutput output = MLEngine
.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
}

Model model = new Model();
model.setName((String) source.get(MLModel.MODEL_NAME));
model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
model.setContent(decoded);

// run predict
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
MLOutput output = MLEngine
.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
handleAsyncMLTaskComplete(mlTask);
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
internalListener.onResponse(response);
} catch (Exception e) {
log.error("Failed to predict model " + request.getModelId(), e);
internalListener.onFailure(e);
}

// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
handleAsyncMLTaskComplete(mlTask);
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
internalListener.onResponse(response);
}, e -> {
log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e);
handlePredictFailure(mlTask, internalListener, e, true);
Expand Down
Loading