diff --git a/common/src/main/java/org/opensearch/ml/common/transport/MLTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/MLTaskRequest.java new file mode 100644 index 0000000000..900d7728f8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/MLTaskRequest.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport; + +import lombok.Getter; +import lombok.Setter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.util.UUID; + +@Getter +@Setter +public class MLTaskRequest extends ActionRequest { + + protected boolean dispatchTask; + protected final String requestID; + + public MLTaskRequest(boolean dispatchTask) { + this.dispatchTask = dispatchTask; + this.requestID = UUID.randomUUID().toString(); + } + + public MLTaskRequest(StreamInput in) throws IOException { + super(in); + this.requestID = in.readString(); + this.dispatchTask = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(requestID); + out.writeBoolean(dispatchTask); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java index 9c93c5c004..93ea3c99ee 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java @@ -18,6 +18,7 @@ import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.ml.common.parameter.Input; import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.transport.MLTaskRequest; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -29,12 +30,18 @@ @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLExecuteTaskRequest extends ActionRequest { +public class MLExecuteTaskRequest extends MLTaskRequest { Input input; @Builder public MLExecuteTaskRequest(Input input) { + this(input, true); + } + + @Builder + public MLExecuteTaskRequest(Input input, boolean dispatchTask) { + super(dispatchTask); this.input = input; } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java index 5e6688395e..443f7720c5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java @@ -23,23 +23,29 @@ import lombok.Getter; import lombok.ToString; import lombok.experimental.FieldDefaults; +import org.opensearch.ml.common.transport.MLTaskRequest; import static org.opensearch.action.ValidateActions.addValidationError; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLPredictionTaskRequest extends ActionRequest { +public class MLPredictionTaskRequest extends MLTaskRequest { String modelId; MLInput mlInput; @Builder - public MLPredictionTaskRequest(String modelId, MLInput mlInput) { + public MLPredictionTaskRequest(String modelId, MLInput mlInput, boolean dispatchTask) { + super(dispatchTask); this.mlInput = mlInput; this.modelId = modelId; } + public MLPredictionTaskRequest(String modelId, MLInput mlInput) { + this(modelId, mlInput, true); + } + public MLPredictionTaskRequest(StreamInput in) throws IOException { super(in); this.modelId = in.readOptionalString(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java index 1cfc6c1fdf..2919c22aa9 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java @@ -17,6 +17,7 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.transport.MLTaskRequest; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -29,7 +30,7 @@ @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLTrainingTaskRequest extends ActionRequest { +public class MLTrainingTaskRequest extends MLTaskRequest { /** * the name of algorithm @@ -38,11 +39,16 @@ public class MLTrainingTaskRequest extends ActionRequest { boolean async; @Builder - public MLTrainingTaskRequest(MLInput mlInput, boolean async) { + public MLTrainingTaskRequest(MLInput mlInput, boolean async, boolean dispatchTask) { + super(dispatchTask); this.mlInput = mlInput; this.async = async; } + public MLTrainingTaskRequest(MLInput mlInput, boolean async) { + this(mlInput, async, true); + } + public MLTrainingTaskRequest(StreamInput in) throws IOException { super(in); this.mlInput = new MLInput(in); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 9ace30a13b..846f66598f 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -186,7 +186,8 @@ public Collection createComponents( mlStats, mlInputDatasetHandler, mlTaskDispatcher, - mlCircuitBreakerService + mlCircuitBreakerService, + xContentRegistry ); mlTrainAndPredictTaskRunner = new MLTrainAndPredictTaskRunner( threadPool, diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 0a55580dd3..9cd250d2b2 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -10,18 +10,20 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.parameter.Input; import org.opensearch.ml.common.parameter.Output; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.indices.MLInputDatasetHandler; import org.opensearch.ml.stats.MLStats; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; +import org.opensearch.transport.TransportResponseHandler; /** * MLExecuteTaskRunner is responsible for running execute tasks. @@ -43,26 +45,30 @@ public MLExecuteTaskRunner( MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService ) { - super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService); + super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService); this.threadPool = threadPool; this.clusterService = clusterService; this.client = client; this.mlInputDatasetHandler = mlInputDatasetHandler; } + @Override + protected String getTransportActionName() { + return MLExecuteTaskAction.NAME; + } + + @Override + protected TransportResponseHandler getResponseHandler(ActionListener listener) { + return new ActionListenerResponseHandler<>(listener, MLExecuteTaskResponse::new); + } + /** * Execute algorithm and return result. - * TODO: 1. support backend task run; 2. support dispatch task to remote node * @param request MLExecuteTaskRequest - * @param transportService transport service * @param listener Action listener */ @Override - public void executeTask( - MLExecuteTaskRequest request, - TransportService transportService, - ActionListener listener - ) { + protected void executeTask(MLExecuteTaskRequest request, ActionListener listener) { threadPool.executor(TASK_THREAD_POOL).execute(() -> { Input input = request.getInput(); Output output = MLEngine.execute(input); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index b903c02f2d..99b04f3d88 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -5,6 +5,7 @@ package org.opensearch.ml.task; +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; import static org.opensearch.ml.permission.AccessController.checkUserPermissions; import static org.opensearch.ml.permission.AccessController.getUserContext; @@ -17,7 +18,6 @@ import java.time.Instant; import java.util.Base64; -import java.util.Map; import java.util.UUID; import lombok.extern.log4j.Log4j2; @@ -32,6 +32,10 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.dataframe.DataFrame; @@ -53,7 +57,7 @@ import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLStats; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; +import org.opensearch.transport.TransportResponseHandler; /** * MLPredictTaskRunner is responsible for running predict tasks. @@ -64,6 +68,7 @@ public class MLPredictTaskRunner extends MLTaskRunner listener) { - mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> { - if (clusterService.localNode().getId().equals(node.getId())) { - // Execute prediction task locally - log.info("execute ML prediction request {} locally on node {}", request.toString(), node.getId()); - startPredictionTask(request, listener); - } else { - // Execute batch task remotely - log.info("execute ML prediction request {} remotely on node {}", request.toString(), node.getId()); - transportService - .sendRequest( - node, - MLPredictionTaskAction.NAME, - request, - new ActionListenerResponseHandler<>(listener, MLTaskResponse::new) - ); - } - }, e -> listener.onFailure(e))); + protected String getTransportActionName() { + return MLPredictionTaskAction.NAME; + } + + @Override + protected TransportResponseHandler getResponseHandler(ActionListener listener) { + return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new); } /** @@ -108,7 +104,8 @@ public void executeTask(MLPredictionTaskRequest request, TransportService transp * @param request MLPredictionTaskRequest * @param listener Action listener */ - public void startPredictionTask(MLPredictionTaskRequest request, ActionListener listener) { + @Override + protected void executeTask(MLPredictionTaskRequest request, ActionListener listener) { MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType(); Instant now = Instant.now(); MLTask mlTask = MLTask @@ -166,36 +163,49 @@ private void predict( internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId.")); return; } - Map source = r.getSourceAsMap(); - User requestUser = getUserContext(client); - User resourceUser = User.parse((String) source.get(USER)); - if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) { - // The backend roles of request user and resource user doesn't have intersection - OpenSearchException e = new OpenSearchException( - "User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId() - ); - handlePredictFailure(mlTask, internalListener, e, false); - return; - } + try ( + XContentParser xContentParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, r.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser); + MLModel mlModel = MLModel.parse(xContentParser); + User resourceUser = mlModel.getUser(); + User requestUser = getUserContext(client); + if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) { + // The backend roles of request user and resource user doesn't have intersection + OpenSearchException e = new OpenSearchException( + "User: " + + requestUser.getName() + + " does not have permissions to run predict by model: " + + request.getModelId() + ); + handlePredictFailure(mlTask, internalListener, e, false); + return; + } + Model model = new Model(); + model.setName(mlModel.getName()); + model.setVersion(mlModel.getVersion()); + byte[] decoded = Base64.getDecoder().decode(mlModel.getContent()); + model.setContent(decoded); + + // run predict + mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync()); + MLOutput output = MLEngine + .predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model); + if (output instanceof MLPredictionOutput) { + ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); + } - Model model = new Model(); - model.setName((String) source.get(MLModel.MODEL_NAME)); - model.setVersion((Integer) source.get(MLModel.MODEL_VERSION)); - byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT)); - model.setContent(decoded); - - // run predict - mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync()); - MLOutput output = MLEngine - .predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model); - if (output instanceof MLPredictionOutput) { - ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); + // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state + handleAsyncMLTaskComplete(mlTask); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + internalListener.onResponse(response); + } catch (Exception e) { + log.error("Failed to predict model " + request.getModelId(), e); + internalListener.onFailure(e); } - // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state - handleAsyncMLTaskComplete(mlTask); - MLTaskResponse response = MLTaskResponse.builder().output(output).build(); - internalListener.onResponse(response); }, e -> { log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e); handlePredictFailure(mlTask, internalListener, e, true); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index 9ef06fc6bc..0eb4ebbab1 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -10,13 +10,19 @@ import java.util.HashMap; import java.util.Map; +import lombok.extern.log4j.Log4j2; + import org.opensearch.action.ActionListener; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.parameter.MLTask; import org.opensearch.ml.common.parameter.MLTaskState; +import org.opensearch.ml.common.transport.MLTaskRequest; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.stats.MLStats; +import org.opensearch.transport.TransportResponse; +import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableMap; @@ -26,12 +32,14 @@ * @param ML task request * @param ML task request */ -public abstract class MLTaskRunner { +@Log4j2 +public abstract class MLTaskRunner { public static final int TIMEOUT_IN_MILLIS = 2000; protected final MLTaskManager mlTaskManager; protected final MLStats mlStats; protected final MLTaskDispatcher mlTaskDispatcher; protected final MLCircuitBreakerService mlCircuitBreakerService; + private final ClusterService clusterService; protected static final String TASK_ID = "task_id"; protected static final String ALGORITHM = "algorithm"; @@ -44,12 +52,14 @@ public MLTaskRunner( MLTaskManager mlTaskManager, MLStats mlStats, MLTaskDispatcher mlTaskDispatcher, - MLCircuitBreakerService mlCircuitBreakerService + MLCircuitBreakerService mlCircuitBreakerService, + ClusterService clusterService ) { this.mlTaskManager = mlTaskManager; this.mlStats = mlStats; this.mlTaskDispatcher = mlTaskDispatcher; this.mlCircuitBreakerService = mlCircuitBreakerService; + this.clusterService = clusterService; } protected void handleAsyncMLTaskFailure(MLTask mlTask, Exception e) { @@ -80,7 +90,12 @@ public void run(Request request, TransportService transportService, ActionListen if (mlCircuitBreakerService.isOpen()) { throw new MLLimitExceededException("Circuit breaker is open"); } - executeTask(request, transportService, listener); + if (!request.isDispatchTask()) { + log.info("Run ML request {} locally", request.getRequestID()); + executeTask(request, listener); + return; + } + dispatchTask(request, transportService, listener); } protected ActionListener wrappedCleanupListener(ActionListener listener, String taskId) { @@ -91,5 +106,24 @@ protected ActionListener wrappedCleanupListener(ActionListener listener); + protected void dispatchTask(Request request, TransportService transportService, ActionListener listener) { + mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> { + if (clusterService.localNode().getId().equals(node.getId())) { + // Execute ML task locally + log.info("Execute ML request {} locally on node {}", request.getRequestID(), node.getId()); + executeTask(request, listener); + } else { + // Execute ML task remotely + log.info("Execute ML request {} remotely on node {}", request.getRequestID(), node.getId()); + request.setDispatchTask(false); + transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener)); + } + }, e -> listener.onFailure(e))); + } + + protected abstract String getTransportActionName(); + + protected abstract TransportResponseHandler getResponseHandler(ActionListener listener); + + protected abstract void executeTask(Request request, ActionListener listener); } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java index 2730d5ec31..a7825bf5a0 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunner.java @@ -40,7 +40,7 @@ import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLStats; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; +import org.opensearch.transport.TransportResponseHandler; /** * MLPredictTaskRunner is responsible for running predict tasks. @@ -62,7 +62,7 @@ public MLTrainAndPredictTaskRunner( MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService ) { - super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService); + super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService); this.threadPool = threadPool; this.clusterService = clusterService; this.client = client; @@ -70,24 +70,13 @@ public MLTrainAndPredictTaskRunner( } @Override - public void executeTask(MLTrainingTaskRequest request, TransportService transportService, ActionListener listener) { - mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> { - if (clusterService.localNode().getId().equals(node.getId())) { - // Execute prediction task locally - log.info("execute ML train and prediction request {} locally on node {}", request.toString(), node.getId()); - startTrainAndPredictionTask(request, listener); - } else { - // Execute batch task remotely - log.info("execute ML train and prediction request {} remotely on node {}", request.toString(), node.getId()); - transportService - .sendRequest( - node, - MLTrainAndPredictionTaskAction.NAME, - request, - new ActionListenerResponseHandler<>(listener, MLTaskResponse::new) - ); - } - }, e -> listener.onFailure(e))); + protected String getTransportActionName() { + return MLTrainAndPredictionTaskAction.NAME; + } + + @Override + protected TransportResponseHandler getResponseHandler(ActionListener listener) { + return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new); } /** @@ -95,7 +84,8 @@ public void executeTask(MLTrainingTaskRequest request, TransportService transpor * @param request MLPredictionTaskRequest * @param listener Action listener */ - public void startTrainAndPredictionTask(MLTrainingTaskRequest request, ActionListener listener) { + @Override + protected void executeTask(MLTrainingTaskRequest request, ActionListener listener) { MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType(); Instant now = Instant.now(); MLTask mlTask = MLTask diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index 6b031e9eff..e2179c006d 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -52,7 +52,7 @@ import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLStats; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; +import org.opensearch.transport.TransportResponseHandler; /** * MLTrainingTaskRunner is responsible for running training tasks. @@ -76,7 +76,7 @@ public MLTrainingTaskRunner( MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService ) { - super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService); + super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService); this.threadPool = threadPool; this.clusterService = clusterService; this.client = client; @@ -85,27 +85,17 @@ public MLTrainingTaskRunner( } @Override - public void executeTask(MLTrainingTaskRequest request, TransportService transportService, ActionListener listener) { - mlTaskDispatcher.dispatchTask(ActionListener.wrap(node -> { - if (clusterService.localNode().getId().equals(node.getId())) { - // Execute training task locally - log.info("execute ML training request {} locally on node {}", request.toString(), node.getId()); - createMLTaskAndTrain(request, listener); - } else { - // Execute batch task remotely - log.info("execute ML training request {} remotely on node {}", request.toString(), node.getId()); - transportService - .sendRequest( - node, - MLTrainingTaskAction.NAME, - request, - new ActionListenerResponseHandler<>(listener, MLTaskResponse::new) - ); - } - }, e -> listener.onFailure(e))); + protected String getTransportActionName() { + return MLTrainingTaskAction.NAME; } - public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener listener) { + @Override + protected TransportResponseHandler getResponseHandler(ActionListener listener) { + return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new); + } + + @Override + protected void executeTask(MLTrainingTaskRequest request, ActionListener listener) { MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType(); Instant now = Instant.now(); MLTask mlTask = MLTask @@ -151,7 +141,7 @@ public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener listener) { + private void startTrainingTask(MLTask mlTask, MLInput mlInput, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment(); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java new file mode 100644 index 0000000000..2a1d406283 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.task; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.common.breaker.MLCircuitBreakerService; +import org.opensearch.ml.common.parameter.LocalSampleCalculatorInput; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; +import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.stats.MLStat; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.ml.stats.StatNames; +import org.opensearch.ml.stats.suppliers.CounterSupplier; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + ClusterService clusterService; + + @Mock + Client client; + + @Mock + MLTaskManager mlTaskManager; + + @Mock + ExecutorService executorService; + + @Mock + MLTaskDispatcher mlTaskDispatcher; + + @Mock + MLCircuitBreakerService mlCircuitBreakerService; + + @Mock + TransportService transportService; + + @Mock + ActionListener listener; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + MLInputDatasetHandler mlInputDatasetHandler; + MLExecuteTaskRunner taskRunner; + MLStats mlStats; + MLExecuteTaskRequest mlExecuteTaskRequest; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + when(threadPool.executor(anyString())).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + + Map> stats = new ConcurrentHashMap<>(); + stats.put(StatNames.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + this.mlStats = new MLStats(stats); + + mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); + taskRunner = spy( + new MLExecuteTaskRunner( + threadPool, + clusterService, + client, + mlTaskManager, + mlStats, + mlInputDatasetHandler, + mlTaskDispatcher, + mlCircuitBreakerService + ) + ); + + mlExecuteTaskRequest = new MLExecuteTaskRequest(new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0)), false); + } + + public void testExecuteTask_Success() { + taskRunner.executeTask(mlExecuteTaskRequest, listener); + verify(listener).onResponse(any(MLExecuteTaskResponse.class)); + } + + public void testExecuteTask_NoExecutorService() { + exceptionRule.expect(IllegalArgumentException.class); + when(threadPool.executor(anyString())).thenThrow(new IllegalArgumentException()); + taskRunner.executeTask(mlExecuteTaskRequest, listener); + verify(listener, never()).onResponse(any(MLExecuteTaskResponse.class)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java new file mode 100644 index 0000000000..954c923bd9 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -0,0 +1,341 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.task; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.spy; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; +import org.opensearch.index.get.GetResult; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.ml.common.breaker.MLCircuitBreakerService; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.SearchQueryInputDataset; +import org.opensearch.ml.common.parameter.BatchRCFParams; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.MLInput; +import org.opensearch.ml.common.parameter.MLModel; +import org.opensearch.ml.common.parameter.MLTask; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.stats.MLStat; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.ml.stats.StatNames; +import org.opensearch.ml.stats.suppliers.CounterSupplier; +import org.opensearch.ml.utils.TestData; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class MLPredictTaskRunnerTests extends OpenSearchTestCase { + + public static final String USER_STRING = "myuser|role1,role2|myTenant"; + @Mock + ThreadPool threadPool; + + @Mock + ClusterService clusterService; + + @Mock + Client client; + + @Mock + MLTaskManager mlTaskManager; + + @Mock + ExecutorService executorService; + + @Mock + MLTaskDispatcher mlTaskDispatcher; + + @Mock + MLCircuitBreakerService mlCircuitBreakerService; + + @Mock + TransportService transportService; + + @Mock + ActionListener listener; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + MLStats mlStats; + DataFrame dataFrame; + DiscoveryNode localNode; + DiscoveryNode remoteNode; + MLInputDatasetHandler mlInputDatasetHandler; + MLPredictTaskRunner taskRunner; + MLPredictionTaskRequest requestWithDataFrame; + MLPredictionTaskRequest requestWithQuery; + ThreadContext threadContext; + String indexName = "testIndex"; + String errorMessage = "test error"; + GetResponse getResponse; + MLInput mlInputWithDataFrame; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); + remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT); + when(clusterService.localNode()).thenReturn(localNode); + + when(threadPool.executor(anyString())).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + + Map> stats = new ConcurrentHashMap<>(); + stats.put(StatNames.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); + stats.put(StatNames.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + this.mlStats = new MLStats(stats); + mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); + taskRunner = spy( + new MLPredictTaskRunner( + threadPool, + clusterService, + client, + mlTaskManager, + mlStats, + mlInputDatasetHandler, + mlTaskDispatcher, + mlCircuitBreakerService, + xContentRegistry() + ) + ); + + dataFrame = TestData.constructTestDataFrame(100); + + MLInputDataset dataFrameInputDataSet = new DataFrameInputDataset(dataFrame); + BatchRCFParams batchRCFParams = BatchRCFParams.builder().build(); + mlInputWithDataFrame = MLInput + .builder() + .algorithm(FunctionName.BATCH_RCF) + .parameters(batchRCFParams) + .inputDataset(dataFrameInputDataSet) + .build(); + requestWithDataFrame = MLPredictionTaskRequest.builder().modelId("111").mlInput(mlInputWithDataFrame).build(); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(new MatchAllQueryBuilder()); + MLInputDataset queryInputDataSet = new SearchQueryInputDataset(ImmutableList.of(indexName), searchSourceBuilder); + MLInput mlInputWithQuery = MLInput + .builder() + .algorithm(FunctionName.BATCH_RCF) + .parameters(batchRCFParams) + .inputDataset(queryInputDataSet) + .build(); + requestWithQuery = MLPredictionTaskRequest.builder().modelId("111").mlInput(mlInputWithQuery).build(); + + when(client.threadPool()).thenReturn(threadPool); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + MLModel mlModel = MLModel + .builder() + .user(User.parse(USER_STRING)) + .version(111) + .name("test") + .algorithm(FunctionName.BATCH_RCF) + .content("content") + .build(); + XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + + GetResult getResult = new GetResult( + indexName, + MapperService.SINGLE_MAPPING_NAME, + "111", + 111l, + 111l, + 111l, + true, + bytesReference, + null, + null + ); + getResponse = new GetResponse(getResult); + } + + public void testExecuteTask_OnLocalNode() { + setupMocks(true, false, false, false); + + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); + verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); + verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); + verify(mlTaskManager).add(any(MLTask.class)); + verify(client).get(any(), any()); + verify(mlTaskManager).remove(anyString()); + } + + public void testExecuteTask_OnLocalNode_QueryInput() { + setupMocks(true, false, false, false); + + taskRunner.dispatchTask(requestWithQuery, transportService, listener); + verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); + verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); + verify(mlTaskManager).add(any(MLTask.class)); + verify(client).get(any(), any()); + verify(mlTaskManager).remove(anyString()); + } + + public void testExecuteTask_OnLocalNode_QueryInput_Failure() { + setupMocks(true, true, false, false); + + taskRunner.dispatchTask(requestWithQuery, transportService, listener); + verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); + verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); + verify(mlTaskManager, never()).add(any(MLTask.class)); + verify(client, never()).get(any(), any()); + } + + public void testExecuteTask_NoPermission() { + setupMocks(true, true, false, false); + threadContext.stashContext(); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "test_user|test_role|test_tenant"); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); + verify(mlTaskManager).add(any(MLTask.class)); + verify(mlTaskManager).remove(anyString()); + verify(client).get(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("User: test_user does not have permissions to run predict by model: 111", argumentCaptor.getValue().getMessage()); + } + + public void testExecuteTask_OnRemoteNode() { + setupMocks(false, false, false, false); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); + verify(transportService).sendRequest(eq(remoteNode), eq(MLPredictionTaskAction.NAME), eq(requestWithDataFrame), any()); + } + + public void testExecuteTask_OnLocalNode_GetModelFail() { + setupMocks(true, false, true, false); + + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); + verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); + verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); + verify(mlTaskManager).add(any(MLTask.class)); + verify(client).get(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); + } + + public void testExecuteTask_OnLocalNode_NullModelIdException() { + setupMocks(true, false, false, false); + requestWithDataFrame = MLPredictionTaskRequest.builder().mlInput(mlInputWithDataFrame).build(); + + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); + verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); + verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); + verify(mlTaskManager).add(any(MLTask.class)); + verify(client, never()).get(any(), any()); + verify(mlTaskManager).remove(anyString()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("ModelId is invalid", argumentCaptor.getValue().getMessage()); + } + + public void testExecuteTask_OnLocalNode_NullGetResponse() { + setupMocks(true, false, false, true); + + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); + verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); + verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); + verify(mlTaskManager).add(any(MLTask.class)); + verify(client).get(any(), any()); + verify(mlTaskManager).remove(anyString()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("No model found, please check the modelId.", argumentCaptor.getValue().getMessage()); + } + + private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput, boolean failedToGetModel, boolean nullGetResponse) { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(0); + if (runOnLocalNode) { + actionListener.onResponse(localNode); + } else { + actionListener.onResponse(remoteNode); + } + return null; + }).when(mlTaskDispatcher).dispatchTask(any()); + + if (failedToParseQueryInput) { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException(errorMessage)); + return null; + }).when(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); + } else { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(dataFrame); + return null; + }).when(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); + } + + if (nullGetResponse) { + getResponse = null; + } + + if (failedToGetModel) { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException(errorMessage)); + return null; + }).when(client).get(any(), any()); + } else { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + } + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index 9becde1721..e09ccd7876 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -152,7 +152,7 @@ public void testExecuteTask_OnLocalNode() { actionListener.onResponse(localNode); return null; }).when(mlTaskDispatcher).dispatchTask(any()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(listener).onResponse(any()); verify(taskRunner).handleAsyncMLTaskComplete(any(MLTask.class)); } @@ -170,7 +170,7 @@ public void testExecuteTask_OnLocalNode_QueryInput() { return null; }).when(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); - taskRunner.executeTask(requestWithQuery, transportService, listener); + taskRunner.dispatchTask(requestWithQuery, transportService, listener); verify(listener).onResponse(any()); verify(taskRunner).handleAsyncMLTaskComplete(any(MLTask.class)); } @@ -188,7 +188,7 @@ public void testExecuteTask_OnLocalNode_QueryInput_Failure() { return null; }).when(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); - taskRunner.executeTask(requestWithQuery, transportService, listener); + taskRunner.dispatchTask(requestWithQuery, transportService, listener); verify(listener, never()).onResponse(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -203,7 +203,7 @@ public void testExecuteTask_OnLocalNode_FailedToUpdateTask() { return null; }).when(mlTaskDispatcher).dispatchTask(any()); doThrow(new RuntimeException(errorMessage)).when(mlTaskManager).updateTaskState(anyString(), any(MLTaskState.class), anyBoolean()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); @@ -216,7 +216,7 @@ public void testExecuteTask_OnRemoteNode() { actionListener.onResponse(remoteNode); return null; }).when(mlTaskDispatcher).dispatchTask(any()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(transportService).sendRequest(eq(remoteNode), eq(MLTrainAndPredictionTaskAction.NAME), eq(requestWithDataFrame), any()); } @@ -226,7 +226,7 @@ public void testExecuteTask_FailedToDispatch() { actionListener.onFailure(new RuntimeException(errorMessage)); return null; }).when(mlTaskDispatcher).dispatchTask(any()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(listener, never()).onResponse(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index 053e4dbcd7..676bbe8f10 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -56,6 +56,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.NodeNotConnectedException; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableList; @@ -166,7 +167,7 @@ public void setup() { public void testExecuteTask_OnLocalNode_SyncRequest() { setupMocks(true, false, false, false); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(listener).onResponse(any()); verify(mlTaskManager, never()).createMLTask(any(MLTask.class), any()); verify(mlTaskManager).add(any(MLTask.class)); @@ -177,7 +178,7 @@ public void testExecuteTask_OnLocalNode_SyncRequest() { public void testExecuteTask_OnLocalNode_SyncRequest_QueryInput() { setupMocks(true, false, false, false); - taskRunner.executeTask(requestWithQuery, transportService, listener); + taskRunner.dispatchTask(requestWithQuery, transportService, listener); verify(listener).onResponse(any()); verify(mlTaskManager, never()).createMLTask(any(MLTask.class), any()); verify(mlTaskManager).add(any(MLTask.class)); @@ -188,7 +189,7 @@ public void testExecuteTask_OnLocalNode_SyncRequest_QueryInput() { public void testExecuteTask_OnLocalNode_AsyncRequest_QueryInput_Failure() { setupMocks(true, false, false, true); - taskRunner.executeTask(asyncRequestWithQuery, transportService, listener); + taskRunner.dispatchTask(asyncRequestWithQuery, transportService, listener); verify(listener).onResponse(any()); verify(mlTaskManager).createMLTask(any(MLTask.class), any()); verify(mlTaskManager).add(any(MLTask.class)); @@ -200,7 +201,7 @@ public void testExecuteTask_OnLocalNode_AsyncRequest_QueryInput_Failure() { public void testExecuteTask_OnLocalNode_AsyncRequest() { setupMocks(true, false, false, false); - taskRunner.executeTask(asyncRequestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(asyncRequestWithDataFrame, transportService, listener); verify(listener).onResponse(any()); verify(mlTaskManager).createMLTask(any(MLTask.class), any()); verify(mlTaskManager).add(any(MLTask.class)); @@ -211,7 +212,7 @@ public void testExecuteTask_OnLocalNode_AsyncRequest() { public void testExecuteTask_OnLocalNode_AsyncRequest_FailToCreateTask() { setupMocks(true, true, false, false); - taskRunner.executeTask(asyncRequestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(asyncRequestWithDataFrame, transportService, listener); verify(listener, never()).onResponse(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -227,7 +228,7 @@ public void testExecuteTask_OnLocalNode_AsyncRequest_FailToCreateTask() { public void testExecuteTask_OnLocalNode_AsyncRequest_FailToCreateTaskWithException() { setupMocks(true, true, true, false); - taskRunner.executeTask(asyncRequestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(asyncRequestWithDataFrame, transportService, listener); verify(listener, never()).onResponse(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -243,17 +244,29 @@ public void testExecuteTask_OnLocalNode_AsyncRequest_FailToCreateTaskWithExcepti public void testExecuteTask_OnRemoteNode_SyncRequest() { setupMocks(false, false, false, false); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(transportService).sendRequest(eq(remoteNode), eq(MLTrainingTaskAction.NAME), eq(requestWithDataFrame), any()); } + public void testExecuteTask_OnRemoteNode_SyncRequest_FailToSendRequest() { + setupMocks(false, false, false, false); + doThrow(new NodeNotConnectedException(remoteNode, errorMessage)) + .when(transportService) + .sendRequest(eq(remoteNode), eq(MLTrainingTaskAction.NAME), any(), any()); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); + verify(transportService).sendRequest(eq(remoteNode), eq(MLTrainingTaskAction.NAME), eq(requestWithDataFrame), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertTrue(argumentCaptor.getValue().getMessage().contains(errorMessage)); + } + public void testExecuteTask_FailedToDispatch() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(0); actionListener.onFailure(new RuntimeException(errorMessage)); return null; }).when(mlTaskDispatcher).dispatchTask(any()); - taskRunner.executeTask(requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(requestWithDataFrame, transportService, listener); verify(listener, never()).onResponse(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java index a9f2c43b62..c71a365533 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java @@ -24,13 +24,16 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.parameter.MLTask; import org.opensearch.ml.common.parameter.MLTaskState; import org.opensearch.ml.common.parameter.MLTaskType; +import org.opensearch.ml.common.transport.MLTaskRequest; import org.opensearch.ml.stats.MLStats; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; public class TaskRunnerTests extends OpenSearchTestCase { @@ -43,6 +46,8 @@ public class TaskRunnerTests extends OpenSearchTestCase { MLTaskDispatcher mlTaskDispatcher; @Mock MLCircuitBreakerService mlCircuitBreakerService; + @Mock + ClusterService clusterService; MLTaskRunner mlTaskRunner; MLTask mlTask; @@ -53,9 +58,19 @@ public class TaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - mlTaskRunner = new MLTaskRunner(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService) { + mlTaskRunner = new MLTaskRunner(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService) { + @Override + public String getTransportActionName() { + return null; + } + + @Override + public TransportResponseHandler getResponseHandler(ActionListener listener) { + return null; + } + @Override - public void executeTask(Object o, TransportService transportService, ActionListener listener) {} + public void executeTask(MLTaskRequest request, ActionListener listener) {} }; mlTask = MLTask .builder()