diff --git a/common/build.gradle b/common/build.gradle index 507ad6c0d6..0561468d1f 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -63,4 +63,4 @@ configurations.all { resolutionStrategy.force "org.apache.httpcomponents:httpcore:4.4.13" resolutionStrategy.force "joda-time:joda-time:2.10.12" resolutionStrategy.force "org.slf4j:slf4j-api:1.7.36" -} +} \ No newline at end of file diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 3d9740d84c..f714a8366b 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -306,9 +306,8 @@ private DataSourceServiceImpl createDataSourceService() { private AsyncQueryExecutorService createAsyncQueryExecutorService( SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier, SparkExecutionEngineConfig sparkExecutionEngineConfig) { - StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); EMRServerlessClient emrServerlessClient = createEMRServerlessClient(sparkExecutionEngineConfig.getRegion()); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); @@ -320,7 +319,8 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), client, - new SessionManager(stateStore, emrServerlessClient, pluginSettings)); + new SessionManager( + new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 18ae47c2b9..7cba2757cc 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -69,14 +69,13 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getSessionId())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata( - dispatchQueryResponse.getQueryId(), sparkExecutionEngineConfig.getApplicationId(), dispatchQueryResponse.getJobId(), dispatchQueryResponse.isDropIndexQuery(), dispatchQueryResponse.getResultIndex(), dispatchQueryResponse.getSessionId())); return new CreateAsyncQueryResponse( - dispatchQueryResponse.getQueryId().getId(), dispatchQueryResponse.getSessionId()); + dispatchQueryResponse.getJobId(), dispatchQueryResponse.getSessionId()); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java index 6de8c35f03..a95a6ffe45 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -7,31 +7,166 @@ package org.opensearch.sql.spark.asyncquery; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createJobMetaData; - +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; -import lombok.RequiredArgsConstructor; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; +import org.apache.commons.io.IOUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.execution.statestore.StateStore; /** Opensearch implementation of {@link AsyncQueryJobMetadataStorageService} */ -@RequiredArgsConstructor public class OpensearchAsyncQueryJobMetadataStorageService implements AsyncQueryJobMetadataStorageService { - private final StateStore stateStore; + public static final String JOB_METADATA_INDEX = ".ql-job-metadata"; + private static final String JOB_METADATA_INDEX_MAPPING_FILE_NAME = + "job-metadata-index-mapping.yml"; + private static final String JOB_METADATA_INDEX_SETTINGS_FILE_NAME = + "job-metadata-index-settings.yml"; + private static final Logger LOG = LogManager.getLogger(); + private final Client client; + private final ClusterService clusterService; + + /** + * This class implements JobMetadataStorageService interface using OpenSearch as underlying + * storage. + * + * @param client opensearch NodeClient. + * @param clusterService ClusterService. + */ + public OpensearchAsyncQueryJobMetadataStorageService( + Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { - AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); - createJobMetaData(stateStore, queryId.getDataSourceName()).apply(asyncQueryJobMetadata); + if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { + createJobMetadataIndex(); + } + IndexRequest indexRequest = new IndexRequest(JOB_METADATA_INDEX); + indexRequest.id(asyncQueryJobMetadata.getJobId()); + indexRequest.opType(DocWriteRequest.OpType.CREATE); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + ActionFuture indexResponseActionFuture; + IndexResponse indexResponse; + try (ThreadContext.StoredContext storedContext = + client.threadPool().getThreadContext().stashContext()) { + indexRequest.source(AsyncQueryJobMetadata.convertToXContent(asyncQueryJobMetadata)); + indexResponseActionFuture = client.index(indexRequest); + indexResponse = indexResponseActionFuture.actionGet(); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("JobMetadata : {} successfully created", asyncQueryJobMetadata.getJobId()); + } else { + throw new RuntimeException( + "Saving job metadata information failed with result : " + + indexResponse.getResult().getLowercase()); + } } @Override - public Optional getJobMetadata(String qid) { - AsyncQueryId queryId = new AsyncQueryId(qid); - return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) - .apply(queryId.docId()); + public Optional getJobMetadata(String jobId) { + if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { + createJobMetadataIndex(); + return Optional.empty(); + } + return searchInJobMetadataIndex(QueryBuilders.termQuery("jobId.keyword", jobId)).stream() + .findFirst(); + } + + private void createJobMetadataIndex() { + try { + InputStream mappingFileStream = + OpensearchAsyncQueryJobMetadataStorageService.class + .getClassLoader() + .getResourceAsStream(JOB_METADATA_INDEX_MAPPING_FILE_NAME); + InputStream settingsFileStream = + OpensearchAsyncQueryJobMetadataStorageService.class + .getClassLoader() + .getResourceAsStream(JOB_METADATA_INDEX_SETTINGS_FILE_NAME); + CreateIndexRequest createIndexRequest = new CreateIndexRequest(JOB_METADATA_INDEX); + createIndexRequest + .mapping(IOUtils.toString(mappingFileStream, StandardCharsets.UTF_8), XContentType.YAML) + .settings( + IOUtils.toString(settingsFileStream, StandardCharsets.UTF_8), XContentType.YAML); + ActionFuture createIndexResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); + } + CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); + if (createIndexResponse.isAcknowledged()) { + LOG.info("Index: {} creation Acknowledged", JOB_METADATA_INDEX); + } else { + throw new RuntimeException("Index creation is not acknowledged."); + } + } catch (Throwable e) { + throw new RuntimeException( + "Internal server error while creating" + + JOB_METADATA_INDEX + + " index:: " + + e.getMessage()); + } + } + + private List searchInJobMetadataIndex(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(JOB_METADATA_INDEX); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchSourceBuilder.size(1); + searchRequest.source(searchSourceBuilder); + // https://github.com/opensearch-project/sql/issues/1801. + searchRequest.preference("_primary_first"); + ActionFuture searchResponseActionFuture; + try (ThreadContext.StoredContext ignored = + client.threadPool().getThreadContext().stashContext()) { + searchResponseActionFuture = client.search(searchRequest); + } + SearchResponse searchResponse = searchResponseActionFuture.actionGet(); + if (searchResponse.status().getStatus() != 200) { + throw new RuntimeException( + "Fetching job metadata information failed with status : " + searchResponse.status()); + } else { + List list = new ArrayList<>(); + for (SearchHit searchHit : searchResponse.getHits().getHits()) { + String sourceAsString = searchHit.getSourceAsString(); + AsyncQueryJobMetadata asyncQueryJobMetadata; + try { + asyncQueryJobMetadata = AsyncQueryJobMetadata.toJobMetadata(sourceAsString); + } catch (IOException e) { + throw new RuntimeException(e); + } + list.add(asyncQueryJobMetadata); + } + return list; + } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java deleted file mode 100644 index b99ebe0e8c..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.asyncquery.model; - -import static org.opensearch.sql.spark.utils.IDUtils.decode; -import static org.opensearch.sql.spark.utils.IDUtils.encode; - -import lombok.Data; - -/** Async query id. */ -@Data -public class AsyncQueryId { - private final String id; - - public static AsyncQueryId newAsyncQueryId(String datasourceName) { - return new AsyncQueryId(encode(datasourceName)); - } - - public String getDataSourceName() { - return decode(id); - } - - /** OpenSearch DocId. */ - public String docId() { - return "qid" + id; - } - - @Override - public String toString() { - return "asyncQueryId=" + id; - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index 3c59403661..b80fefa173 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -8,83 +8,37 @@ package org.opensearch.sql.spark.asyncquery.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.sql.spark.execution.statement.StatementModel.QUERY_ID; import com.google.gson.Gson; import java.io.IOException; +import lombok.AllArgsConstructor; import lombok.Data; import lombok.EqualsAndHashCode; -import lombok.SneakyThrows; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.seqno.SequenceNumbers; -import org.opensearch.sql.spark.execution.statestore.StateModel; /** This class models all the metadata required for a job. */ @Data -@EqualsAndHashCode(callSuper = false) -public class AsyncQueryJobMetadata extends StateModel { - public static final String TYPE_JOBMETA = "jobmeta"; - - private final AsyncQueryId queryId; - private final String applicationId; - private final String jobId; - private final boolean isDropIndexQuery; - private final String resultIndex; +@AllArgsConstructor +@EqualsAndHashCode +public class AsyncQueryJobMetadata { + private String applicationId; + private String jobId; + private boolean isDropIndexQuery; + private String resultIndex; // optional sessionId. - private final String sessionId; - - @EqualsAndHashCode.Exclude private final long seqNo; - @EqualsAndHashCode.Exclude private final long primaryTerm; + private String sessionId; - public AsyncQueryJobMetadata( - AsyncQueryId queryId, String applicationId, String jobId, String resultIndex) { - this( - queryId, - applicationId, - jobId, - false, - resultIndex, - null, - SequenceNumbers.UNASSIGNED_SEQ_NO, - SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - } - - public AsyncQueryJobMetadata( - AsyncQueryId queryId, - String applicationId, - String jobId, - boolean isDropIndexQuery, - String resultIndex, - String sessionId) { - this( - queryId, - applicationId, - jobId, - isDropIndexQuery, - resultIndex, - sessionId, - SequenceNumbers.UNASSIGNED_SEQ_NO, - SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - } - - public AsyncQueryJobMetadata( - AsyncQueryId queryId, - String applicationId, - String jobId, - boolean isDropIndexQuery, - String resultIndex, - String sessionId, - long seqNo, - long primaryTerm) { - this.queryId = queryId; + public AsyncQueryJobMetadata(String applicationId, String jobId, String resultIndex) { this.applicationId = applicationId; this.jobId = jobId; - this.isDropIndexQuery = isDropIndexQuery; + this.isDropIndexQuery = false; this.resultIndex = resultIndex; - this.sessionId = sessionId; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; + this.sessionId = null; } @Override @@ -95,36 +49,39 @@ public String toString() { /** * Converts JobMetadata to XContentBuilder. * + * @param metadata metadata. * @return XContentBuilder {@link XContentBuilder} * @throws Exception Exception. */ - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder - .startObject() - .field(QUERY_ID, queryId.getId()) - .field("type", TYPE_JOBMETA) - .field("jobId", jobId) - .field("applicationId", applicationId) - .field("isDropIndexQuery", isDropIndexQuery) - .field("resultIndex", resultIndex) - .field("sessionId", sessionId) - .endObject(); + public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field("jobId", metadata.getJobId()); + builder.field("applicationId", metadata.getApplicationId()); + builder.field("isDropIndexQuery", metadata.isDropIndexQuery()); + builder.field("resultIndex", metadata.getResultIndex()); + builder.field("sessionId", metadata.getSessionId()); + builder.endObject(); return builder; } - /** copy builder. update seqNo and primaryTerm */ - public static AsyncQueryJobMetadata copy( - AsyncQueryJobMetadata copy, long seqNo, long primaryTerm) { - return new AsyncQueryJobMetadata( - copy.getQueryId(), - copy.getApplicationId(), - copy.getJobId(), - copy.isDropIndexQuery(), - copy.getResultIndex(), - copy.getSessionId(), - seqNo, - primaryTerm); + /** + * Converts json string to DataSourceMetadata. + * + * @param json jsonstring. + * @return jobmetadata {@link AsyncQueryJobMetadata} + * @throws java.io.IOException IOException. + */ + public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOException { + try (XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + json)) { + return toJobMetadata(parser); + } } /** @@ -134,23 +91,17 @@ public static AsyncQueryJobMetadata copy( * @return JobMetadata {@link AsyncQueryJobMetadata} * @throws IOException IOException. */ - @SneakyThrows - public static AsyncQueryJobMetadata fromXContent( - XContentParser parser, long seqNo, long primaryTerm) { - AsyncQueryId queryId = null; + public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws IOException { String jobId = null; String applicationId = null; boolean isDropIndexQuery = false; String resultIndex = null; String sessionId = null; - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); parser.nextToken(); switch (fieldName) { - case QUERY_ID: - queryId = new AsyncQueryId(parser.textOrNull()); - break; case "jobId": jobId = parser.textOrNull(); break; @@ -166,8 +117,6 @@ public static AsyncQueryJobMetadata fromXContent( case "sessionId": sessionId = parser.textOrNull(); break; - case "type": - break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -176,18 +125,6 @@ public static AsyncQueryJobMetadata fromXContent( throw new IllegalArgumentException("jobId and applicationId are required fields."); } return new AsyncQueryJobMetadata( - queryId, - applicationId, - jobId, - isDropIndexQuery, - resultIndex, - sessionId, - seqNo, - primaryTerm); - } - - @Override - public String getId() { - return queryId.docId(); + applicationId, jobId, isDropIndexQuery, resultIndex, sessionId); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java deleted file mode 100644 index 77a0e1cd09..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.dispatcher; - -import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; -import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; -import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; - -import com.amazonaws.services.emrserverless.model.JobRunState; -import org.json.JSONObject; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; - -/** Process async query request. */ -public abstract class AsyncQueryHandler { - - public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { - if (asyncQueryJobMetadata.isDropIndexQuery()) { - return SparkQueryDispatcher.DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()) - .result(); - } - - JSONObject result = getResponseFromResultIndex(asyncQueryJobMetadata); - if (result.has(DATA_FIELD)) { - JSONObject items = result.getJSONObject(DATA_FIELD); - - // If items have STATUS_FIELD, use it; otherwise, mark failed - String status = items.optString(STATUS_FIELD, JobRunState.FAILED.toString()); - result.put(STATUS_FIELD, status); - - // If items have ERROR_FIELD, use it; otherwise, set empty string - String error = items.optString(ERROR_FIELD, ""); - result.put(ERROR_FIELD, error); - return result; - } else { - return getResponseFromExecutor(asyncQueryJobMetadata); - } - } - - protected abstract JSONObject getResponseFromResultIndex( - AsyncQueryJobMetadata asyncQueryJobMetadata); - - protected abstract JSONObject getResponseFromExecutor( - AsyncQueryJobMetadata asyncQueryJobMetadata); - - abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata); -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java deleted file mode 100644 index 8a582278e1..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.dispatcher; - -import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; -import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; - -import com.amazonaws.services.emrserverless.model.GetJobRunResult; -import lombok.RequiredArgsConstructor; -import org.json.JSONObject; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.response.JobExecutionResponseReader; - -@RequiredArgsConstructor -public class BatchQueryHandler extends AsyncQueryHandler { - private final EMRServerlessClient emrServerlessClient; - private final JobExecutionResponseReader jobExecutionResponseReader; - - @Override - protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { - // either empty json when the result is not available or data with status - // Fetch from Result Index - return jobExecutionResponseReader.getResultFromOpensearchIndex( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); - } - - @Override - protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { - JSONObject result = new JSONObject(); - // make call to EMR Serverless when related result index documents are not available - GetJobRunResult getJobRunResult = - emrServerlessClient.getJobRunResult( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - String jobState = getJobRunResult.getJobRun().getState(); - result.put(STATUS_FIELD, jobState); - result.put(ERROR_FIELD, ""); - return result; - } - - @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - emrServerlessClient.cancelJobRun( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - return asyncQueryJobMetadata.getQueryId().getId(); - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java deleted file mode 100644 index 24ea1528c8..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.dispatcher; - -import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; -import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; - -import java.util.Optional; -import lombok.RequiredArgsConstructor; -import org.json.JSONObject; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.execution.session.Session; -import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statement.Statement; -import org.opensearch.sql.spark.execution.statement.StatementId; -import org.opensearch.sql.spark.execution.statement.StatementState; -import org.opensearch.sql.spark.response.JobExecutionResponseReader; - -@RequiredArgsConstructor -public class InteractiveQueryHandler extends AsyncQueryHandler { - private final SessionManager sessionManager; - private final JobExecutionResponseReader jobExecutionResponseReader; - - @Override - protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { - String queryId = asyncQueryJobMetadata.getQueryId().getId(); - return jobExecutionResponseReader.getResultWithQueryId( - queryId, asyncQueryJobMetadata.getResultIndex()); - } - - @Override - protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { - JSONObject result = new JSONObject(); - String queryId = asyncQueryJobMetadata.getQueryId().getId(); - Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId); - StatementState statementState = statement.getStatementState(); - result.put(STATUS_FIELD, statementState.getState()); - result.put(ERROR_FIELD, ""); - return result; - } - - @Override - public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { - String queryId = asyncQueryJobMetadata.getQueryId().getId(); - getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId).cancel(); - return queryId; - } - - private Statement getStatementByQueryId(String sid, String qid) { - SessionId sessionId = new SessionId(sid); - Optional session = sessionManager.getSession(sessionId); - if (session.isPresent()) { - // todo, statementId == jobId if statement running in session. - StatementId statementId = new StatementId(qid); - Optional statement = session.get().get(statementId); - if (statement.isPresent()) { - return statement.get(); - } else { - throw new IllegalArgumentException("no statement found. " + statementId); - } - } else { - throw new IllegalArgumentException("no session found. " + sessionId); - } - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 882f2663d9..2bd1ae67b9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -10,6 +10,8 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRunState; import java.nio.charset.StandardCharsets; import java.util.Base64; @@ -31,7 +33,6 @@ import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; @@ -45,6 +46,9 @@ import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -88,22 +92,97 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) } public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { - if (asyncQueryJobMetadata.getSessionId() != null) { - return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader) - .getQueryResponse(asyncQueryJobMetadata); + // todo. refactor query process logic in plugin. + if (asyncQueryJobMetadata.isDropIndexQuery()) { + return DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()).result(); + } + + JSONObject result; + if (asyncQueryJobMetadata.getSessionId() == null) { + // either empty json when the result is not available or data with status + // Fetch from Result Index + result = + jobExecutionResponseReader.getResultFromOpensearchIndex( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); } else { - return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader) - .getQueryResponse(asyncQueryJobMetadata); + // when session enabled, jobId in asyncQueryJobMetadata is actually queryId. + result = + jobExecutionResponseReader.getResultWithQueryId( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); } + // if result index document has a status, we are gonna use the status directly; otherwise, we + // will use emr-s job status. + // That a job is successful does not mean there is no error in execution. For example, even if + // result + // index mapping is incorrect, we still write query result and let the job finish. + // That a job is running does not mean the status is running. For example, index/streaming Query + // is a + // long-running job which runs forever. But we need to return success from the result index + // immediately. + if (result.has(DATA_FIELD)) { + JSONObject items = result.getJSONObject(DATA_FIELD); + + // If items have STATUS_FIELD, use it; otherwise, mark failed + String status = items.optString(STATUS_FIELD, JobRunState.FAILED.toString()); + result.put(STATUS_FIELD, status); + + // If items have ERROR_FIELD, use it; otherwise, set empty string + String error = items.optString(ERROR_FIELD, ""); + result.put(ERROR_FIELD, error); + } else { + if (asyncQueryJobMetadata.getSessionId() != null) { + SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + StatementState statementState = statement.get().getStatementState(); + result.put(STATUS_FIELD, statementState.getState()); + result.put(ERROR_FIELD, ""); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } + } else { + // make call to EMR Serverless when related result index documents are not available + GetJobRunResult getJobRunResult = + emrServerlessClient.getJobRunResult( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + String jobState = getJobRunResult.getJobRun().getState(); + result.put(STATUS_FIELD, jobState); + result.put(ERROR_FIELD, ""); + } + } + + return result; } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { if (asyncQueryJobMetadata.getSessionId() != null) { - return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader) - .cancelJob(asyncQueryJobMetadata); + SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + statement.get().cancel(); + return statementId.getId(); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } } else { - return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader) - .cancelJob(asyncQueryJobMetadata); + CancelJobRunResult cancelJobRunResult = + emrServerlessClient.cancelJobRun( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + return cancelJobRunResult.getJobRunId(); } } @@ -150,18 +229,12 @@ private DispatchQueryResponse handleIndexQuery( indexDetails.getAutoRefresh(), dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse( - AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), - jobId, - false, - dataSourceMetadata.getResultIndex(), - null); + return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); } private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); - AsyncQueryId queryId = AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); @@ -194,12 +267,12 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceMetadata.getResultIndex(), dataSourceMetadata.getName())); } - session.submit( - new QueryRequest( - queryId, dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); + StatementId statementId = + session.submit( + new QueryRequest( + dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); return new DispatchQueryResponse( - queryId, - session.getSessionModel().getJobId(), + statementId.getId(), false, dataSourceMetadata.getResultIndex(), session.getSessionId().getSessionId()); @@ -221,8 +294,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ false, dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse( - queryId, jobId, false, dataSourceMetadata.getResultIndex(), null); + return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); } } @@ -253,11 +325,7 @@ private DispatchQueryResponse handleDropIndexQuery( } } return new DispatchQueryResponse( - AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), - new DropIndexResult(status).toJobId(), - true, - dataSourceMetadata.getResultIndex(), - null); + new DropIndexResult(status).toJobId(), true, dataSourceMetadata.getResultIndex(), null); } private static Map getDefaultTagsForJobSubmission( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java index e44379daff..893446c617 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -2,12 +2,10 @@ import lombok.AllArgsConstructor; import lombok.Data; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; @Data @AllArgsConstructor public class DispatchQueryResponse { - private AsyncQueryId queryId; private String jobId; private boolean isDropIndexQuery; private String resultIndex; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index a2e7cfe6ee..4428c3b83d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -81,8 +81,7 @@ public StatementId submit(QueryRequest request) { } else { sessionModel = model.get(); if (!END_STATE.contains(sessionModel.getSessionState())) { - String qid = request.getQueryId().getId(); - StatementId statementId = newStatementId(qid); + StatementId statementId = newStatementId(); Statement st = Statement.builder() .sessionId(sessionId) @@ -93,7 +92,7 @@ public StatementId submit(QueryRequest request) { .langType(LangType.SQL) .datasourceName(sessionModel.getDatasourceName()) .query(request.getQuery()) - .queryId(qid) + .queryId(statementId.getId()) .build(); st.open(); return statementId; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java index c85e4dd35c..b3bd716925 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -5,10 +5,10 @@ package org.opensearch.sql.spark.execution.session; -import static org.opensearch.sql.spark.utils.IDUtils.decode; -import static org.opensearch.sql.spark.utils.IDUtils.encode; - +import java.nio.charset.StandardCharsets; +import java.util.Base64; import lombok.Data; +import org.apache.commons.lang3.RandomStringUtils; @Data public class SessionId { @@ -24,6 +24,15 @@ public String getDataSourceName() { return decode(sessionId); } + private static String decode(String sessionId) { + return new String(Base64.getDecoder().decode(sessionId)).substring(PREFIX_LEN); + } + + private static String encode(String datasourceName) { + String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; + return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); + } + @Override public String toString() { return "sessionId=" + sessionId; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java index c365265224..10061404ca 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java @@ -6,12 +6,10 @@ package org.opensearch.sql.spark.execution.statement; import lombok.Data; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.rest.model.LangType; @Data public class QueryRequest { - private final AsyncQueryId queryId; private final LangType langType; private final String query; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java index 33284c4b3d..d9381ad45f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java @@ -6,14 +6,14 @@ package org.opensearch.sql.spark.execution.statement; import lombok.Data; +import org.apache.commons.lang3.RandomStringUtils; @Data public class StatementId { private final String id; - // construct statementId from queryId. - public static StatementId newStatementId(String qid) { - return new StatementId(qid); + public static StatementId newStatementId() { + return new StatementId(RandomStringUtils.randomAlphanumeric(16)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index 6546d303fb..a36ee3ef45 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -38,7 +38,6 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statement.StatementModel; @@ -54,6 +53,7 @@ public class StateStore { public static String MAPPING_FILE_NAME = "query_execution_request_mapping.yml"; public static Function DATASOURCE_TO_REQUEST_INDEX = datasourceName -> String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); + public static String ALL_REQUEST_INDEX = String.format("%s_*", SPARK_REQUEST_BUFFER_INDEX_NAME); private static final Logger LOG = LogManager.getLogger(); @@ -77,6 +77,7 @@ protected T create( try (ThreadContext.StoredContext ignored = client.threadPool().getThreadContext().stashContext()) { IndexResponse indexResponse = client.index(indexRequest).actionGet(); + ; if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { LOG.debug("Successfully created doc. id: {}", st.getId()); return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); @@ -226,6 +227,10 @@ public static Function> getSession( docId, SessionModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } + public static Function> searchSession(StateStore stateStore) { + return (docId) -> stateStore.get(docId, SessionModel::fromXContent, ALL_REQUEST_INDEX); + } + public static BiFunction updateSessionState( StateStore stateStore, String datasourceName) { return (old, state) -> @@ -236,21 +241,8 @@ public static BiFunction updateSession DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function createJobMetaData( - StateStore stateStore, String datasourceName) { - return (jobMetadata) -> - stateStore.create( - jobMetadata, - AsyncQueryJobMetadata::copy, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function> getJobMetaData( - StateStore stateStore, String datasourceName) { - return (docId) -> - stateStore.get( - docId, - AsyncQueryJobMetadata::fromXContent, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + public static Runnable createStateStoreIndex(StateStore stateStore, String datasourceName) { + String indexName = String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); + return () -> stateStore.createIndex(indexName); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java b/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java deleted file mode 100644 index 438d2342b4..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.utils; - -import java.nio.charset.StandardCharsets; -import java.util.Base64; -import lombok.experimental.UtilityClass; -import org.apache.commons.lang3.RandomStringUtils; - -@UtilityClass -public class IDUtils { - public static final int PREFIX_LEN = 10; - - public static String decode(String id) { - return new String(Base64.getDecoder().decode(id)).substring(PREFIX_LEN); - } - - public static String encode(String datasourceName) { - String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; - return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); - } -} diff --git a/spark/src/main/resources/job-metadata-index-mapping.yml b/spark/src/main/resources/job-metadata-index-mapping.yml new file mode 100644 index 0000000000..3a39b989a2 --- /dev/null +++ b/spark/src/main/resources/job-metadata-index-mapping.yml @@ -0,0 +1,25 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Schema file for the .ql-job-metadata index +# Also "dynamic" is set to "false" so that other fields can be added. +dynamic: false +properties: + jobId: + type: text + fields: + keyword: + type: keyword + applicationId: + type: text + fields: + keyword: + type: keyword + resultIndex: + type: text + fields: + keyword: + type: keyword \ No newline at end of file diff --git a/spark/src/main/resources/job-metadata-index-settings.yml b/spark/src/main/resources/job-metadata-index-settings.yml new file mode 100644 index 0000000000..be93f4645c --- /dev/null +++ b/spark/src/main/resources/job-metadata-index-settings.yml @@ -0,0 +1,11 @@ +--- +## +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +## + +# Settings file for the .ql-job-metadata index +index: + number_of_shards: "1" + auto_expand_replicas: "0-2" + number_of_replicas: "0" \ No newline at end of file diff --git a/spark/src/main/resources/query_execution_request_mapping.yml b/spark/src/main/resources/query_execution_request_mapping.yml index fbe90a1cba..87bd927e6e 100644 --- a/spark/src/main/resources/query_execution_request_mapping.yml +++ b/spark/src/main/resources/query_execution_request_mapping.yml @@ -8,8 +8,6 @@ # Also "dynamic" is set to "false" so that other fields can be added. dynamic: false properties: - version: - type: keyword type: type: keyword state: diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 1ee119df78..3eb8958eb2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -284,9 +284,8 @@ private DataSourceServiceImpl createDataSourceService() { private AsyncQueryExecutorService createAsyncQueryExecutorService( EMRServerlessClient emrServerlessClient) { - StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( @@ -296,7 +295,8 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), client, - new SessionManager(stateStore, emrServerlessClient, pluginSettings)); + new SessionManager( + new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 2ed316795f..0d4e280b61 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -11,7 +11,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; -import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; @@ -30,7 +29,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; @@ -49,7 +47,6 @@ public class AsyncQueryExecutorServiceImplTest { private AsyncQueryExecutorService jobExecutorService; @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; - private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); @BeforeEach void setUp() { @@ -81,12 +78,11 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) - .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, false, null, null)); + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) - .storeJobMetadata( - new AsyncQueryJobMetadata(QUERY_ID, "00fd775baqpu4g0p", EMR_JOB_ID, null)); + .storeJobMetadata(new AsyncQueryJobMetadata("00fd775baqpu4g0p", EMR_JOB_ID, null)); verify(sparkExecutionEngineConfigSupplier, times(1)).getSparkExecutionEngineConfig(); verify(sparkQueryDispatcher, times(1)) .dispatch( @@ -97,7 +93,7 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME)); - Assertions.assertEquals(QUERY_ID.getId(), createAsyncQueryResponse.getQueryId()); + Assertions.assertEquals(EMR_JOB_ID, createAsyncQueryResponse.getQueryId()); } @Test @@ -111,7 +107,7 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { "--conf spark.dynamicAllocation.enabled=false", TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) - .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, false, null, null)); + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( @@ -143,13 +139,11 @@ void testGetAsyncQueryResultsWithJobNotFoundException() { @Test void testGetAsyncQueryResultsWithInProgressJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn( - Optional.of( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(); jobResult.put("status", JobRunState.PENDING.toString()); when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); @@ -163,13 +157,11 @@ void testGetAsyncQueryResultsWithInProgressJob() { @Test void testGetAsyncQueryResultsWithSuccessJob() throws IOException { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn( - Optional.of( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); jobResult.put("status", JobRunState.SUCCESS.toString()); when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -216,11 +208,9 @@ void testCancelJobWithJobNotFound() { @Test void testCancelJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn( - Optional.of( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); when(sparkQueryDispatcher.cancelJob( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(EMR_JOB_ID); String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, jobId); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index de0caf5589..7288fd3fc2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -5,70 +5,242 @@ package org.opensearch.sql.spark.asyncquery; +import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService.JOB_METADATA_INDEX; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import java.util.Optional; -import org.junit.Before; -import org.junit.Test; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; +import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.ArgumentMatchers; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.test.OpenSearchIntegTestCase; -public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest - extends OpenSearchIntegTestCase { +@ExtendWith(MockitoExtension.class) +public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest { - public static final String DS_NAME = "mys3"; - private static final String MOCK_SESSION_ID = "sessionId"; - private static final String MOCK_RESULT_INDEX = "resultIndex"; + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Client client; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private ClusterService clusterService; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private SearchResponse searchResponse; + + @Mock private ActionFuture searchResponseActionFuture; + @Mock private ActionFuture createIndexResponseActionFuture; + @Mock private ActionFuture indexResponseActionFuture; + @Mock private IndexResponse indexResponse; + @Mock private SearchHit searchHit; + + @InjectMocks private OpensearchAsyncQueryJobMetadataStorageService opensearchJobMetadataStorageService; - @Before - public void setup() { - opensearchJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService( - new StateStore(client(), clusterService())); + @Test + public void testStoreJobMetadata() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); + + this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); } @Test - public void testStoreJobMetadata() { - AsyncQueryJobMetadata expected = - new AsyncQueryJobMetadata( - AsyncQueryId.newAsyncQueryId(DS_NAME), - EMR_JOB_ID, - EMRS_APPLICATION_ID, - MOCK_RESULT_INDEX); - - opensearchJobMetadataStorageService.storeJobMetadata(expected); - Optional actual = - opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); - - assertTrue(actual.isPresent()); - assertEquals(expected, actual.get()); - assertFalse(actual.get().isDropIndexQuery()); - assertNull(actual.get().getSessionId()); + public void testStoreJobMetadataWithOutCreatingIndex() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.TRUE); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); + + this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); + + Mockito.verify(client.admin().indices(), Mockito.times(0)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); + } + + @Test + public void testStoreJobMetadataWithException() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())) + .thenThrow(new RuntimeException("error while indexing")); + + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); + Assertions.assertEquals( + "java.lang.RuntimeException: error while indexing", runtimeException.getMessage()); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); + } + + @Test + public void testStoreJobMetadataWithIndexCreationFailed() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(false, false, JOB_METADATA_INDEX)); + + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); + Assertions.assertEquals( + "Internal server error while creating.ql-job-metadata index:: " + + "Index creation is not acknowledged.", + runtimeException.getMessage()); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); + } + + @Test + public void testStoreJobMetadataFailedWithNotFoundResponse() { + + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); + Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); + Assertions.assertEquals( + "Saving job metadata information failed with result : not_found", + runtimeException.getMessage()); + + Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); + Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); + Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); + } + + @Test + public void testGetJobMetadata() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(true); + Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); + Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); + Mockito.when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); + AsyncQueryJobMetadata asyncQueryJobMetadata = + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null); + Mockito.when(searchHit.getSourceAsString()).thenReturn(asyncQueryJobMetadata.toString()); + + Optional jobMetadataOptional = + opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); + Assertions.assertTrue(jobMetadataOptional.isPresent()); + Assertions.assertEquals(EMR_JOB_ID, jobMetadataOptional.get().getJobId()); + Assertions.assertEquals(EMRS_APPLICATION_ID, jobMetadataOptional.get().getApplicationId()); + } + + @Test + public void testGetJobMetadataWith404SearchResponse() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(true); + Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); + Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + Mockito.when(searchResponse.status()).thenReturn(RestStatus.NOT_FOUND); + + RuntimeException runtimeException = + Assertions.assertThrows( + RuntimeException.class, + () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); + Assertions.assertEquals( + "Fetching job metadata information failed with status : NOT_FOUND", + runtimeException.getMessage()); } @Test - public void testStoreJobMetadataWithResultExtraData() { - AsyncQueryJobMetadata expected = - new AsyncQueryJobMetadata( - AsyncQueryId.newAsyncQueryId(DS_NAME), - EMR_JOB_ID, - EMRS_APPLICATION_ID, - true, - MOCK_RESULT_INDEX, - MOCK_SESSION_ID); - - opensearchJobMetadataStorageService.storeJobMetadata(expected); - Optional actual = - opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); - - assertTrue(actual.isPresent()); - assertEquals(expected, actual.get()); - assertTrue(actual.get().isDropIndexQuery()); - assertEquals("resultIndex", actual.get().getResultIndex()); - assertEquals(MOCK_SESSION_ID, actual.get().getSessionId()); + public void testGetJobMetadataWithParsingFailed() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(true); + Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); + Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); + Mockito.when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); + Mockito.when(searchHit.getSourceAsString()).thenReturn("..tesJOBs"); + + Assertions.assertThrows( + RuntimeException.class, + () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); + } + + @Test + public void testGetJobMetadataWithNoIndex() { + Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) + .thenReturn(Boolean.FALSE); + Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) + .thenReturn(createIndexResponseActionFuture); + Mockito.when(createIndexResponseActionFuture.actionGet()) + .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); + Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); + + Optional jobMetadata = + opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); + + Assertions.assertFalse(jobMetadata.isPresent()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 4acccae0e2..15211dec01 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -5,7 +5,6 @@ package org.opensearch.sql.spark.dispatcher; -import static org.mockito.Answers.RETURNS_DEEP_STUBS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; @@ -20,7 +19,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; -import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; @@ -49,6 +47,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -59,7 +58,6 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; @@ -88,22 +86,19 @@ public class SparkQueryDispatcherTest { @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; @Mock private FlintIndexMetadataReader flintIndexMetadataReader; - @Mock(answer = RETURNS_DEEP_STUBS) + @Mock(answer = Answers.RETURNS_DEEP_STUBS) private Client openSearchClient; @Mock private FlintIndexMetadata flintIndexMetadata; @Mock private SessionManager sessionManager; - @Mock(answer = RETURNS_DEEP_STUBS) - private Session session; + @Mock private Session session; @Mock private Statement statement; private SparkQueryDispatcher sparkQueryDispatcher; - private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); - @Captor ArgumentCaptor startJobRequestArgumentCaptor; @BeforeEach @@ -290,7 +285,6 @@ void testDispatchSelectQueryCreateNewSession() { doReturn(session).when(sessionManager).createSession(any()); doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); - when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -298,7 +292,7 @@ void testDispatchSelectQueryCreateNewSession() { verifyNoInteractions(emrServerlessClient); verify(sessionManager, never()).getSession(any()); - Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); } @@ -313,7 +307,6 @@ void testDispatchSelectQueryReuseSession() { .getSession(eq(new SessionId(MOCK_SESSION_ID))); doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); - when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -321,7 +314,7 @@ void testDispatchSelectQueryReuseSession() { verifyNoInteractions(emrServerlessClient); verify(sessionManager, never()).createSession(any()); - Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); } @@ -643,8 +636,10 @@ void testCancelJob() { new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); - Assertions.assertEquals(QUERY_ID.getId(), queryId); + String jobId = + sparkQueryDispatcher.cancelJob( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); + Assertions.assertEquals(EMR_JOB_ID, jobId); } @Test @@ -703,8 +698,10 @@ void testCancelQueryWithNoSessionId() { new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); - Assertions.assertEquals(QUERY_ID.getId(), queryId); + String jobId = + sparkQueryDispatcher.cancelJob( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); + Assertions.assertEquals(EMR_JOB_ID, jobId); } @Test @@ -715,7 +712,9 @@ void testGetQueryResponse() { // simulate result index is not created yet when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(new JSONObject()); - JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); + JSONObject result = + sparkQueryDispatcher.getQueryResponse( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); Assertions.assertEquals("PENDING", result.get("status")); } @@ -791,7 +790,9 @@ void testGetQueryResponseWithSuccess() { queryResult.put(DATA_FIELD, resultMap); when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(queryResult); - JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); + JSONObject result = + sparkQueryDispatcher.getQueryResponse( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID, null); Assertions.assertEquals( new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet()); @@ -826,13 +827,7 @@ void testGetQueryResponseOfDropIndex() { JSONObject result = sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata( - AsyncQueryId.newAsyncQueryId(DS_NAME), - EMRS_APPLICATION_ID, - jobId, - true, - null, - null)); + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, jobId, true, null, null)); verify(jobExecutionResponseReader, times(0)) .getResultFromOpensearchIndex(anyString(), anyString()); Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD)); @@ -1215,13 +1210,8 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str sessionId); } - private AsyncQueryJobMetadata asyncQueryJobMetadata() { - return new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null); - } - private AsyncQueryJobMetadata asyncQueryJobMetadataWithSessionId( - String statementId, String sessionId) { - return new AsyncQueryJobMetadata( - new AsyncQueryId(statementId), EMRS_APPLICATION_ID, EMR_JOB_ID, false, null, sessionId); + String queryId, String sessionId) { + return new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, queryId, false, null, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 1e33c8a6b9..ff3ddd1bef 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -22,7 +22,6 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -209,7 +208,7 @@ public void submitStatementInRunningSession() { // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); - StatementId statementId = session.submit(queryRequest()); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); assertFalse(statementId.getId().isEmpty()); } @@ -219,7 +218,7 @@ public void submitStatementInNotStartedState() { new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(createSessionRequest()); - StatementId statementId = session.submit(queryRequest()); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); assertFalse(statementId.getId().isEmpty()); } @@ -232,7 +231,9 @@ public void failToSubmitStatementInDeadState() { updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); assertEquals( "can't submit statement, session should not be in end state, current session state is:" + " dead", @@ -248,7 +249,9 @@ public void failToSubmitStatementInFailState() { updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); assertEquals( "can't submit statement, session should not be in end state, current session state is:" + " fail", @@ -260,7 +263,7 @@ public void newStatementFieldAssert() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(createSessionRequest()); - StatementId statementId = session.submit(queryRequest()); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); assertTrue(statement.isPresent()); @@ -285,7 +288,9 @@ public void failToSubmitStatementInDeletedSession() { .actionGet(); IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); assertEquals("session does not exist. " + session.getSessionId(), exception.getMessage()); } @@ -296,7 +301,7 @@ public void getStatementSuccess() { .createSession(createSessionRequest()); // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); - StatementId statementId = session.submit(queryRequest()); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); Optional statement = session.get(statementId); assertTrue(statement.isPresent()); @@ -312,7 +317,7 @@ public void getStatementNotExist() { // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); - Optional statement = session.get(StatementId.newStatementId("not-exist-id")); + Optional statement = session.get(StatementId.newStatementId()); assertFalse(statement.isPresent()); } @@ -356,8 +361,4 @@ public TestStatement cancel() { return this; } } - - private QueryRequest queryRequest() { - return new QueryRequest(AsyncQueryId.newAsyncQueryId(DS_NAME), LangType.SQL, "select 1"); - } }