diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java index b5929d0f20..2947afc5b9 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java @@ -88,8 +88,7 @@ public List routes() { new Route(GET, BASE_DATASOURCE_ACTION_URL), /* - * GET datasources - * Request URL: GET + * PUT datasources * Request body: * Ref * [org.opensearch.sql.plugin.transport.datasource.model.UpdateDataSourceActionRequest] @@ -100,8 +99,7 @@ public List routes() { new Route(PUT, BASE_DATASOURCE_ACTION_URL), /* - * GET datasources - * Request URL: GET + * DELETE datasources * Request body: Ref * [org.opensearch.sql.plugin.transport.datasource.model.DeleteDataSourceActionRequest] * Response body: Ref diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index 98990b795b..85e65aa1a3 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -106,3 +106,19 @@ Sample Response If the Query is successful :: "total": 1, "size": 1 } +Job Cancellation API +====================================== +If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/jobs/delete``. + + +HTTP URI: _plugins/_query/_jobs/{jobId} +HTTP VERB: DELETE + + +Sample Request BODY:: + + curl --location --request DELETE 'http://localhost:9200/_plugins/_query/_jobs/00fdalrvgkbh2g0q' \ + --header 'Content-Type: application/json' \ + --data '{ + "query" : "select * from my_glue.default.http_logs limit 10" + }' diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java index df13daa2a2..7caa69293a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorService.java @@ -29,4 +29,12 @@ public interface AsyncQueryExecutorService { * @return {@link AsyncQueryExecutionResponse} */ AsyncQueryExecutionResponse getAsyncQueryResults(String queryId); + + /** + * Cancels running async query and returns the cancelled queryId. + * + * @param queryId queryId. + * @return {@link String} cancelledQueryId. + */ + String cancelQuery(String queryId); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index e5ed65920e..efc23e08b5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -95,6 +95,17 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); } + @Override + public String cancelQuery(String queryId) { + Optional asyncQueryJobMetadata = + asyncQueryJobMetadataStorageService.getJobMetadata(queryId); + if (asyncQueryJobMetadata.isPresent()) { + return sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadata.get().getApplicationId(), queryId); + } + throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); + } + private void validateSparkExecutionEngineSettings() { if (!isSparkJobExecutionEnabled) { throw new IllegalArgumentException( diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java index b554c4cd23..294d173fb0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrServerlessClientImpl.java @@ -9,12 +9,15 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.CancelJobRunRequest; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunRequest; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobDriver; import com.amazonaws.services.emrserverless.model.SparkSubmit; import com.amazonaws.services.emrserverless.model.StartJobRunRequest; import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import com.amazonaws.services.emrserverless.model.ValidationException; import java.security.AccessController; import java.security.PrivilegedAction; import org.apache.logging.log4j.LogManager; @@ -65,4 +68,21 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { logger.info("Job Run state: " + getJobRunResult.getJobRun().getState()); return getJobRunResult; } + + @Override + public CancelJobRunResult closeJobRunResult(String applicationId, String jobId) { + CancelJobRunRequest cancelJobRunRequest = + new CancelJobRunRequest().withJobRunId(jobId).withApplicationId(applicationId); + try { + CancelJobRunResult cancelJobRunResult = + AccessController.doPrivileged( + (PrivilegedAction) + () -> emrServerless.cancelJobRun(cancelJobRunRequest)); + logger.info(String.format("Job : %s cancelled", cancelJobRunResult.getJobRunId())); + return cancelJobRunResult; + } catch (ValidationException e) { + throw new IllegalArgumentException( + String.format("Couldn't cancel the queryId: %s due to %s", jobId, e.getMessage())); + } + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java index ff9f4acedd..35d7174452 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/SparkJobClient.java @@ -7,6 +7,7 @@ package org.opensearch.sql.spark.client; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; public interface SparkJobClient { @@ -19,4 +20,6 @@ String startJobRun( String sparkSubmitParams); GetJobRunResult getJobRunResult(String applicationId, String jobId); + + CancelJobRunResult closeJobRunResult(String applicationId, String jobId); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index f632ceaf6a..f5b4dc074f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -15,6 +15,7 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY; import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_GLUE_ARN_KEY; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRunState; import java.net.URI; @@ -64,6 +65,11 @@ public JSONObject getQueryResponse(String applicationId, String queryId) { return result; } + public String cancelJob(String applicationId, String jobId) { + CancelJobRunResult cancelJobRunResult = sparkJobClient.closeJobRunResult(applicationId, jobId); + return cancelJobRunResult.getJobRunId(); + } + // TODO: Analyze given query // Extract datasourceName // Apply Authorizaiton. diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index 56484688dc..741501cd18 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -194,7 +194,7 @@ public void onResponse( CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse) { restChannel.sendResponse( new BytesRestResponse( - RestStatus.OK, + RestStatus.NO_CONTENT, "application/json; charset=UTF-8", cancelAsyncQueryActionResponse.getResult())); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java index 990dbccd0b..232a280db5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestAction.java @@ -12,6 +12,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -21,13 +22,17 @@ public class TransportCancelAsyncQueryRequestAction extends HandledTransportAction { public static final String NAME = "cluster:admin/opensearch/ql/async_query/delete"; + private final AsyncQueryExecutorServiceImpl asyncQueryExecutorService; public static final ActionType ACTION_TYPE = new ActionType<>(NAME, CancelAsyncQueryActionResponse::new); @Inject public TransportCancelAsyncQueryRequestAction( - TransportService transportService, ActionFilters actionFilters) { + TransportService transportService, + ActionFilters actionFilters, + AsyncQueryExecutorServiceImpl asyncQueryExecutorService) { super(NAME, transportService, actionFilters, CancelAsyncQueryActionRequest::new); + this.asyncQueryExecutorService = asyncQueryExecutorService; } @Override @@ -35,7 +40,13 @@ protected void doExecute( Task task, CancelAsyncQueryActionRequest request, ActionListener listener) { - String responseContent = "deleted_job"; - listener.onResponse(new CancelAsyncQueryActionResponse(responseContent)); + try { + String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId()); + listener.onResponse( + new CancelAsyncQueryActionResponse( + String.format("Deleted async query with id: %s", jobId))); + } catch (Exception e) { + listener.onFailure(e); + } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java index e12f184efe..0065b575ed 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/model/CancelAsyncQueryActionRequest.java @@ -9,11 +9,13 @@ import java.io.IOException; import lombok.AllArgsConstructor; +import lombok.Getter; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.StreamInput; @AllArgsConstructor +@Getter public class CancelAsyncQueryActionRequest extends ActionRequest { private String queryId; diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index cf04278892..5e832777fc 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -142,4 +142,34 @@ void testGetAsyncQueryResultsWithDisabledExecutionEngine() { + " to enable Async Query APIs", illegalArgumentException.getMessage()); } + + @Test + void testCancelJobWithJobNotFound() { + AsyncQueryExecutorService asyncQueryExecutorService = + new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.empty()); + AsyncQueryNotFoundException asyncQueryNotFoundException = + Assertions.assertThrows( + AsyncQueryNotFoundException.class, + () -> asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)); + Assertions.assertEquals( + "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); + verifyNoInteractions(sparkQueryDispatcher); + verifyNoInteractions(settings); + } + + @Test + void testCancelJob() { + AsyncQueryExecutorService asyncQueryExecutorService = + new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + when(sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID)).thenReturn(EMR_JOB_ID); + String jobId = asyncQueryExecutorService.cancelQuery(EMR_JOB_ID); + Assertions.assertEquals(EMR_JOB_ID, jobId); + verifyNoInteractions(settings); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java index 36f10cd08b..960f40186d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrServerlessClientImplTest.java @@ -5,17 +5,22 @@ package org.opensearch.sql.spark.client; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_JOB_NAME; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.QUERY; import static org.opensearch.sql.spark.constants.TestConstants.SPARK_SUBMIT_PARAMETERS; import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.StartJobRunResult; +import com.amazonaws.services.emrserverless.model.ValidationException; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -45,4 +50,28 @@ void testGetJobRunState() { EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, "123"); } + + @Test + void testCancelJobRun() { + when(emrServerless.cancelJobRun(any())) + .thenReturn(new CancelJobRunResult().withJobRunId(EMR_JOB_ID)); + EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + CancelJobRunResult cancelJobRunResult = + emrServerlessClient.closeJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID); + Assertions.assertEquals(EMR_JOB_ID, cancelJobRunResult.getJobRunId()); + } + + @Test + void testCancelJobRunWithValidationException() { + doThrow(new ValidationException("Error")).when(emrServerless).cancelJobRun(any()); + EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless); + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> emrServerlessClient.closeJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)); + Assertions.assertEquals( + "Couldn't cancel the queryId: job-123xxx due to Error (Service: null; Status Code: 0; Error" + + " Code: null; Request ID: null; Proxy: null)", + illegalArgumentException.getMessage()); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 800bd59b72..82ad94bab6 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -14,6 +14,7 @@ import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.JobRunState; @@ -79,36 +80,17 @@ void testDispatchWithWrongURI() { illegalArgumentException.getMessage()); } - private DataSourceMetadata constructMyGlueDataSourceMetadata() { - DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); - dataSourceMetadata.setName("my_glue"); - dataSourceMetadata.setConnector(DataSourceType.S3GLUE); - Map properties = new HashMap<>(); - properties.put("glue.auth.type", "iam_role"); - properties.put( - "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); - properties.put( - "glue.indexstore.opensearch.uri", - "https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com"); - properties.put("glue.indexstore.opensearch.auth", "sigv4"); - properties.put("glue.indexstore.opensearch.region", "eu-west-1"); - dataSourceMetadata.setProperties(properties); - return dataSourceMetadata; - } - - private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { - DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); - dataSourceMetadata.setName("my_glue"); - dataSourceMetadata.setConnector(DataSourceType.S3GLUE); - Map properties = new HashMap<>(); - properties.put("glue.auth.type", "iam_role"); - properties.put( - "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); - properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param"); - properties.put("glue.indexstore.opensearch.auth", "sigv4"); - properties.put("glue.indexstore.opensearch.region", "eu-west-1"); - dataSourceMetadata.setProperties(properties); - return dataSourceMetadata; + @Test + void testCancelJob() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader); + when(sparkJobClient.closeJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + String jobId = sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID); + Assertions.assertEquals(EMR_JOB_ID, jobId); } @Test @@ -140,7 +122,7 @@ void testGetQueryResponseWithSuccess() { Assertions.assertEquals("SUCCESS", result.get("status")); } - String constructExpectedSparkSubmitParameterString() { + private String constructExpectedSparkSubmitParameterString() { return " --class org.opensearch.sql.FlintJob --conf" + " spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider" + " --conf" @@ -171,4 +153,36 @@ String constructExpectedSparkSubmitParameterString() { + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegateCatalog "; } + + private DataSourceMetadata constructMyGlueDataSourceMetadata() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_glue"); + dataSourceMetadata.setConnector(DataSourceType.S3GLUE); + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put( + "glue.indexstore.opensearch.uri", + "https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com"); + properties.put("glue.indexstore.opensearch.auth", "sigv4"); + properties.put("glue.indexstore.opensearch.region", "eu-west-1"); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } + + private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("my_glue"); + dataSourceMetadata.setConnector(DataSourceType.S3GLUE); + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param"); + properties.put("glue.indexstore.opensearch.auth", "sigv4"); + properties.put("glue.indexstore.opensearch.region", "eu-west-1"); + dataSourceMetadata.setProperties(properties); + return dataSourceMetadata; + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java index c560c882c0..2ff76b9b57 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/TransportCancelAsyncQueryRequestActionTest.java @@ -7,6 +7,10 @@ package org.opensearch.sql.spark.transport; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; + import java.util.HashSet; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -19,6 +23,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.tasks.Task; @@ -32,24 +37,40 @@ public class TransportCancelAsyncQueryRequestActionTest { @Mock private Task task; @Mock private ActionListener actionListener; + @Mock private AsyncQueryExecutorServiceImpl asyncQueryExecutorService; + @Captor private ArgumentCaptor deleteJobActionResponseArgumentCaptor; + @Captor private ArgumentCaptor exceptionArgumentCaptor; + @BeforeEach public void setUp() { action = new TransportCancelAsyncQueryRequestAction( - transportService, new ActionFilters(new HashSet<>())); + transportService, new ActionFilters(new HashSet<>()), asyncQueryExecutorService); } @Test public void testDoExecute() { - CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest("jobId"); - + CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); + when(asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)).thenReturn(EMR_JOB_ID); action.doExecute(task, request, actionListener); Mockito.verify(actionListener).onResponse(deleteJobActionResponseArgumentCaptor.capture()); CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse = deleteJobActionResponseArgumentCaptor.getValue(); - Assertions.assertEquals("deleted_job", cancelAsyncQueryActionResponse.getResult()); + Assertions.assertEquals( + "Deleted async query with id: " + EMR_JOB_ID, cancelAsyncQueryActionResponse.getResult()); + } + + @Test + public void testDoExecuteWithException() { + CancelAsyncQueryActionRequest request = new CancelAsyncQueryActionRequest(EMR_JOB_ID); + doThrow(new RuntimeException("Error")).when(asyncQueryExecutorService).cancelQuery(EMR_JOB_ID); + action.doExecute(task, request, actionListener); + Mockito.verify(actionListener).onFailure(exceptionArgumentCaptor.capture()); + Exception exception = exceptionArgumentCaptor.getValue(); + Assertions.assertTrue(exception instanceof RuntimeException); + Assertions.assertEquals("Error", exception.getMessage()); } }