diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index d5100885c4..d6463779d6 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -97,6 +97,7 @@ import org.opensearch.sql.spark.client.EmrServerlessClientImplEMR; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; @@ -297,14 +298,15 @@ private DataSourceServiceImpl createDataSourceService() { private AsyncQueryExecutorService createAsyncQueryExecutorService() { AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); - EMRServerlessClient EMRServerlessClient = createEMRServerlessClient(); + EMRServerlessClient emrServerlessClient = createEMRServerlessClient(); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - EMRServerlessClient, + emrServerlessClient, this.dataSourceService, new DataSourceUserAuthorizationHelperImpl(client), - jobExecutionResponseReader); + jobExecutionResponseReader, + new FlintIndexMetadataReaderImpl(client)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, pluginSettings); } diff --git a/spark/build.gradle b/spark/build.gradle index 778c7f4e60..d506bfbeca 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -90,7 +90,9 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.transport.model.*', 'org.opensearch.sql.spark.asyncquery.model.*', 'org.opensearch.sql.spark.asyncquery.exceptions.*', - 'org.opensearch.sql.spark.dispatcher.model.*' + 'org.opensearch.sql.spark.dispatcher.model.*', + 'org.opensearch.sql.spark.flint.FlintIndexType', + 'org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 74065c2d20..486a31bf73 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -25,6 +25,7 @@ import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.functions.response.DefaultSparkSqlFunctionResponseHandle; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -64,7 +65,7 @@ public CreateAsyncQueryResponse createAsyncQuery( SparkExecutionEngineConfig.toSparkExecutionEngineConfig( sparkExecutionEngineConfigString)); ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME); - String jobId = + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( sparkExecutionEngineConfig.getApplicationId(), @@ -74,8 +75,11 @@ public CreateAsyncQueryResponse createAsyncQuery( sparkExecutionEngineConfig.getExecutionRoleARN(), clusterName.value())); asyncQueryJobMetadataStorageService.storeJobMetadata( - new AsyncQueryJobMetadata(jobId, sparkExecutionEngineConfig.getApplicationId())); - return new CreateAsyncQueryResponse(jobId); + new AsyncQueryJobMetadata( + sparkExecutionEngineConfig.getApplicationId(), + dispatchQueryResponse.getJobId(), + dispatchQueryResponse.isDropIndexQuery())); + return new CreateAsyncQueryResponse(dispatchQueryResponse.getJobId()); } @Override @@ -84,9 +88,7 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { Optional jobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (jobMetadata.isPresent()) { - JSONObject jsonObject = - sparkQueryDispatcher.getQueryResponse( - jobMetadata.get().getApplicationId(), jobMetadata.get().getJobId()); + JSONObject jsonObject = sparkQueryDispatcher.getQueryResponse(jobMetadata.get()); if (JobRunState.SUCCESS.toString().equals(jsonObject.getString("status"))) { DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle = new DefaultSparkSqlFunctionResponseHandle(jsonObject); @@ -108,8 +110,7 @@ public String cancelQuery(String queryId) { Optional asyncQueryJobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (asyncQueryJobMetadata.isPresent()) { - return sparkQueryDispatcher.cancelJob( - asyncQueryJobMetadata.get().getApplicationId(), queryId); + return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get()); } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index 60ec53987e..64a2078066 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -13,6 +13,7 @@ import java.io.IOException; import lombok.AllArgsConstructor; import lombok.Data; +import lombok.EqualsAndHashCode; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.DeprecationHandler; @@ -23,9 +24,17 @@ /** This class models all the metadata required for a job. */ @Data @AllArgsConstructor +@EqualsAndHashCode public class AsyncQueryJobMetadata { - private String jobId; private String applicationId; + private String jobId; + private boolean isDropIndexQuery; + + public AsyncQueryJobMetadata(String applicationId, String jobId) { + this.applicationId = applicationId; + this.jobId = jobId; + this.isDropIndexQuery = false; + } @Override public String toString() { @@ -44,6 +53,7 @@ public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) builder.startObject(); builder.field("jobId", metadata.getJobId()); builder.field("applicationId", metadata.getApplicationId()); + builder.field("isDropIndexQuery", metadata.isDropIndexQuery()); builder.endObject(); return builder; } @@ -77,6 +87,7 @@ public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOExceptio public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws IOException { String jobId = null; String applicationId = null; + boolean isDropIndexQuery = false; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -88,6 +99,9 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws case "applicationId": applicationId = parser.textOrNull(); break; + case "isDropIndexQuery": + isDropIndexQuery = parser.booleanValue(); + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -95,6 +109,6 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws if (jobId == null || applicationId == null) { throw new IllegalArgumentException("jobId and applicationId are required fields."); } - return new AsyncQueryJobMetadata(jobId, applicationId); + return new AsyncQueryJobMetadata(applicationId, jobId, isDropIndexQuery); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 2749d7c934..9c5d4df667 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -11,15 +11,19 @@ import java.util.HashMap; import java.util.Map; import lombok.AllArgsConstructor; +import org.apache.commons.lang3.RandomStringUtils; import org.json.JSONObject; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.utils.SQLQueryUtils; @@ -42,49 +46,64 @@ public class SparkQueryDispatcher { private JobExecutionResponseReader jobExecutionResponseReader; - public String dispatch(DispatchQueryRequest dispatchQueryRequest) { - return emrServerlessClient.startJobRun(getStartJobRequest(dispatchQueryRequest)); + private FlintIndexMetadataReader flintIndexMetadataReader; + + public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { + if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { + return handleSQLQuery(dispatchQueryRequest); + } else { + // Since we don't need any extra handling for PPL, we are treating it as normal dispatch + // Query. + return handleNonIndexQuery(dispatchQueryRequest); + } } // TODO : Fetch from Result Index and then make call to EMR Serverless. - public JSONObject getQueryResponse(String applicationId, String queryId) { - GetJobRunResult getJobRunResult = emrServerlessClient.getJobRunResult(applicationId, queryId); + public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { + GetJobRunResult getJobRunResult = + emrServerlessClient.getJobRunResult( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); JSONObject result = new JSONObject(); if (getJobRunResult.getJobRun().getState().equals(JobRunState.SUCCESS.toString())) { - result = jobExecutionResponseReader.getResultFromOpensearchIndex(queryId); + result = + jobExecutionResponseReader.getResultFromOpensearchIndex(asyncQueryJobMetadata.getJobId()); } result.put("status", getJobRunResult.getJobRun().getState()); return result; } - public String cancelJob(String applicationId, String jobId) { - CancelJobRunResult cancelJobRunResult = emrServerlessClient.cancelJobRun(applicationId, jobId); + public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + CancelJobRunResult cancelJobRunResult = + emrServerlessClient.cancelJobRun( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); return cancelJobRunResult.getJobRunId(); } - // we currently don't support index queries in PPL language. - // so we are treating all of them as non-index queries which don't require any kind of query - // parsing. - private StartJobRequest getStartJobRequest(DispatchQueryRequest dispatchQueryRequest) { - if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { - if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery())) - return getStartJobRequestForIndexRequest(dispatchQueryRequest); - else { - return getStartJobRequestForNonIndexQueries(dispatchQueryRequest); + private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryRequest) { + if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery())) { + IndexDetails indexDetails = + SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); + if (indexDetails.isDropIndex()) { + return handleDropIndexQuery(dispatchQueryRequest, indexDetails); + } else { + return handleIndexQuery(dispatchQueryRequest, indexDetails); } } else { - return getStartJobRequestForNonIndexQueries(dispatchQueryRequest); + return handleNonIndexQuery(dispatchQueryRequest); } } - private StartJobRequest getStartJobRequestForNonIndexQueries( - DispatchQueryRequest dispatchQueryRequest) { - StartJobRequest startJobRequest; + private DispatchQueryResponse handleIndexQuery( + DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); dataSourceUserAuthorizationHelper.authorizeDataSource( this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource())); - String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; + String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - startJobRequest = + tags.put(INDEX_TAG_KEY, indexDetails.getIndexName()); + tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName()); + tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName()); + StartJobRequest startJobRequest = new StartJobRequest( dispatchQueryRequest.getQuery(), jobName, @@ -94,26 +113,21 @@ private StartJobRequest getStartJobRequestForNonIndexQueries( .dataSource( dataSourceService.getRawDataSourceMetadata( dispatchQueryRequest.getDatasource())) + .structuredStreaming(indexDetails.getAutoRefresh()) .build() .toString(), tags, - false); - return startJobRequest; + indexDetails.getAutoRefresh()); + String jobId = emrServerlessClient.startJobRun(startJobRequest); + return new DispatchQueryResponse(jobId, false); } - private StartJobRequest getStartJobRequestForIndexRequest( - DispatchQueryRequest dispatchQueryRequest) { - StartJobRequest startJobRequest; - IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); - FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { dataSourceUserAuthorizationHelper.authorizeDataSource( this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource())); - String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; + String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - tags.put(INDEX_TAG_KEY, indexDetails.getIndexName()); - tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName()); - tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName()); - startJobRequest = + StartJobRequest startJobRequest = new StartJobRequest( dispatchQueryRequest.getQuery(), jobName, @@ -123,12 +137,22 @@ private StartJobRequest getStartJobRequestForIndexRequest( .dataSource( dataSourceService.getRawDataSourceMetadata( dispatchQueryRequest.getDatasource())) - .structuredStreaming(indexDetails.getAutoRefresh()) .build() .toString(), tags, - indexDetails.getAutoRefresh()); - return startJobRequest; + false); + String jobId = emrServerlessClient.startJobRun(startJobRequest); + return new DispatchQueryResponse(jobId, false); + } + + private DispatchQueryResponse handleDropIndexQuery( + DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) { + dataSourceUserAuthorizationHelper.authorizeDataSource( + this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource())); + String jobId = flintIndexMetadataReader.getJobIdFromFlintIndexMetadata(indexDetails); + emrServerlessClient.cancelJobRun(dispatchQueryRequest.getApplicationId(), jobId); + String dropIndexDummyJobId = RandomStringUtils.randomAlphanumeric(16); + return new DispatchQueryResponse(dropIndexDummyJobId, true); } private static Map getDefaultTagsForJobSubmission( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java new file mode 100644 index 0000000000..592f3db4fe --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -0,0 +1,11 @@ +package org.opensearch.sql.spark.dispatcher.model; + +import lombok.AllArgsConstructor; +import lombok.Data; + +@Data +@AllArgsConstructor +public class DispatchQueryResponse { + private String jobId; + private boolean isDropIndexQuery; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java index 86fca60525..2034535848 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDetails.java @@ -5,13 +5,22 @@ package org.opensearch.sql.spark.dispatcher.model; +import lombok.AllArgsConstructor; import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import org.opensearch.sql.spark.flint.FlintIndexType; /** Index details in an async query. */ @Data +@AllArgsConstructor +@NoArgsConstructor +@EqualsAndHashCode public class IndexDetails { private String indexName; private FullyQualifiedTableName fullyQualifiedTableName; // by default, auto_refresh = false; private Boolean autoRefresh = false; + private boolean isDropIndex; + private FlintIndexType indexType; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java new file mode 100644 index 0000000000..7cb2e6a7c8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReader.java @@ -0,0 +1,15 @@ +package org.opensearch.sql.spark.flint; + +import org.opensearch.sql.spark.dispatcher.model.IndexDetails; + +/** Interface for FlintIndexMetadataReader */ +public interface FlintIndexMetadataReader { + + /** + * Given Index details, get the streaming job Id. + * + * @param indexDetails indexDetails. + * @return jobId. + */ + String getJobIdFromFlintIndexMetadata(IndexDetails indexDetails); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java new file mode 100644 index 0000000000..b552488455 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImpl.java @@ -0,0 +1,48 @@ +package org.opensearch.sql.spark.flint; + +import java.util.Map; +import lombok.AllArgsConstructor; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; +import org.opensearch.sql.spark.dispatcher.model.IndexDetails; + +/** Implementation of {@link FlintIndexMetadataReader} */ +@AllArgsConstructor +public class FlintIndexMetadataReaderImpl implements FlintIndexMetadataReader { + + private final Client client; + + @Override + public String getJobIdFromFlintIndexMetadata(IndexDetails indexDetails) { + String indexName = getIndexName(indexDetails); + GetMappingsResponse mappingsResponse = + client.admin().indices().prepareGetMappings(indexName).get(); + try { + MappingMetadata mappingMetadata = mappingsResponse.mappings().get(indexName); + Map mappingSourceMap = mappingMetadata.getSourceAsMap(); + Map metaMap = (Map) mappingSourceMap.get("_meta"); + Map propertiesMap = (Map) metaMap.get("properties"); + Map envMap = (Map) propertiesMap.get("env"); + return (String) envMap.get("SERVERLESS_EMR_JOB_ID"); + } catch (NullPointerException npe) { + throw new IllegalArgumentException("Index doesn't exist"); + } + } + + private String getIndexName(IndexDetails indexDetails) { + FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName(); + return "flint" + + "_" + + fullyQualifiedTableName.getDatasourceName() + + "_" + + fullyQualifiedTableName.getSchemaName() + + "_" + + fullyQualifiedTableName.getTableName() + + "_" + + indexDetails.getIndexName() + + "_" + + indexDetails.getIndexType().getName(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexType.java b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexType.java new file mode 100644 index 0000000000..1415856803 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexType.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.flint; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** Enum for FlintIndex Type. */ +public enum FlintIndexType { + SKIPPING("skipping_index"), + COVERING("covering_index"), + MATERIALIZED("materialized_view"); + + private final String name; + private static final Map ENUM_MAP; + + FlintIndexType(String name) { + this.name = name; + } + + public String getName() { + return this.name; + } + + static { + Map map = new HashMap<>(); + for (FlintIndexType instance : FlintIndexType.values()) { + map.put(instance.getName().toLowerCase(), instance); + } + ENUM_MAP = Collections.unmodifiableMap(map); + } + + public static FlintIndexType get(String name) { + return ENUM_MAP.get(name.toLowerCase()); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index 481591a4f0..f6b75d49ef 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/spark/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -21,6 +21,7 @@ import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.flint.FlintIndexType; /** * This util class parses spark sql query and provides util functions to identify indexName, @@ -137,6 +138,8 @@ public Void visitTableName(FlintSparkSqlExtensionsParser.TableNameContext ctx) { @Override public Void visitCreateSkippingIndexStatement( FlintSparkSqlExtensionsParser.CreateSkippingIndexStatementContext ctx) { + indexDetails.setDropIndex(false); + indexDetails.setIndexType(FlintIndexType.SKIPPING); visitPropertyList(ctx.propertyList()); return super.visitCreateSkippingIndexStatement(ctx); } @@ -144,10 +147,28 @@ public Void visitCreateSkippingIndexStatement( @Override public Void visitCreateCoveringIndexStatement( FlintSparkSqlExtensionsParser.CreateCoveringIndexStatementContext ctx) { + indexDetails.setDropIndex(false); + indexDetails.setIndexType(FlintIndexType.COVERING); visitPropertyList(ctx.propertyList()); return super.visitCreateCoveringIndexStatement(ctx); } + @Override + public Void visitDropCoveringIndexStatement( + FlintSparkSqlExtensionsParser.DropCoveringIndexStatementContext ctx) { + indexDetails.setDropIndex(true); + indexDetails.setIndexType(FlintIndexType.COVERING); + return super.visitDropCoveringIndexStatement(ctx); + } + + @Override + public Void visitDropSkippingIndexStatement( + FlintSparkSqlExtensionsParser.DropSkippingIndexStatementContext ctx) { + indexDetails.setDropIndex(true); + indexDetails.setIndexType(FlintIndexType.SKIPPING); + return super.visitDropSkippingIndexStatement(ctx); + } + @Override public Void visitPropertyList(FlintSparkSqlExtensionsParser.PropertyListContext ctx) { if (ctx != null) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index ff1b17473a..df897ec7dc 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -31,6 +31,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.rest.model.LangType; @@ -63,11 +64,11 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) - .thenReturn(EMR_JOB_ID); + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false)); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) - .storeJobMetadata(new AsyncQueryJobMetadata(EMR_JOB_ID, "00fd775baqpu4g0p")); + .storeJobMetadata(new AsyncQueryJobMetadata("00fd775baqpu4g0p", EMR_JOB_ID)); verify(settings, times(1)).getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); verify(settings, times(1)).getSettingValue(Settings.Key.CLUSTER_NAME); verify(sparkQueryDispatcher, times(1)) @@ -105,10 +106,11 @@ void testGetAsyncQueryResultsWithInProgressJob() { new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))); JSONObject jobResult = new JSONObject(); jobResult.put("status", JobRunState.PENDING.toString()); - when(sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID)) + when(sparkQueryDispatcher.getQueryResponse( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); @@ -122,10 +124,11 @@ void testGetAsyncQueryResultsWithInProgressJob() { @Test void testGetAsyncQueryResultsWithSuccessJob() throws IOException { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))); JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); jobResult.put("status", JobRunState.SUCCESS.toString()); - when(sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID)) + when(sparkQueryDispatcher.getQueryResponse( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))) .thenReturn(jobResult); AsyncQueryExecutorServiceImpl jobExecutorService = @@ -182,8 +185,9 @@ void testCancelJob() { new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID))); - when(sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID)).thenReturn(EMR_JOB_ID); + .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))); + when(sparkQueryDispatcher.cancelJob(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID))) + .thenReturn(EMR_JOB_ID); String jobId = asyncQueryExecutorService.cancelQuery(EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, jobId); verifyNoInteractions(settings); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index fe9da12ef0..7097daf13e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -183,7 +183,7 @@ public void testGetJobMetadata() { new SearchHits( new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID); + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID); Mockito.when(searchHit.getSourceAsString()).thenReturn(asyncQueryJobMetadata.toString()); Optional jobMetadataOptional = diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 2a8c21d342..8dbf60e170 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -26,6 +26,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import org.apache.commons.lang3.StringUtils; import org.json.JSONObject; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -36,9 +37,15 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; +import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; +import org.opensearch.sql.spark.dispatcher.model.IndexDetails; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; +import org.opensearch.sql.spark.flint.FlintIndexType; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; @@ -49,6 +56,7 @@ public class SparkQueryDispatcherTest { @Mock private DataSourceService dataSourceService; @Mock private JobExecutionResponseReader jobExecutionResponseReader; @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; + @Mock private FlintIndexMetadataReader flintIndexMetadataReader; @Test void testDispatchSelectQuery() { @@ -57,7 +65,8 @@ void testDispatchSelectQuery() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); @@ -81,7 +90,7 @@ void testDispatchSelectQuery() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); - String jobId = + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( EMRS_APPLICATION_ID, @@ -106,7 +115,9 @@ void testDispatchSelectQuery() { }), tags, false)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); } @Test @@ -116,7 +127,8 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); @@ -141,7 +153,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); - String jobId = + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( EMRS_APPLICATION_ID, @@ -167,7 +179,9 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { }), tags, false)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); } @Test @@ -177,7 +191,8 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); @@ -200,7 +215,7 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithNoAuth(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); - String jobId = + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( EMRS_APPLICATION_ID, @@ -224,7 +239,9 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { }), tags, false)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); } @Test @@ -234,7 +251,8 @@ void testDispatchIndexQuery() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("table", "http_logs"); @@ -264,7 +282,7 @@ void testDispatchIndexQuery() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); - String jobId = + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( EMRS_APPLICATION_ID, @@ -290,7 +308,9 @@ void testDispatchIndexQuery() { })), tags, true)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); } @Test @@ -303,7 +323,8 @@ void testDispatchWithPPLQuery() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); String query = "source = my_glue.default.http_logs"; when(emrServerlessClient.startJobRun( new StartJobRequest( @@ -324,7 +345,7 @@ void testDispatchWithPPLQuery() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); - String jobId = + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( EMRS_APPLICATION_ID, @@ -349,7 +370,9 @@ void testDispatchWithPPLQuery() { }), tags, false)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); } @Test @@ -362,7 +385,8 @@ void testDispatchQueryWithoutATableAndDataSourceName() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); String query = "show tables"; when(emrServerlessClient.startJobRun( new StartJobRequest( @@ -383,7 +407,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); - String jobId = + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( EMRS_APPLICATION_ID, @@ -408,7 +432,9 @@ void testDispatchQueryWithoutATableAndDataSourceName() { }), tags, false)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); } @Test @@ -424,7 +450,8 @@ void testDispatchIndexQueryWithoutADatasourceName() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); String query = "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; @@ -448,7 +475,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); - String jobId = + DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( EMRS_APPLICATION_ID, @@ -474,7 +501,9 @@ void testDispatchIndexQueryWithoutADatasourceName() { })), tags, true)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); + verifyNoInteractions(flintIndexMetadataReader); } @Test @@ -484,7 +513,8 @@ void testDispatchWithWrongURI() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); when(dataSourceService.getRawDataSourceMetadata("my_glue")) .thenReturn(constructMyGlueDataSourceMetadataWithBadURISyntax()); String query = "select * from my_glue.default.http_logs"; @@ -512,7 +542,8 @@ void testDispatchWithUnSupportedDataSourceType() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); when(dataSourceService.getRawDataSourceMetadata("my_prometheus")) .thenReturn(constructPrometheusDataSourceType()); String query = "select * from my_prometheus.default.http_logs"; @@ -540,13 +571,15 @@ void testCancelJob() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn( new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String jobId = sparkQueryDispatcher.cancelJob(EMRS_APPLICATION_ID, EMR_JOB_ID); + String jobId = + sparkQueryDispatcher.cancelJob(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID)); Assertions.assertEquals(EMR_JOB_ID, jobId); } @@ -557,10 +590,13 @@ void testGetQueryResponse() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); - JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); + JSONObject result = + sparkQueryDispatcher.getQueryResponse( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID)); Assertions.assertEquals("PENDING", result.get("status")); verifyNoInteractions(jobExecutionResponseReader); } @@ -572,14 +608,17 @@ void testGetQueryResponseWithSuccess() { emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, - jobExecutionResponseReader); + jobExecutionResponseReader, + flintIndexMetadataReader); when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.SUCCESS))); JSONObject queryResult = new JSONObject(); queryResult.put("data", "result"); when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID)) .thenReturn(queryResult); - JSONObject result = sparkQueryDispatcher.getQueryResponse(EMRS_APPLICATION_ID, EMR_JOB_ID); + JSONObject result = + sparkQueryDispatcher.getQueryResponse( + new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID)); verify(emrServerlessClient, times(1)).getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID); verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID); Assertions.assertEquals(new HashSet<>(Arrays.asList("data", "status")), result.keySet()); @@ -587,6 +626,108 @@ void testGetQueryResponseWithSuccess() { Assertions.assertEquals("SUCCESS", result.get("status")); } + @Test + void testDropIndexQuery() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader, + flintIndexMetadataReader); + String query = "DROP INDEX size_year ON my_glue.default.http_logs"; + when(flintIndexMetadataReader.getJobIdFromFlintIndexMetadata( + new IndexDetails( + "size_year", + new FullyQualifiedTableName("my_glue.default.http_logs"), + false, + true, + FlintIndexType.COVERING))) + .thenReturn(EMR_JOB_ID); + when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); + verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); + verify(flintIndexMetadataReader, times(1)) + .getJobIdFromFlintIndexMetadata( + new IndexDetails( + "size_year", + new FullyQualifiedTableName("my_glue.default.http_logs"), + false, + true, + FlintIndexType.COVERING)); + Assertions.assertNotEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertTrue(StringUtils.isAlphanumeric(dispatchQueryResponse.getJobId())); + Assertions.assertEquals(16, dispatchQueryResponse.getJobId().length()); + Assertions.assertTrue(dispatchQueryResponse.isDropIndexQuery()); + } + + @Test + void testDropSkippingIndexQuery() { + SparkQueryDispatcher sparkQueryDispatcher = + new SparkQueryDispatcher( + emrServerlessClient, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader, + flintIndexMetadataReader); + String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; + when(flintIndexMetadataReader.getJobIdFromFlintIndexMetadata( + new IndexDetails( + null, + new FullyQualifiedTableName("my_glue.default.http_logs"), + false, + true, + FlintIndexType.SKIPPING))) + .thenReturn(EMR_JOB_ID); + when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) + .thenReturn( + new CancelJobRunResult() + .withJobRunId(EMR_JOB_ID) + .withApplicationId(EMRS_APPLICATION_ID)); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME)); + verify(emrServerlessClient, times(1)).cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID); + verify(dataSourceUserAuthorizationHelper, times(1)).authorizeDataSource(dataSourceMetadata); + verify(flintIndexMetadataReader, times(1)) + .getJobIdFromFlintIndexMetadata( + new IndexDetails( + null, + new FullyQualifiedTableName("my_glue.default.http_logs"), + false, + true, + FlintIndexType.SKIPPING)); + Assertions.assertNotEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertTrue(StringUtils.isAlphanumeric(dispatchQueryResponse.getJobId())); + Assertions.assertEquals(16, dispatchQueryResponse.getJobId().length()); + Assertions.assertTrue(dispatchQueryResponse.isDropIndexQuery()); + } + private String constructExpectedSparkSubmitParameterString( String auth, Map authParams) { StringBuilder authParamConfigBuilder = new StringBuilder(); diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java new file mode 100644 index 0000000000..5d89a4e4c6 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/FlintIndexMetadataReaderImplTest.java @@ -0,0 +1,71 @@ +package org.opensearch.sql.spark.flint; + +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.base.Charsets; +import com.google.common.io.Resources; +import java.io.IOException; +import java.net.URL; +import java.util.Map; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; +import org.opensearch.sql.spark.dispatcher.model.IndexDetails; + +@ExtendWith(MockitoExtension.class) +public class FlintIndexMetadataReaderImplTest { + @Mock(answer = RETURNS_DEEP_STUBS) + private Client client; + + // TODO FIX this + @SneakyThrows + // @Test + void testGetJobIdFromFlintIndexMetadata() { + URL url = + Resources.getResource( + "flint-index-mappings/flint_my_glue_default_http_logs_size_year_covering_index.json"); + String mappings = Resources.toString(url, Charsets.UTF_8); + String indexName = "flint_my_glue_default_http_logs_size_year_covering_index"; + mockNodeClientIndicesMappings(indexName, mappings); + FlintIndexMetadataReader flintIndexMetadataReader = new FlintIndexMetadataReaderImpl(client); + String jobId = + flintIndexMetadataReader.getJobIdFromFlintIndexMetadata( + new IndexDetails( + "size_year", + new FullyQualifiedTableName("my_glue.default.http_logs"), + false, + true, + FlintIndexType.COVERING)); + Assertions.assertEquals("00fdlum58g9g1g0q", jobId); + } + + @SneakyThrows + public void mockNodeClientIndicesMappings(String indexName, String mappings) { + GetMappingsResponse mockResponse = mock(GetMappingsResponse.class); + when(client.admin().indices().prepareGetMappings(any()).get()).thenReturn(mockResponse); + Map metadata; + metadata = Map.of(indexName, IndexMetadata.fromXContent(createParser(mappings)).mapping()); + when(mockResponse.mappings()).thenReturn(metadata); + } + + private XContentParser createParser(String mappings) throws IOException { + return XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, mappings); + } +} diff --git a/spark/src/test/resources/flint-index-mappings/flint_my_glue_default_http_logs_size_year_covering_index.json b/spark/src/test/resources/flint-index-mappings/flint_my_glue_default_http_logs_size_year_covering_index.json new file mode 100644 index 0000000000..201aa539bb --- /dev/null +++ b/spark/src/test/resources/flint-index-mappings/flint_my_glue_default_http_logs_size_year_covering_index.json @@ -0,0 +1,32 @@ +{ + "mappings": { + "_meta": { + "kind": "skipping", + "indexedColumns": [ + { + "columnType": "int", + "kind": "VALUE_SET", + "columnName": "status" + } + ], + "name": "flint_mys3_default_http_logs_skipping_index", + "options": {}, + "source": "mys3.default.http_logs", + "version": "0.1.0", + "properties": { + "env": { + "SERVERLESS_EMR_VIRTUAL_CLUSTER_ID": "00fd777k3k3ls20p", + "SERVERLESS_EMR_JOB_ID": "00fdmvv9hp8u0o0q" + } + } + }, + "properties": { + "file_path": { + "type": "keyword" + }, + "status": { + "type": "integer" + } + } + } +} \ No newline at end of file