diff --git a/common/src/main/java/org/opensearch/ml/common/transport/MLTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/MLTaskRequest.java new file mode 100644 index 0000000000..900d7728f8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/MLTaskRequest.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport; + +import lombok.Getter; +import lombok.Setter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.util.UUID; + +@Getter +@Setter +public class MLTaskRequest extends ActionRequest { + + protected boolean dispatchTask; + protected final String requestID; + + public MLTaskRequest(boolean dispatchTask) { + this.dispatchTask = dispatchTask; + this.requestID = UUID.randomUUID().toString(); + } + + public MLTaskRequest(StreamInput in) throws IOException { + super(in); + this.requestID = in.readString(); + this.dispatchTask = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(requestID); + out.writeBoolean(dispatchTask); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java index 2b4f4f2d24..fa0db53b21 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java @@ -20,6 +20,7 @@ import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.Input; +import org.opensearch.ml.common.transport.MLTaskRequest; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -31,17 +32,22 @@ @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLExecuteTaskRequest extends ActionRequest { +public class MLExecuteTaskRequest extends MLTaskRequest { FunctionName functionName; Input input; @Builder - public MLExecuteTaskRequest(@NonNull FunctionName functionName, Input input) { + public MLExecuteTaskRequest(@NonNull FunctionName functionName, Input input, boolean dispatchTask) { + super(dispatchTask); this.functionName = functionName; this.input = input; } + public MLExecuteTaskRequest(@NonNull FunctionName functionName, Input input) { + this(functionName, input, true); + } + public MLExecuteTaskRequest(StreamInput in) throws IOException { super(in); this.functionName = in.readEnum(FunctionName.class); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java index 5e6688395e..443f7720c5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java @@ -23,23 +23,29 @@ import lombok.Getter; import lombok.ToString; import lombok.experimental.FieldDefaults; +import org.opensearch.ml.common.transport.MLTaskRequest; import static org.opensearch.action.ValidateActions.addValidationError; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLPredictionTaskRequest extends ActionRequest { +public class MLPredictionTaskRequest extends MLTaskRequest { String modelId; MLInput mlInput; @Builder - public MLPredictionTaskRequest(String modelId, MLInput mlInput) { + public MLPredictionTaskRequest(String modelId, MLInput mlInput, boolean dispatchTask) { + super(dispatchTask); this.mlInput = mlInput; this.modelId = modelId; } + public MLPredictionTaskRequest(String modelId, MLInput mlInput) { + this(modelId, mlInput, true); + } + public MLPredictionTaskRequest(StreamInput in) throws IOException { super(in); this.modelId = in.readOptionalString(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java index b5d3bcfc92..628f649877 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java @@ -17,6 +17,7 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.transport.MLTaskRequest; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -29,7 +30,7 @@ @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLTrainingTaskRequest extends ActionRequest { +public class MLTrainingTaskRequest extends MLTaskRequest { /** * the name of algorithm @@ -38,11 +39,16 @@ public class MLTrainingTaskRequest extends ActionRequest { boolean async; @Builder - public MLTrainingTaskRequest(MLInput mlInput, boolean async) { + public MLTrainingTaskRequest(MLInput mlInput, boolean async, boolean dispatchTask) { + super(dispatchTask); this.mlInput = mlInput; this.async = async; } + public MLTrainingTaskRequest(MLInput mlInput, boolean async) { + this(mlInput, async, true); + } + public MLTrainingTaskRequest(StreamInput in) throws IOException { super(in); this.mlInput = new MLInput(in); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 33c54c6f56..9fd5acca50 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -188,7 +188,8 @@ public Collection createComponents( mlStats, mlInputDatasetHandler, mlTaskDispatcher, - mlCircuitBreakerService + mlCircuitBreakerService, + xContentRegistry ); mlTrainAndPredictTaskRunner = new MLTrainAndPredictTaskRunner( threadPool, diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 62d2191966..43616e0f74 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -10,19 +10,21 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.Input; import org.opensearch.ml.common.parameter.Output; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.MLStats; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; +import org.opensearch.transport.TransportResponseHandler; /** * MLExecuteTaskRunner is responsible for running execute tasks. @@ -44,26 +46,30 @@ public MLExecuteTaskRunner( MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService ) { - super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService); + super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService); this.threadPool = threadPool; this.clusterService = clusterService; this.client = client; this.mlInputDatasetHandler = mlInputDatasetHandler; } + @Override + protected String getTransportActionName() { + return MLExecuteTaskAction.NAME; + } + + @Override + protected TransportResponseHandler getResponseHandler(ActionListener listener) { + return new ActionListenerResponseHandler<>(listener, MLExecuteTaskResponse::new); + } + /** * Execute algorithm and return result. - * TODO: 1. support backend task run; 2. support dispatch task to remote node * @param request MLExecuteTaskRequest - * @param transportService transport service * @param listener Action listener */ @Override - public void executeTask( - MLExecuteTaskRequest request, - TransportService transportService, - ActionListener listener - ) { + protected void executeTask(MLExecuteTaskRequest request, ActionListener listener) { threadPool.executor(TASK_THREAD_POOL).execute(() -> { try { Input input = request.getInput(); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index b903c02f2d..99b04f3d88 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -5,6 +5,7 @@ package org.opensearch.ml.task; +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; import static org.opensearch.ml.permission.AccessController.checkUserPermissions; import static org.opensearch.ml.permission.AccessController.getUserContext; @@ -17,7 +18,6 @@ import java.time.Instant; import java.util.Base64; -import java.util.Map; import java.util.UUID; import lombok.extern.log4j.Log4j2; @@ -32,6 +32,10 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.dataframe.DataFrame; @@ -53,7 +57,7 @@ import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLStats; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; +import org.opensearch.transport.TransportResponseHandler; /** * MLPredictTaskRunner is responsible for running predict tasks. @@ -64,6 +68,7 @@ public class MLPredictTaskRunner extends MLTaskRunner listener) { - mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> { - if (clusterService.localNode().getId().equals(node.getId())) { - // Execute prediction task locally - log.info("execute ML prediction request {} locally on node {}", request.toString(), node.getId()); - startPredictionTask(request, listener); - } else { - // Execute batch task remotely - log.info("execute ML prediction request {} remotely on node {}", request.toString(), node.getId()); - transportService - .sendRequest( - node, - MLPredictionTaskAction.NAME, - request, - new ActionListenerResponseHandler<>(listener, MLTaskResponse::new) - ); - } - }, e -> listener.onFailure(e))); + protected String getTransportActionName() { + return MLPredictionTaskAction.NAME; + } + + @Override + protected TransportResponseHandler getResponseHandler(ActionListener listener) { + return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new); } /** @@ -108,7 +104,8 @@ public void executeTask(MLPredictionTaskRequest request, TransportService transp * @param request MLPredictionTaskRequest * @param listener Action listener */ - public void startPredictionTask(MLPredictionTaskRequest request, ActionListener listener) { + @Override + protected void executeTask(MLPredictionTaskRequest request, ActionListener listener) { MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType(); Instant now = Instant.now(); MLTask mlTask = MLTask @@ -166,36 +163,49 @@ private void predict( internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId.")); return; } - Map source = r.getSourceAsMap(); - User requestUser = getUserContext(client); - User resourceUser = User.parse((String) source.get(USER)); - if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) { - // The backend roles of request user and resource user doesn't have intersection - OpenSearchException e = new OpenSearchException( - "User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId() - ); - handlePredictFailure(mlTask, internalListener, e, false); - return; - } + try ( + XContentParser xContentParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, r.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser); + MLModel mlModel = MLModel.parse(xContentParser); + User resourceUser = mlModel.getUser(); + User requestUser = getUserContext(client); + if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) { + // The backend roles of request user and resource user doesn't have intersection + OpenSearchException e = new OpenSearchException( + "User: " + + requestUser.getName() + + " does not have permissions to run predict by model: " + + request.getModelId() + ); + handlePredictFailure(mlTask, internalListener, e, false); + return; + } + Model model = new Model(); + model.setName(mlModel.getName()); + model.setVersion(mlModel.getVersion()); + byte[] decoded = Base64.getDecoder().decode(mlModel.getContent()); + model.setContent(decoded); + + // run predict + mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync()); + MLOutput output = MLEngine + .predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model); + if (output instanceof MLPredictionOutput) { + ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); + } - Model model = new Model(); - model.setName((String) source.get(MLModel.MODEL_NAME)); - model.setVersion((Integer) source.get(MLModel.MODEL_VERSION)); - byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT)); - model.setContent(decoded); - - // run predict - mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync()); - MLOutput output = MLEngine - .predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model); - if (output instanceof MLPredictionOutput) { - ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); + // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state + handleAsyncMLTaskComplete(mlTask); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + internalListener.onResponse(response); + } catch (Exception e) { + log.error("Failed to predict model " + request.getModelId(), e); + internalListener.onFailure(e); } - // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state - handleAsyncMLTaskComplete(mlTask); - MLTaskResponse response = MLTaskResponse.builder().output(output).build(); - internalListener.onResponse(response); }, e -> { log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e); handlePredictFailure(mlTask, internalListener, e, true); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index 7049f330ae..962473c130 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -10,13 +10,19 @@ import java.util.HashMap; import java.util.Map; +import lombok.extern.log4j.Log4j2; + import org.opensearch.action.ActionListener; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.parameter.MLTask; import org.opensearch.ml.common.parameter.MLTaskState; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.common.transport.MLTaskRequest; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.stats.MLStats; +import org.opensearch.transport.TransportResponse; +import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableMap; @@ -26,12 +32,14 @@ * @param ML task request * @param ML task request */ -public abstract class MLTaskRunner { +@Log4j2 +public abstract class MLTaskRunner { public static final int TIMEOUT_IN_MILLIS = 2000; protected final MLTaskManager mlTaskManager; protected final MLStats mlStats; protected final MLTaskDispatcher mlTaskDispatcher; protected final MLCircuitBreakerService mlCircuitBreakerService; + private final ClusterService clusterService; protected static final String TASK_ID = "task_id"; protected static final String ALGORITHM = "algorithm"; @@ -44,12 +52,14 @@ public MLTaskRunner( MLTaskManager mlTaskManager, MLStats mlStats, MLTaskDispatcher mlTaskDispatcher, - MLCircuitBreakerService mlCircuitBreakerService + MLCircuitBreakerService mlCircuitBreakerService, + ClusterService clusterService ) { this.mlTaskManager = mlTaskManager; this.mlStats = mlStats; this.mlTaskDispatcher = mlTaskDispatcher; this.mlCircuitBreakerService = mlCircuitBreakerService; + this.clusterService = clusterService; } protected void handleAsyncMLTaskFailure(MLTask mlTask, Exception e) { @@ -80,11 +90,12 @@ public void run(Request request, TransportService transportService, ActionListen if (mlCircuitBreakerService.isOpen()) { throw new MLLimitExceededException("Circuit breaker is open"); } - try { - executeTask(request, transportService, listener); - } catch (Exception e) { - listener.onFailure(e); + if (!request.isDispatchTask()) { + log.info("Run ML request {} locally", request.getRequestID()); + executeTask(request, listener); + return; } + dispatchTask(request, transportService, listener); } protected ActionListener wrappedCleanupListener(ActionListener listener, String taskId) { @@ -95,5 +106,24 @@ protected ActionListener wrappedCleanupListener(ActionListener listener); + protected void dispatchTask(Request request, TransportService transportService, ActionListener listener) { + mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> { + if (clusterService.localNode().getId().equals(node.getId())) { + // Execute ML task locally + log.info("Execute ML request {} locally on node {}", request.getRequestID(), node.getId()); + executeTask(request, listener); + } else { + // Execute ML task remotely + log.info("Execute ML request {} remotely on node {}", request.getRequestID(), node.getId()); + request.setDispatchTask(false); + transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener)); + } + }, e -> listener.onFailure(e))); + } + + protected abstract String getTransportActionName(); + + protected abstract TransportResponseHandler getResponseHandler(ActionListener listener); + + protected abstract void executeTask(Request request, ActionListener listener); } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java index 2730d5ec31..a7825bf5a0 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java @@ -40,7 +40,7 @@ import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLStats; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; +import org.opensearch.transport.TransportResponseHandler; /** * MLPredictTaskRunner is responsible for running predict tasks. @@ -62,7 +62,7 @@ public MLTrainAndPredictTaskRunner( MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService ) { - super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService); + super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService); this.threadPool = threadPool; this.clusterService = clusterService; this.client = client; @@ -70,24 +70,13 @@ public MLTrainAndPredictTaskRunner( } @Override - public void executeTask(MLTrainingTaskRequest request, TransportService transportService, ActionListener listener) { - mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> { - if (clusterService.localNode().getId().equals(node.getId())) { - // Execute prediction task locally - log.info("execute ML train and prediction request {} locally on node {}", request.toString(), node.getId()); - startTrainAndPredictionTask(request, listener); - } else { - // Execute batch task remotely - log.info("execute ML train and prediction request {} remotely on node {}", request.toString(), node.getId()); - transportService - .sendRequest( - node, - MLTrainAndPredictionTaskAction.NAME, - request, - new ActionListenerResponseHandler<>(listener, MLTaskResponse::new) - ); - } - }, e -> listener.onFailure(e))); + protected String getTransportActionName() { + return MLTrainAndPredictionTaskAction.NAME; + } + + @Override + protected TransportResponseHandler getResponseHandler(ActionListener listener) { + return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new); } /** @@ -95,7 +84,8 @@ public void executeTask(MLTrainingTaskRequest request, TransportService transpor * @param request MLPredictionTaskRequest * @param listener Action listener */ - public void startTrainAndPredictionTask(MLTrainingTaskRequest request, ActionListener listener) { + @Override + protected void executeTask(MLTrainingTaskRequest request, ActionListener listener) { MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType(); Instant now = Instant.now(); MLTask mlTask = MLTask diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index 6b031e9eff..e2179c006d 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -52,7 +52,7 @@ import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLStats; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; +import org.opensearch.transport.TransportResponseHandler; /** * MLTrainingTaskRunner is responsible for running training tasks. @@ -76,7 +76,7 @@ public MLTrainingTaskRunner( MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService ) { - super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService); + super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService); this.threadPool = threadPool; this.clusterService = clusterService; this.client = client; @@ -85,27 +85,17 @@ public MLTrainingTaskRunner( } @Override - public void executeTask(MLTrainingTaskRequest request, TransportService transportService, ActionListener listener) { - mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> { - if (clusterService.localNode().getId().equals(node.getId())) { - // Execute training task locally - log.info("execute ML training request {} locally on node {}", request.toString(), node.getId()); - createMLTaskAndTrain(request, listener); - } else { - // Execute batch task remotely - log.info("execute ML training request {} remotely on node {}", request.toString(), node.getId()); - transportService - .sendRequest( - node, - MLTrainingTaskAction.NAME, - request, - new ActionListenerResponseHandler<>(listener, MLTaskResponse::new) - ); - } - }, e -> listener.onFailure(e))); + protected String getTransportActionName() { + return MLTrainingTaskAction.NAME; } - public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener listener) { + @Override + protected TransportResponseHandler getResponseHandler(ActionListener listener) { + return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new); + } + + @Override + protected void executeTask(MLTrainingTaskRequest request, ActionListener listener) { MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType(); Instant now = Instant.now(); MLTask mlTask = MLTask @@ -151,7 +141,7 @@ public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener listener) { + private void startTrainingTask(MLTask mlTask, MLInput mlInput, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment(); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 21bf03388d..620fdb6b07 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -116,14 +116,14 @@ public void setup() { } public void testExecuteTask_Success() { - taskRunner.executeTask(mlExecuteTaskRequest, transportService, listener); + taskRunner.executeTask(mlExecuteTaskRequest, listener); verify(listener).onResponse(any(MLExecuteTaskResponse.class)); } public void testExecuteTask_NoExecutorService() { exceptionRule.expect(IllegalArgumentException.class); when(threadPool.executor(anyString())).thenThrow(new IllegalArgumentException()); - taskRunner.executeTask(mlExecuteTaskRequest, transportService, listener); + taskRunner.executeTask(mlExecuteTaskRequest, listener); verify(listener, never()).onResponse(any(MLExecuteTaskResponse.class)); } } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 6f209d0042..699fdda12e 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -61,6 +61,7 @@ public class MLPredictTaskRunnerTests extends OpenSearchTestCase { + public static final String USER_STRING = "myuser|role1,role2|myTenant"; @Mock ThreadPool threadPool; @@ -135,7 +136,8 @@ public void setup() throws IOException { mlStats, mlInputDatasetHandler, mlTaskDispatcher, - mlCircuitBreakerService + mlCircuitBreakerService, + xContentRegistry() ) ); @@ -165,13 +167,13 @@ public void setup() throws IOException { when(client.threadPool()).thenReturn(threadPool); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); - threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "myuser|role1,role2|myTenant"); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); MLModel mlModel = MLModel .builder() - .user(new User()) + .user(User.parse(USER_STRING)) .version(111) .name("test") .algorithm(FunctionName.BATCH_RCF) @@ -187,7 +189,7 @@ public void setup() throws IOException { public void testExecuteTask_OnLocalNode() { setupMocks(true, false, false, false); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -198,7 +200,7 @@ public void testExecuteTask_OnLocalNode() { public void testExecuteTask_OnLocalNode_QueryInput() { setupMocks(true, false, false, false); - taskRunner.executeTask(requestWithQuery, transportService, listener); + taskRunner.dispatchTask(requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -209,23 +211,36 @@ public void testExecuteTask_OnLocalNode_QueryInput() { public void testExecuteTask_OnLocalNode_QueryInput_Failure() { setupMocks(true, true, false, false); - taskRunner.executeTask(requestWithQuery, transportService, listener); + taskRunner.dispatchTask(requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager, never()).add(any(MLTask.class)); verify(client, never()).get(any(), any()); } + public void testExecuteTask_NoPermission() { + setupMocks(true, true, false, false); + threadContext.stashContext(); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "test_user|test_role|test_tenant"); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); + verify(mlTaskManager).add(any(MLTask.class)); + verify(mlTaskManager).remove(anyString()); + verify(client).get(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("User: test_user does not have permissions to run predict by model: 111", argumentCaptor.getValue().getMessage()); + } + public void testExecuteTask_OnRemoteNode() { setupMocks(false, false, false, false); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(transportService).sendRequest(eq(remoteNode), eq(MLPredictionTaskAction.NAME), eq(requestWithDataFrame), any()); } public void testExecuteTask_OnLocalNode_GetModelFail() { setupMocks(true, false, true, false); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -239,7 +254,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { setupMocks(true, false, false, false); requestWithDataFrame = MLPredictionTaskRequest.builder().mlInput(mlInputWithDataFrame).build(); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -253,7 +268,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { public void testExecuteTask_OnLocalNode_NullGetResponse() { setupMocks(true, false, false, true); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index c0e8cdf427..d1689fbada 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -152,7 +152,7 @@ public void testExecuteTask_OnLocalNode() { actionListener.onResponse(localNode); return null; }).when(mlTaskDispatcher).dispatchTask(any()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(listener).onResponse(any()); verify(taskRunner).handleAsyncMLTaskComplete(any(MLTask.class)); } @@ -170,7 +170,7 @@ public void testExecuteTask_OnLocalNode_QueryInput() { return null; }).when(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); - taskRunner.executeTask(requestWithQuery, transportService, listener); + taskRunner.dispatchTask(requestWithQuery, transportService, listener); verify(listener).onResponse(any()); verify(taskRunner).handleAsyncMLTaskComplete(any(MLTask.class)); } @@ -188,7 +188,7 @@ public void testExecuteTask_OnLocalNode_QueryInput_Failure() { return null; }).when(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); - taskRunner.executeTask(requestWithQuery, transportService, listener); + taskRunner.dispatchTask(requestWithQuery, transportService, listener); verify(listener, never()).onResponse(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -203,7 +203,7 @@ public void testExecuteTask_OnLocalNode_FailedToUpdateTask() { return null; }).when(mlTaskDispatcher).dispatchTask(any()); doThrow(new RuntimeException(errorMessage)).when(mlTaskManager).updateTaskState(anyString(), any(MLTaskState.class), anyBoolean()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); @@ -216,7 +216,7 @@ public void testExecuteTask_OnRemoteNode() { actionListener.onResponse(remoteNode); return null; }).when(mlTaskDispatcher).dispatchTask(any()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(transportService).sendRequest(eq(remoteNode), eq(MLTrainAndPredictionTaskAction.NAME), eq(requestWithDataFrame), any()); } @@ -226,7 +226,7 @@ public void testExecuteTask_FailedToDispatch() { actionListener.onFailure(new RuntimeException(errorMessage)); return null; }).when(mlTaskDispatcher).dispatchTask(any()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(listener, never()).onResponse(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index b23866c7a3..f35a236f8b 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -167,7 +167,7 @@ public void setup() { public void testExecuteTask_OnLocalNode_SyncRequest() { setupMocks(true, false, false, false); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(listener).onResponse(any()); verify(mlTaskManager, never()).createMLTask(any(MLTask.class), any()); verify(mlTaskManager).add(any(MLTask.class)); @@ -178,7 +178,7 @@ public void testExecuteTask_OnLocalNode_SyncRequest() { public void testExecuteTask_OnLocalNode_SyncRequest_QueryInput() { setupMocks(true, false, false, false); - taskRunner.executeTask(requestWithQuery, transportService, listener); + taskRunner.dispatchTask(requestWithQuery, transportService, listener); verify(listener).onResponse(any()); verify(mlTaskManager, never()).createMLTask(any(MLTask.class), any()); verify(mlTaskManager).add(any(MLTask.class)); @@ -189,7 +189,7 @@ public void testExecuteTask_OnLocalNode_SyncRequest_QueryInput() { public void testExecuteTask_OnLocalNode_AsyncRequest_QueryInput_Failure() { setupMocks(true, false, false, true); - taskRunner.executeTask(asyncRequestWithQuery, transportService, listener); + taskRunner.dispatchTask(asyncRequestWithQuery, transportService, listener); verify(listener).onResponse(any()); verify(mlTaskManager).createMLTask(any(MLTask.class), any()); verify(mlTaskManager).add(any(MLTask.class)); @@ -201,7 +201,7 @@ public void testExecuteTask_OnLocalNode_AsyncRequest_QueryInput_Failure() { public void testExecuteTask_OnLocalNode_AsyncRequest() { setupMocks(true, false, false, false); - taskRunner.executeTask(asyncRequestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(asyncRequestWithDataFrame, transportService, listener); verify(listener).onResponse(any()); verify(mlTaskManager).createMLTask(any(MLTask.class), any()); verify(mlTaskManager).add(any(MLTask.class)); @@ -212,7 +212,7 @@ public void testExecuteTask_OnLocalNode_AsyncRequest() { public void testExecuteTask_OnLocalNode_AsyncRequest_FailToCreateTask() { setupMocks(true, true, false, false); - taskRunner.executeTask(asyncRequestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(asyncRequestWithDataFrame, transportService, listener); verify(listener, never()).onResponse(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -228,7 +228,7 @@ public void testExecuteTask_OnLocalNode_AsyncRequest_FailToCreateTask() { public void testExecuteTask_OnLocalNode_AsyncRequest_FailToCreateTaskWithException() { setupMocks(true, true, true, false); - taskRunner.executeTask(asyncRequestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(asyncRequestWithDataFrame, transportService, listener); verify(listener, never()).onResponse(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -244,7 +244,7 @@ public void testExecuteTask_OnLocalNode_AsyncRequest_FailToCreateTaskWithExcepti public void testExecuteTask_OnRemoteNode_SyncRequest() { setupMocks(false, false, false, false); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(transportService).sendRequest(eq(remoteNode), eq(MLTrainingTaskAction.NAME), eq(requestWithDataFrame), any()); } @@ -253,7 +253,7 @@ public void testExecuteTask_OnRemoteNode_SyncRequest_FailToSendRequest() { doThrow(new NodeNotConnectedException(remoteNode, errorMessage)) .when(transportService) .sendRequest(eq(remoteNode), eq(MLTrainingTaskAction.NAME), any(), any()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(transportService).sendRequest(eq(remoteNode), eq(MLTrainingTaskAction.NAME), eq(requestWithDataFrame), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -266,7 +266,7 @@ public void testExecuteTask_FailedToDispatch() { actionListener.onFailure(new RuntimeException(errorMessage)); return null; }).when(mlTaskDispatcher).dispatchTask(any()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(listener, never()).onResponse(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java index a9f2c43b62..f46a4900ba 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java @@ -29,8 +29,11 @@ import org.opensearch.ml.common.parameter.MLTask; import org.opensearch.ml.common.parameter.MLTaskState; import org.opensearch.ml.common.parameter.MLTaskType; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.common.transport.MLTaskRequest; import org.opensearch.ml.stats.MLStats; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; public class TaskRunnerTests extends OpenSearchTestCase { @@ -43,6 +46,8 @@ public class TaskRunnerTests extends OpenSearchTestCase { MLTaskDispatcher mlTaskDispatcher; @Mock MLCircuitBreakerService mlCircuitBreakerService; + @Mock + ClusterService clusterService; MLTaskRunner mlTaskRunner; MLTask mlTask; @@ -53,9 +58,19 @@ public class TaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - mlTaskRunner = new MLTaskRunner(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService) { + mlTaskRunner = new MLTaskRunner(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService) { + @Override + public String getTransportActionName() { + return null; + } + + @Override + public TransportResponseHandler getResponseHandler(ActionListener listener) { + return null; + } + @Override - public void executeTask(Object o, TransportService transportService, ActionListener listener) {} + public void executeTask(MLTaskRequest request, ActionListener listener) {} }; mlTask = MLTask .builder()