diff --git a/common/src/main/java/org/opensearch/ml/common/MLTaskState.java b/common/src/main/java/org/opensearch/ml/common/MLTaskState.java index 77336be901..dfd7b835d4 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTaskState.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTaskState.java @@ -28,5 +28,7 @@ public enum MLTaskState { COMPLETED, FAILED, CANCELLED, - COMPLETED_WITH_ERROR + COMPLETED_WITH_ERROR, + CANCELLING, + EXPIRED } diff --git a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java index 179bf152cd..aafff5b50e 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTaskType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTaskType.java @@ -8,7 +8,6 @@ public enum MLTaskType { TRAINING, PREDICTION, - BATCH_PREDICTION, TRAINING_AND_PREDICTION, EXECUTION, @Deprecated @@ -17,5 +16,6 @@ public enum MLTaskType { LOAD_MODEL, REGISTER_MODEL, DEPLOY_MODEL, - BATCH_INGEST + BATCH_INGEST, + BATCH_PREDICTION } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 5ba465b15a..abe56cde0e 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -10,6 +10,7 @@ import java.util.Map; import java.util.function.Function; +import org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction; import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction; @@ -20,6 +21,7 @@ public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding"; + public static final String BEDROCK_BATCH_JOB_ARN = "connector.post_process.bedrock.batch_job_arn"; public static final String COHERE_RERANK = "connector.post_process.cohere.rerank"; public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; public static final String DEFAULT_RERANK = "connector.post_process.default.rerank"; @@ -31,17 +33,20 @@ public class MLPostProcessFunction { static { EmbeddingPostProcessFunction embeddingPostProcessFunction = new EmbeddingPostProcessFunction(); BedrockEmbeddingPostProcessFunction bedrockEmbeddingPostProcessFunction = new BedrockEmbeddingPostProcessFunction(); + BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction(); CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction(); JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding"); + JSON_PATH_EXPRESSION.put(BEDROCK_BATCH_JOB_ARN, "$"); JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results"); JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]"); POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, embeddingPostProcessFunction); POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, bedrockEmbeddingPostProcessFunction); + POST_PROCESS_FUNCTIONS.put(BEDROCK_BATCH_JOB_ARN, batchJobArnPostProcessFunction); POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction); POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java new file mode 100644 index 0000000000..e69829855e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunction.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.output.model.ModelTensor; + +public class BedrockBatchJobArnPostProcessFunction extends ConnectorPostProcessFunction> { + public static final String JOB_ARN = "jobArn"; + public static final String PROCESSED_JOB_ARN = "processedJobArn"; + + @Override + public void validate(Object input) { + if (!(input instanceof Map)) { + throw new IllegalArgumentException("Post process function input is not a Map."); + } + Map jobInfo = (Map) input; + if (!(jobInfo.containsKey(JOB_ARN))) { + throw new IllegalArgumentException("job arn is missing."); + } + } + + @Override + public List process(Map jobInfo) { + List modelTensors = new ArrayList<>(); + Map processedResult = new HashMap<>(); + processedResult.putAll(jobInfo); + String jobArn = jobInfo.get(JOB_ARN); + processedResult.put(PROCESSED_JOB_ARN, jobArn.replace("/", "%2F")); + modelTensors.add(ModelTensor.builder().name("response").dataAsMap(processedResult).build()); + return modelTensors; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java index 6ea26c9eb3..5c75e4c8d2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLCancelBatchJobAction.java @@ -9,7 +9,7 @@ public class MLCancelBatchJobAction extends ActionType { public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction(); - public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job"; + public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel"; private MLCancelBatchJobAction() { super(NAME, MLCancelBatchJobResponse::new); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunctionTest.java new file mode 100644 index 0000000000..3dae9011ea --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockBatchJobArnPostProcessFunctionTest.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector.functions.postprocess; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction.JOB_ARN; +import static org.opensearch.ml.common.connector.functions.postprocess.BedrockBatchJobArnPostProcessFunction.PROCESSED_JOB_ARN; + +public class BedrockBatchJobArnPostProcessFunctionTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + BedrockBatchJobArnPostProcessFunction function; + + @Before + public void setUp() { + function = new BedrockBatchJobArnPostProcessFunction(); + } + + @Test + public void process_WrongInput_NotMap() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Post process function input is not a Map."); + function.apply("abc"); + } + + @Test + public void process_WrongInput_NotContainJobArn() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("job arn is missing."); + function.apply(Map.of("test", "value")); + } + + @Test + public void process_CorrectInput() { + String jobArn = "arn:aws:bedrock:us-east-1:12345678912:model-invocation-job/w1xtlm0ik3e1"; + List result = function.apply(Map.of(JOB_ARN, jobArn)); + assertEquals(1, result.size()); + assertEquals(jobArn, result.get(0).getDataAsMap().get(JOB_ARN)); + assertEquals("arn:aws:bedrock:us-east-1:12345678912:model-invocation-job%2Fw1xtlm0ik3e1", result.get(0).getDataAsMap().get(PROCESSED_JOB_ARN)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 01b4724046..90c1c17a94 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -11,14 +11,25 @@ import static org.opensearch.ml.common.MLTask.REMOTE_JOB_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.CANCELLED; +import static org.opensearch.ml.common.MLTaskState.CANCELLING; import static org.opensearch.ml.common.MLTaskState.COMPLETED; +import static org.opensearch.ml.common.MLTaskState.EXPIRED; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD; import static org.opensearch.ml.utils.MLExceptionUtils.logException; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Consumer; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; @@ -30,6 +41,8 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; @@ -80,6 +93,12 @@ public class GetTaskTransportAction extends HandledTransportAction remoteJobStatusFields; + volatile Pattern remoteJobCompletedStatusRegexPattern; + volatile Pattern remoteJobCancelledStatusRegexPattern; + volatile Pattern remoteJobCancellingStatusRegexPattern; + volatile Pattern remoteJobExpiredStatusRegexPattern; + @Inject public GetTaskTransportAction( TransportService transportService, @@ -91,7 +110,8 @@ public GetTaskTransportAction( ConnectorAccessControlHelper connectorAccessControlHelper, EncryptorImpl encryptor, MLTaskManager mlTaskManager, - MLModelManager mlModelManager + MLModelManager mlModelManager, + Settings settings ) { super(MLTaskGetAction.NAME, transportService, actionFilters, MLTaskGetRequest::new); this.client = client; @@ -102,6 +122,19 @@ public GetTaskTransportAction( this.encryptor = encryptor; this.mlTaskManager = mlTaskManager; this.mlModelManager = mlModelManager; + + remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_JOB_STATUS_FIELD, it -> remoteJobStatusFields = it); + initializeRegexPattern(ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX, settings, clusterService, (regex) -> remoteJobCompletedStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)); + initializeRegexPattern(ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX, settings, clusterService, (regex) -> remoteJobCancelledStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)); + initializeRegexPattern(ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX, settings, clusterService, (regex) -> remoteJobCancellingStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)); + initializeRegexPattern(ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX, settings, clusterService, (regex) -> remoteJobExpiredStatusRegexPattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE)); + } + + private void initializeRegexPattern(Setting setting, Settings settings, ClusterService clusterService, Consumer patternInitializer) { + String regex = setting.get(settings); + patternInitializer.accept(regex); + clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, it -> patternInitializer.accept(it)); } @Override @@ -210,7 +243,7 @@ private void executeConnector( MLInput mlInput, String taskId, MLTask mlTask, - Map transformJob, + Map remoteJob, ActionListener actionListener ) { if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { @@ -222,7 +255,7 @@ private void executeConnector( connectorExecutor.setClient(client); connectorExecutor.setXContentRegistry(xContentRegistry); connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> { - processTaskResponse(mlTask, taskId, taskResponse, transformJob, actionListener); + processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener); }, e -> { actionListener.onFailure(e); })); } else { actionListener @@ -230,7 +263,7 @@ private void executeConnector( } } - private void processTaskResponse( + protected void processTaskResponse( MLTask mlTask, String taskId, MLTaskResponse taskResponse, @@ -248,15 +281,11 @@ private void processTaskResponse( Map updatedTask = new HashMap<>(); updatedTask.put(REMOTE_JOB_FIELD, remoteJob); - if ((remoteJob.containsKey("status") && remoteJob.get("status").equals("completed")) - || (remoteJob.containsKey("TransformJobStatus") && remoteJob.get("TransformJobStatus").equals("Completed"))) { - updatedTask.put(STATE_FIELD, COMPLETED); - mlTask.setState(COMPLETED); - - } else if ((remoteJob.containsKey("status") && remoteJob.get("status").equals("cancelled")) - || (remoteJob.containsKey("TransformJobStatus") && remoteJob.get("TransformJobStatus").equals("Stopped"))) { - updatedTask.put(STATE_FIELD, CANCELLED); - mlTask.setState(CANCELLED); + for (String statusField : remoteJobStatusFields) { + String statusValue = String.valueOf(remoteJob.get(statusField)); + if (remoteJob.containsKey(statusField)) { + updateTaskState(updatedTask, mlTask, statusValue); + } } mlTaskManager.updateMLTaskDirectly(taskId, updatedTask, ActionListener.wrap(response -> { actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build()); @@ -280,4 +309,25 @@ private void processTaskResponse( log.error("Unable to fetch status for ml task ", e); } } + + private void updateTaskState(Map updatedTask, MLTask mlTask, String statusValue) { + if (matchesPattern(remoteJobCancellingStatusRegexPattern, statusValue)) { + updatedTask.put(STATE_FIELD, CANCELLING); + mlTask.setState(CANCELLING); + } else if (matchesPattern(remoteJobCancelledStatusRegexPattern, statusValue)) { + updatedTask.put(STATE_FIELD, CANCELLED); + mlTask.setState(CANCELLED); + } else if (matchesPattern(remoteJobCompletedStatusRegexPattern, statusValue)) { + updatedTask.put(STATE_FIELD, COMPLETED); + mlTask.setState(COMPLETED); + } else if (matchesPattern(remoteJobExpiredStatusRegexPattern, statusValue)) { + updatedTask.put(STATE_FIELD, EXPIRED); + mlTask.setState(EXPIRED); + } + } + + private boolean matchesPattern(Pattern pattern, String input) { + Matcher matcher = pattern.matcher(input); + return matcher.find(); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index ed4d595897..39aaf05ff9 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -964,7 +964,12 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED, MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, - MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED + MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, + MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD, + MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX, + MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX, + MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX, + MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java index 33c7314be2..49d2247122 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCancelBatchJobAction.java @@ -23,8 +23,9 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +//TODO: Rename class and support cancelling more tasks. Now only support cancelling remote job public class RestMLCancelBatchJobAction extends BaseRestHandler { - private static final String ML_CANCEL_BATCH_ACTION = "ml_cancel_batch_action"; + private static final String ML_CANCEL_TASK_ACTION = "ml_cancel_task_action"; /** * Constructor @@ -33,18 +34,13 @@ public RestMLCancelBatchJobAction() {} @Override public String getName() { - return ML_CANCEL_BATCH_ACTION; + return ML_CANCEL_TASK_ACTION; } @Override public List routes() { return ImmutableList - .of( - new Route( - RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel_batch", ML_BASE_URI, PARAMETER_TASK_ID) - ) - ); + .of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/tasks/{%s}/_cancel", ML_BASE_URI, PARAMETER_TASK_ID))); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 339116226d..b9d7f2a9fc 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -200,4 +200,47 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED = Setting .boolSetting("plugins.ml_commons.connector.private_ip_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting> ML_COMMONS_REMOTE_JOB_STATUS_FIELD = Setting + .listSetting( + "plugins.ml_commons.remote_job.status_field", + ImmutableList + .of( + "status", // openai, bedrock, cohere + "Status", + "TransformJobStatus" // sagemaker + ), + Function.identity(), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX = Setting + .simpleString( + "plugins.ml_commons.remote_job.status_regex.completed", + "(complete|completed)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX = Setting + .simpleString( + "plugins.ml_commons.remote_job.status_regex.cancelled", + "(stopped|cancelled)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX = Setting + .simpleString( + "plugins.ml_commons.remote_job.status_regex.cancelling", + "(stopping|cancelling)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + public static final Setting ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX = Setting + .simpleString( + "plugins.ml_commons.remote_job.status_regex.expired", + "(expired|timeout)", + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index 3707c89eae..3fe4ce06f4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -15,12 +15,18 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD; import java.io.IOException; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.junit.Ignore; @@ -36,6 +42,7 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; @@ -49,13 +56,16 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; import org.opensearch.ml.engine.encryptor.EncryptorImpl; @@ -118,7 +128,14 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlTaskGetRequest = MLTaskGetRequest.builder().taskId("test_id").build(); - Settings settings = Settings.builder().build(); + Settings settings = Settings + .builder() + .putList(ML_COMMONS_REMOTE_JOB_STATUS_FIELD.getKey(), List.of("status", "TransformJobStatus")) + .put(ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX.getKey(), "(complete|completed)") + .put(ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX.getKey(), "(stopped|cancelled)") + .put(ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX.getKey(), "(stopping|cancelling)") + .put(ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX.getKey(), "(expired|timeout)") + .build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); @@ -127,6 +144,21 @@ public void setup() throws IOException { doReturn(metaData).when(clusterState).metadata(); doReturn(true).when(metaData).hasIndex(anyString()); + when(clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn( + new ClusterSettings( + settings, + Set + .of( + ML_COMMONS_REMOTE_JOB_STATUS_FIELD, + ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX, + ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX, + ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX, + ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX + ) + ) + ); getTaskTransportAction = spy( new GetTaskTransportAction( @@ -139,7 +171,8 @@ public void setup() throws IOException { connectorAccessControlHelper, encryptor, mlTaskManager, - mlModelManager + mlModelManager, + settings ) ); @@ -331,4 +364,59 @@ public GetResponse prepareMLTask(FunctionName functionName, MLTaskType mlTaskTyp GetResponse getResponse = new GetResponse(getResult); return getResponse; } + + public void test_processTaskResponse_complete() { + processTaskResponse("TransformJobStatus", "complete", MLTaskState.COMPLETED); + } + + public void test_processTaskResponse_cancelling() { + processTaskResponse("status", "cancelling", MLTaskState.CANCELLING); + } + + public void test_processTaskResponse_cancelled() { + processTaskResponse("status", "cancelled", MLTaskState.CANCELLED); + } + + public void test_processTaskResponse_expired() { + processTaskResponse("status", "expired", MLTaskState.EXPIRED); + } + + public void test_processTaskResponse_WrongStatusField() { + processTaskResponse("wrong_status_field", "expired", null); + } + + public void test_processTaskResponse_UnknownStatusField() { + processTaskResponse("status", "unkown_status", null); + } + + private void processTaskResponse(String statusField, String remoteJobResponseStatus, MLTaskState taskState) { + String taskId = "testTaskId"; + String remoteJobName = randomAlphaOfLength(5); + Map remoteJob = new HashMap(); + remoteJob.put(statusField, "running"); + remoteJob.put("name", remoteJobName); + MLTask mlTask = MLTask.builder().taskId(taskId) + .taskType(MLTaskType.BATCH_PREDICTION) + .inputType(MLInputDataType.REMOTE) + .state(MLTaskState.RUNNING) + .remoteJob(remoteJob) + .build(); + ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(Map.of(statusField, remoteJobResponseStatus)).build(); + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) + .build(); + MLTaskResponse taskResponse = MLTaskResponse.builder().output(modelTensorOutput).build(); + ActionListener actionListener = mock(ActionListener.class); + ArgumentCaptor> updatedTaskCaptor = ArgumentCaptor.forClass(Map.class); + + getTaskTransportAction.processTaskResponse(mlTask, taskId, taskResponse, mlTask.getRemoteJob(), actionListener); + + verify(mlTaskManager).updateMLTaskDirectly(any(), updatedTaskCaptor.capture(), any()); + Map updatedTask = updatedTaskCaptor.getValue(); + assertEquals(taskState, updatedTask.get("state")); + Map updatedRemoteJob = (Map)updatedTask.get("remote_job"); + assertEquals(remoteJobResponseStatus, updatedRemoteJob.get(statusField)); + assertEquals(remoteJobName, updatedRemoteJob.get("name")); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java index 1498750e6a..bd1d321fef 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCancelBatchJobActionTests.java @@ -74,7 +74,7 @@ public void testConstructor() { public void testGetName() { String actionName = restMLCancelBatchJobAction.getName(); assertFalse(Strings.isNullOrEmpty(actionName)); - assertEquals("ml_cancel_batch_action", actionName); + assertEquals("ml_cancel_task_action", actionName); } public void testRoutes() { @@ -83,7 +83,7 @@ public void testRoutes() { assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.POST, route.getMethod()); - assertEquals("/_plugins/_ml/tasks/{task_id}/_cancel_batch", route.getPath()); + assertEquals("/_plugins/_ml/tasks/{task_id}/_cancel", route.getPath()); } public void test_PrepareRequest() throws Exception {