Skip to content

Commit

Permalink
Cancel Job API
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <[email protected]>
  • Loading branch information
vamsimanohar committed Sep 21, 2023
1 parent f96a1a5 commit 0e9a5f7
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ public List<Route> routes() {
new Route(GET, BASE_DATASOURCE_ACTION_URL),

/*
* GET datasources
* Request URL: GET
* PUT datasources
* Request body:
* Ref
* [org.opensearch.sql.plugin.transport.datasource.model.UpdateDataSourceActionRequest]
Expand All @@ -100,8 +99,7 @@ public List<Route> routes() {
new Route(PUT, BASE_DATASOURCE_ACTION_URL),

/*
* GET datasources
* Request URL: GET
* DELETE datasources
* Request body: Ref
* [org.opensearch.sql.plugin.transport.datasource.model.DeleteDataSourceActionRequest]
* Response body: Ref
Expand Down
16 changes: 16 additions & 0 deletions docs/user/interfaces/asyncqueryinterface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,19 @@ Sample Response If the Query is successful ::
"total": 1,
"size": 1
}
Job Cancellation API
======================================
If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/jobs/delete``.


HTTP URI: _plugins/_query/_jobs/{jobId}
HTTP VERB: DELETE


Sample Request BODY::

curl --location --request DELETE 'http://localhost:9200/_plugins/_query/_jobs/00fdalrvgkbh2g0q' \
--header 'Content-Type: application/json' \
--data '{
"query" : "select * from my_glue.default.http_logs limit 10"
}'
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,12 @@ public interface AsyncQueryExecutorService {
* @return {@link AsyncQueryExecutionResponse}
*/
AsyncQueryExecutionResponse getAsyncQueryResults(String queryId);

/**
* Cancels running async query and returns the cancelled queryId.
*
* @param queryId queryId.
* @return {@link String} cancelledQueryId.
*/
String cancelQuery(String queryId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) {
throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId));
}

@Override
public String cancelQuery(String queryId) {
Optional<AsyncQueryJobMetadata> asyncQueryJobMetadata =
asyncQueryJobMetadataStorageService.getJobMetadata(queryId);
if (asyncQueryJobMetadata.isPresent()) {
return sparkQueryDispatcher.cancelJob(
asyncQueryJobMetadata.get().getApplicationId(), queryId);
}
throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId));
}

