diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 3366e21894..0e871f9ddc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.dispatcher; import java.util.HashMap; +import java.util.List; import java.util.Map; import lombok.AllArgsConstructor; import org.jetbrains.annotations.NotNull; @@ -45,25 +46,50 @@ public DispatchQueryResponse dispatch( this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata( dispatchQueryRequest.getDatasource()); - if (LangType.SQL.equals(dispatchQueryRequest.getLangType()) - && SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) { - IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest); - DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) - .indexQueryDetails(indexQueryDetails) - .asyncQueryRequestContext(asyncQueryRequestContext) - .build(); - - return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails) - .submit(dispatchQueryRequest, context); - } else { - DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) - .asyncQueryRequestContext(asyncQueryRequestContext) - .build(); - return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId()) - .submit(dispatchQueryRequest, context); + if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { + String query = dispatchQueryRequest.getQuery(); + + if (SQLQueryUtils.isFlintExtensionQuery(query)) { + return handleFlintExtensionQuery( + dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata); + } + + List validationErrors = SQLQueryUtils.validateSparkSqlQuery(query); + if (!validationErrors.isEmpty()) { + throw new IllegalArgumentException( + "Query is not allowed: " + String.join(", ", validationErrors)); + } } + return handleDefaultQuery(dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata); + } + + private DispatchQueryResponse handleFlintExtensionQuery( + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext asyncQueryRequestContext, + DataSourceMetadata dataSourceMetadata) { + IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest); + DispatchQueryContext context = + getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + .indexQueryDetails(indexQueryDetails) + .asyncQueryRequestContext(asyncQueryRequestContext) + .build(); + + return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails) + .submit(dispatchQueryRequest, context); + } + + private DispatchQueryResponse handleDefaultQuery( + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext asyncQueryRequestContext, + DataSourceMetadata dataSourceMetadata) { + + DispatchQueryContext context = + getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + .asyncQueryRequestContext(asyncQueryRequestContext) + .build(); + + return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId()) + .submit(dispatchQueryRequest, context); } private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index a96e203cea..0bb9cb4b85 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -5,6 +5,8 @@ package org.opensearch.sql.spark.utils; +import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Locale; @@ -73,6 +75,32 @@ public static boolean isFlintExtensionQuery(String sqlQuery) { } } + public static List validateSparkSqlQuery(String sqlQuery) { + SparkSqlValidatorVisitor sparkSqlValidatorVisitor = new SparkSqlValidatorVisitor(); + SqlBaseParser sqlBaseParser = + new SqlBaseParser( + new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); + sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); + try { + SqlBaseParser.StatementContext statement = sqlBaseParser.statement(); + sparkSqlValidatorVisitor.visit(statement); + return sparkSqlValidatorVisitor.getValidationErrors(); + } catch (SyntaxCheckException syntaxCheckException) { + return Collections.emptyList(); + } + } + + private static class SparkSqlValidatorVisitor extends SqlBaseParserBaseVisitor { + + @Getter private final List validationErrors = new ArrayList<>(); + + @Override + public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { + validationErrors.add("Creating user-defined functions is not allowed"); + return super.visitCreateFunction(ctx); + } + } + public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { @Getter private List fullyQualifiedTableNames = new LinkedList<>(); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 5582de332c..ebec92f484 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -41,9 +41,11 @@ import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.JobRunState; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; @@ -438,6 +440,85 @@ void testDispatchWithPPLQuery() { verifyNoInteractions(flintIndexMetadataService); } + @Test + void testDispatchWithSparkUDFQuery() { + List udfQueries = new ArrayList<>(); + udfQueries.add( + "CREATE FUNCTION celsius_to_fahrenheit AS 'org.apache.spark.sql.functions.expr(\"(celsius *" + + " 9/5) + 32\")'"); + udfQueries.add( + "CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'"); + for (String query : udfQueries) { + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + .thenReturn(dataSourceMetadata); + + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.dispatch( + getBaseDispatchQueryRequestBuilder(query).langType(LangType.SQL).build(), + asyncQueryRequestContext)); + Assertions.assertEquals( + "Query is not allowed: Creating user-defined functions is not allowed", + illegalArgumentException.getMessage()); + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(flintIndexMetadataService); + } + } + + @Test + void testInvalidSQLQueryDispatchToSpark() { + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + HashMap tags = new HashMap<>(); + tags.put(DATASOURCE_TAG_KEY, MY_GLUE); + tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); + String query = "myselect 1"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }, + query); + StartJobRequest expected = + new StartJobRequest( + "TEST_CLUSTER:batch", + null, + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + "query_execution_result_my_glue"); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + .thenReturn(dataSourceMetadata); + + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + DispatchQueryRequest.builder() + .applicationId(EMRS_APPLICATION_ID) + .query(query) + .datasource(MY_GLUE) + .langType(LangType.SQL) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .clusterName(TEST_CLUSTER_NAME) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .build(), + asyncQueryRequestContext); + + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + verifyNoInteractions(flintIndexMetadataService); + } + @Test void testDispatchQueryWithoutATableAndDataSourceName() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index af49a59838..9b889f7f97 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -68,6 +68,8 @@ Async Query Creation API ====================================== If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/create``. +Limitation: Spark SQL queries that create User-Defined Functions (UDFs) are not allowed. + HTTP URI: ``_plugins/_async_query`` HTTP VERB: ``POST``