private void validateSparkExecutionEngineSettings() {
if (!isSparkJobExecutionEnabled) {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR;

import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.model.CancelJobRunRequest;
import com.amazonaws.services.emrserverless.model.CancelJobRunResult;
import com.amazonaws.services.emrserverless.model.GetJobRunRequest;
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
import com.amazonaws.services.emrserverless.model.JobDriver;
import com.amazonaws.services.emrserverless.model.SparkSubmit;
import com.amazonaws.services.emrserverless.model.StartJobRunRequest;
import com.amazonaws.services.emrserverless.model.StartJobRunResult;
import com.amazonaws.services.emrserverless.model.ValidationException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -65,4 +68,21 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) {
logger.info("Job Run state: " + getJobRunResult.getJobRun().getState());
return getJobRunResult;
}

@Override
public CancelJobRunResult closeJobRunResult(String applicationId, String jobId) {
CancelJobRunRequest cancelJobRunRequest =
new CancelJobRunRequest().withJobRunId(jobId).withApplicationId(applicationId);
try {
CancelJobRunResult cancelJobRunResult =
AccessController.doPrivileged(
(PrivilegedAction<CancelJobRunResult>)
() -> emrServerless.cancelJobRun(cancelJobRunRequest));
logger.info(String.format("Job : %s cancelled", cancelJobRunResult.getJobRunId()));
return cancelJobRunResult;
} catch (ValidationException e) {
throw new IllegalArgumentException(
String.format("Couldn't cancel the queryId: %s due to %s", jobId, e.getMessage()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.opensearch.sql.spark.client;

import com.amazonaws.services.emrserverless.model.CancelJobRunResult;
import com.amazonaws.services.emrserverless.model.GetJobRunResult;

public interface SparkJobClient {
Expand All @@ -19,4 +20,6 @@ String startJobRun(
String sparkSubmitParams);

GetJobRunResult getJobRunResult(String applicationId, String jobId);

CancelJobRunResult closeJobRunResult(String applicationId, String jobId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INDEX_STORE_SCHEME_KEY;
import static org.opensearch.sql.spark.data.constants.SparkConstants.HIVE_METASTORE_GLUE_ARN_KEY;

import com.amazonaws.services.emrserverless.model.CancelJobRunResult;
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
import com.amazonaws.services.emrserverless.model.JobRunState;
import java.net.URI;
Expand Down Expand Up @@ -64,6 +65,11 @@ public JSONObject getQueryResponse(String applicationId, String queryId) {
return result;
}

public String cancelJob(String applicationId, String jobId) {
CancelJobRunResult cancelJobRunResult = sparkJobClient.closeJobRunResult(applicationId, jobId);
return cancelJobRunResult.getJobRunId();
}

// TODO: Analyze given query
// Extract datasourceName
// Apply Authorizaiton.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public void onResponse(
CancelAsyncQueryActionResponse cancelAsyncQueryActionResponse) {
restChannel.sendResponse(
new BytesRestResponse(
RestStatus.OK,
RestStatus.NO_CONTENT,
"application/json; charset=UTF-8",
cancelAsyncQueryActionResponse.getResult()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.action.ActionListener;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionRequest;
import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse;
import org.opensearch.tasks.Task;
Expand All @@ -21,21 +22,31 @@ public class TransportCancelAsyncQueryRequestAction
extends HandledTransportAction<CancelAsyncQueryActionRequest, CancelAsyncQueryActionResponse> {

public static final String NAME = "cluster:admin/opensearch/ql/async_query/delete";
private final AsyncQueryExecutorServiceImpl asyncQueryExecutorService;
public static final ActionType<CancelAsyncQueryActionResponse> ACTION_TYPE =
new ActionType<>(NAME, CancelAsyncQueryActionResponse::new);

@Inject
public TransportCancelAsyncQueryRequestAction(
TransportService transportService, ActionFilters actionFilters) {
TransportService transportService,
ActionFilters actionFilters,
AsyncQueryExecutorServiceImpl asyncQueryExecutorService) {
super(NAME, transportService, actionFilters, CancelAsyncQueryActionRequest::new);
this.asyncQueryExecutorService = asyncQueryExecutorService;
}

@Override
protected void doExecute(
Task task,
CancelAsyncQueryActionRequest request,
ActionListener<CancelAsyncQueryActionResponse> listener) {
String responseContent = "deleted_job";
listener.onResponse(new CancelAsyncQueryActionResponse(responseContent));
try {
String jobId = asyncQueryExecutorService.cancelQuery(request.getQueryId());
listener.onResponse(
new CancelAsyncQueryActionResponse(
String.format("Deleted async query with id: %s", jobId)));
} catch (Exception e) {
listener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

import java.io.IOException;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;

@AllArgsConstructor
@Getter
public class CancelAsyncQueryActionRequest extends ActionRequest {

private String queryId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,34 @@ void testGetAsyncQueryResultsWithDisabledExecutionEngine() {
+ " to enable Async Query APIs",
illegalArgumentException.getMessage());
}

@Test
void testCancelJobWithJobNotFound() {
AsyncQueryExecutorService asyncQueryExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID))
.thenReturn(Optional.empty());
AsyncQueryNotFoundException asyncQueryNotFoundException =
Assertions.assertThrows(
AsyncQueryNotFoundException.class,
() -> asyncQueryExecutorService.cancelQuery(EMR_JOB_ID));
Assertions.assertEquals(
"QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage());
verifyNoInteractions(sparkQueryDispatcher);
verifyNoInteractions(settings);
}

@Test
void testCancelJob() {
AsyncQueryExecutorService asyncQueryExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID))
.thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID)));
when(sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID)).thenReturn(EMR_JOB_ID);
String jobId = asyncQueryExecutorService.cancelQuery(EMR_JOB_ID);
Assertions.assertEquals(EMR_JOB_ID, jobId);
verifyNoInteractions(settings);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,22 @@
package org.opensearch.sql.spark.client;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID;
import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE;
import static org.opensearch.sql.spark.constants.TestConstants.EMRS_JOB_NAME;
import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID;
import static org.opensearch.sql.spark.constants.TestConstants.QUERY;
import static org.opensearch.sql.spark.constants.TestConstants.SPARK_SUBMIT_PARAMETERS;

import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.model.CancelJobRunResult;
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
import com.amazonaws.services.emrserverless.model.JobRun;
import com.amazonaws.services.emrserverless.model.StartJobRunResult;
import com.amazonaws.services.emrserverless.model.ValidationException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand Down Expand Up @@ -45,4 +50,28 @@ void testGetJobRunState() {
EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless);
emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, "123");
}

@Test
void testCancelJobRun() {
when(emrServerless.cancelJobRun(any()))
.thenReturn(new CancelJobRunResult().withJobRunId(EMR_JOB_ID));
EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless);
CancelJobRunResult cancelJobRunResult =
emrServerlessClient.closeJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID);
Assertions.assertEquals(EMR_JOB_ID, cancelJobRunResult.getJobRunId());
}

@Test
void testCancelJobRunWithValidationException() {
doThrow(new ValidationException("Error")).when(emrServerless).cancelJobRun(any());
EmrServerlessClientImpl emrServerlessClient = new EmrServerlessClientImpl(emrServerless);
IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class,
() -> emrServerlessClient.closeJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID));
Assertions.assertEquals(
"Couldn't cancel the queryId: job-123xxx due to Error (Service: null; Status Code: 0; Error"
+ " Code: null; Request ID: null; Proxy: null)",
illegalArgumentException.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID;
import static org.opensearch.sql.spark.constants.TestConstants.QUERY;

import com.amazonaws.services.emrserverless.model.CancelJobRunResult;
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
import com.amazonaws.services.emrserverless.model.JobRun;
import com.amazonaws.services.emrserverless.model.JobRunState;
Expand Down Expand Up @@ -79,36 +80,17 @@ void testDispatchWithWrongURI() {
illegalArgumentException.getMessage());
}

private DataSourceMetadata constructMyGlueDataSourceMetadata() {
DataSourceMetadata dataSourceMetadata = new DataSourceMetadata();
dataSourceMetadata.setName("my_glue");
dataSourceMetadata.setConnector(DataSourceType.S3GLUE);
Map<String, String> properties = new HashMap<>();
properties.put("glue.auth.type", "iam_role");
properties.put(
"glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole");
properties.put(
"glue.indexstore.opensearch.uri",
"https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com");
properties.put("glue.indexstore.opensearch.auth", "sigv4");
properties.put("glue.indexstore.opensearch.region", "eu-west-1");
dataSourceMetadata.setProperties(properties);
return dataSourceMetadata;
}

private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() {
DataSourceMetadata dataSourceMetadata = new DataSourceMetadata();
dataSourceMetadata.setName("my_glue");
dataSourceMetadata.setConnector(DataSourceType.S3GLUE);
Map<String, String> properties = new HashMap<>();
properties.put("glue.auth.type", "iam_role");
properties.put(
"glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole");
properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param");
properties.put("glue.indexstore.opensearch.auth", "sigv4");
properties.put("glue.indexstore.opensearch.region", "eu-west-1");
dataSourceMetadata.setProperties(properties);
return dataSourceMetadata;
@Test
void testCancelJob() {
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(sparkJobClient, dataSourceService, jobExecutionResponseReader);
when(sparkJobClient.closeJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID))
.thenReturn(
new CancelJobRunResult()
.withJobRunId(EMR_JOB_ID)
.withApplicationId(EMRS_APPLICATION_ID));
String jobId = sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID);
Assertions.assertEquals(EMR_JOB_ID, jobId);
}

@Test
Expand Down Expand Up @@ -140,7 +122,7 @@ void testGetQueryResponseWithSuccess() {
Assertions.assertEquals("SUCCESS", result.get("status"));
}

String constructExpectedSparkSubmitParameterString() {
private String constructExpectedSparkSubmitParameterString() {
return " --class org.opensearch.sql.FlintJob --conf"
+ " spark.hadoop.fs.s3.customAWSCredentialsProvider=com.amazonaws.emr.AssumeRoleAWSCredentialsProvider"
+ " --conf"
Expand Down Expand Up @@ -171,4 +153,36 @@ String constructExpectedSparkSubmitParameterString() {
+ " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"
+ " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegateCatalog ";
}

private DataSourceMetadata constructMyGlueDataSourceMetadata() {
DataSourceMetadata dataSourceMetadata = new DataSourceMetadata();
dataSourceMetadata.setName("my_glue");
dataSourceMetadata.setConnector(DataSourceType.S3GLUE);
Map<String, String> properties = new HashMap<>();
properties.put("glue.auth.type", "iam_role");
properties.put(
"glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole");
properties.put(
"glue.indexstore.opensearch.uri",
"https://search-flint-dp-benchmark-cf5crj5mj2kfzvgwdeynkxnefy.eu-west-1.es.amazonaws.com");
properties.put("glue.indexstore.opensearch.auth", "sigv4");
properties.put("glue.indexstore.opensearch.region", "eu-west-1");
dataSourceMetadata.setProperties(properties);
return dataSourceMetadata;
}

private DataSourceMetadata constructMyGlueDataSourceMetadataWithBadURISyntax() {
DataSourceMetadata dataSourceMetadata = new DataSourceMetadata();
dataSourceMetadata.setName("my_glue");
dataSourceMetadata.setConnector(DataSourceType.S3GLUE);
Map<String, String> properties = new HashMap<>();
properties.put("glue.auth.type", "iam_role");
properties.put(
"glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole");
properties.put("glue.indexstore.opensearch.uri", "http://localhost:9090? param");
properties.put("glue.indexstore.opensearch.auth", "sigv4");
properties.put("glue.indexstore.opensearch.region", "eu-west-1");
dataSourceMetadata.setProperties(properties);
return dataSourceMetadata;
}
}
Loading

0 comments on commit 0e9a5f7

Please sign in to comment.