From ddbaeaeccc1fe79c66095dbac0fcdc2f415e4303 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 5 Oct 2023 23:58:28 +0000 Subject: [PATCH] Read extra Spark submit parameters from cluster settings (#2219) * Add default setting for Spark execution engine Signed-off-by: Chen Dai * Pass extra parameters to Spark dispatcher Signed-off-by: Chen Dai * Wrap read default setting file with previlege action Signed-off-by: Chen Dai * Fix spotless format Signed-off-by: Chen Dai * Use input stream to read default config file Signed-off-by: Chen Dai * Add UT for dispatcher Signed-off-by: Chen Dai * Add more UT Signed-off-by: Chen Dai * Remove default config setting Signed-off-by: Chen Dai * Fix spotless check in spark module Signed-off-by: Chen Dai * Refactor test code Signed-off-by: Chen Dai * Add more UT on config class Signed-off-by: Chen Dai --------- Signed-off-by: Chen Dai (cherry picked from commit 492982c461bbeecdf5169912aa2795a71140212b) Signed-off-by: github-actions[bot] --- .../setting/OpenSearchSettingsTest.java | 52 +++++-- .../AsyncQueryExecutorServiceImpl.java | 4 +- .../model/SparkSubmitParameters.java | 17 +- .../config/SparkExecutionEngineConfig.java | 3 + .../dispatcher/SparkQueryDispatcher.java | 2 + .../model/DispatchQueryRequest.java | 7 + .../AsyncQueryExecutorServiceImplTest.java | 63 +++++--- .../model/SparkSubmitParametersTest.java | 30 ++++ .../SparkExecutionEngineConfigTest.java | 49 ++++++ .../dispatcher/SparkQueryDispatcherTest.java | 147 +++++++----------- 10 files changed, 251 insertions(+), 123 deletions(-) create mode 100644 spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigTest.java diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java index ff2c311753..e99e5b360a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/setting/OpenSearchSettingsTest.java @@ -15,6 +15,15 @@ import static org.mockito.Mockito.when; import static org.opensearch.common.unit.TimeValue.timeValueMinutes; import static org.opensearch.sql.opensearch.setting.LegacyOpenDistroSettings.legacySettings; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.METRICS_ROLLING_INTERVAL_SETTING; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.METRICS_ROLLING_WINDOW_SETTING; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.PPL_ENABLED_SETTING; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.QUERY_SIZE_LIMIT_SETTING; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_ENGINE_CONFIG; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SQL_CURSOR_KEEP_ALIVE_SETTING; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SQL_ENABLED_SETTING; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SQL_SLOWLOG_SETTING; import java.util.List; import org.junit.jupiter.api.Test; @@ -47,14 +56,13 @@ void getSettingValue() { @Test void getSettingValueWithPresetValuesInYml() { when(clusterSettings.get(ClusterName.CLUSTER_NAME_SETTING)).thenReturn(ClusterName.DEFAULT); - when(clusterSettings.get( - (Setting) OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING)) + when(clusterSettings.get((Setting) QUERY_MEMORY_LIMIT_SETTING)) .thenReturn(new ByteSizeValue(20)); when(clusterSettings.get( not( or( eq(ClusterName.CLUSTER_NAME_SETTING), - eq((Setting) OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING))))) + eq((Setting) QUERY_MEMORY_LIMIT_SETTING))))) .thenReturn(null); OpenSearchSettings settings = new OpenSearchSettings(clusterSettings); ByteSizeValue sizeValue = settings.getSettingValue(Settings.Key.QUERY_MEMORY_LIMIT); @@ -150,21 +158,41 @@ public void updateLegacySettingsFallback() { .put(LegacySettings.Key.METRICS_ROLLING_INTERVAL.getKeyValue(), 100L) .build(); - assertEquals(OpenSearchSettings.SQL_ENABLED_SETTING.get(settings), false); - assertEquals(OpenSearchSettings.SQL_SLOWLOG_SETTING.get(settings), 10); - assertEquals( - OpenSearchSettings.SQL_CURSOR_KEEP_ALIVE_SETTING.get(settings), timeValueMinutes(1)); - assertEquals(OpenSearchSettings.PPL_ENABLED_SETTING.get(settings), true); + assertEquals(SQL_ENABLED_SETTING.get(settings), false); + assertEquals(SQL_SLOWLOG_SETTING.get(settings), 10); + assertEquals(SQL_CURSOR_KEEP_ALIVE_SETTING.get(settings), timeValueMinutes(1)); + assertEquals(PPL_ENABLED_SETTING.get(settings), true); assertEquals( - OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING.get(settings), + QUERY_MEMORY_LIMIT_SETTING.get(settings), new ByteSizeValue((int) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.2))); - assertEquals(OpenSearchSettings.QUERY_SIZE_LIMIT_SETTING.get(settings), 100); - assertEquals(OpenSearchSettings.METRICS_ROLLING_WINDOW_SETTING.get(settings), 2000L); - assertEquals(OpenSearchSettings.METRICS_ROLLING_INTERVAL_SETTING.get(settings), 100L); + assertEquals(QUERY_SIZE_LIMIT_SETTING.get(settings), 100); + assertEquals(METRICS_ROLLING_WINDOW_SETTING.get(settings), 2000L); + assertEquals(METRICS_ROLLING_INTERVAL_SETTING.get(settings), 100L); } @Test void legacySettingsShouldBeDeprecatedBeforeRemove() { assertEquals(15, legacySettings().size()); } + + @Test + void getSparkExecutionEngineConfigSetting() { + // Default is empty string + assertEquals( + "", + SPARK_EXECUTION_ENGINE_CONFIG.get( + org.opensearch.common.settings.Settings.builder().build())); + + // Configurable at runtime + String sparkConfig = + "{\n" + + " \"sparkSubmitParameters\": \"--conf spark.dynamicAllocation.enabled=false\"\n" + + "}"; + assertEquals( + sparkConfig, + SPARK_EXECUTION_ENGINE_CONFIG.get( + org.opensearch.common.settings.Settings.builder() + .put(SPARK_EXECUTION_ENGINE_CONFIG.getKey(), sparkConfig) + .build())); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index bbb5abdb28..55346bc289 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -75,8 +75,8 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getDatasource(), createAsyncQueryRequest.getLang(), sparkExecutionEngineConfig.getExecutionRoleARN(), - clusterName.value())); - + clusterName.value(), + sparkExecutionEngineConfig.getSparkSubmitParameters())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata( sparkExecutionEngineConfig.getApplicationId(), diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index 1f7bf4b9fb..0609d8903c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -18,12 +18,14 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.function.Supplier; +import lombok.AllArgsConstructor; import lombok.RequiredArgsConstructor; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.datasources.auth.AuthenticationType; /** Define Spark Submit Parameters. */ +@AllArgsConstructor @RequiredArgsConstructor public class SparkSubmitParameters { public static final String SPACE = " "; @@ -32,10 +34,14 @@ public class SparkSubmitParameters { private final String className; private final Map config; + /** Extra parameters to append finally */ + private String extraParameters; + public static class Builder { private final String className; private final Map config; + private String extraParameters; private Builder() { className = DEFAULT_CLASS_NAME; @@ -130,8 +136,13 @@ public Builder structuredStreaming(Boolean isStructuredStreaming) { return this; } + public Builder extraParameters(String params) { + extraParameters = params; + return this; + } + public SparkSubmitParameters build() { - return new SparkSubmitParameters(className, config); + return new SparkSubmitParameters(className, config, extraParameters); } } @@ -148,6 +159,10 @@ public String toString() { stringBuilder.append(config.get(key)); stringBuilder.append(SPACE); } + + if (extraParameters != null) { + stringBuilder.append(extraParameters); + } return stringBuilder.toString(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java index 4f928c4f1f..23e5907b5c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java @@ -16,6 +16,9 @@ public class SparkExecutionEngineConfig { private String region; private String executionRoleARN; + /** Additional Spark submit parameters to append to request. */ + private String sparkSubmitParameters; + public static SparkExecutionEngineConfig toSparkExecutionEngineConfig(String jsonString) { return new Gson().fromJson(jsonString, SparkExecutionEngineConfig.class); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index dcce11fd55..aca4c86e0e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -145,6 +145,7 @@ private DispatchQueryResponse handleIndexQuery( dataSourceService.getRawDataSourceMetadata( dispatchQueryRequest.getDatasource())) .structuredStreaming(indexDetails.getAutoRefresh()) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() .toString(), tags, @@ -170,6 +171,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ .dataSource( dataSourceService.getRawDataSourceMetadata( dispatchQueryRequest.getDatasource())) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) .build() .toString(), tags, diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java index 09240278ee..823a4570ce 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryRequest.java @@ -5,10 +5,14 @@ package org.opensearch.sql.spark.dispatcher.model; +import lombok.AllArgsConstructor; import lombok.Data; +import lombok.RequiredArgsConstructor; import org.opensearch.sql.spark.rest.model.LangType; +@AllArgsConstructor @Data +@RequiredArgsConstructor // required explicitly public class DispatchQueryRequest { private final String applicationId; private final String query; @@ -16,4 +20,7 @@ public class DispatchQueryRequest { private final LangType langType; private final String executionRoleARN; private final String clusterName; + + /** Optional extra Spark submit parameters to include in final request */ + private String extraSparkSubmitParams; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index a053f30f3b..e16dd89639 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -5,6 +5,8 @@ package org.opensearch.sql.spark.asyncquery; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -20,6 +22,7 @@ import java.util.Optional; import org.json.JSONObject; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -43,11 +46,17 @@ public class AsyncQueryExecutorServiceImplTest { @Mock private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService; @Mock private Settings settings; - @Test - void testCreateAsyncQuery() { - AsyncQueryExecutorServiceImpl jobExecutorService = + private AsyncQueryExecutorService jobExecutorService; + + @BeforeEach + void setUp() { + jobExecutorService = new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + } + + @Test + void testCreateAsyncQuery() { CreateAsyncQueryRequest createAsyncQueryRequest = new CreateAsyncQueryRequest( "select * from my_glue.default.http_logs", "my_glue", LangType.SQL); @@ -83,11 +92,36 @@ void testCreateAsyncQuery() { Assertions.assertEquals(EMR_JOB_ID, createAsyncQueryResponse.getQueryId()); } + @Test + void testCreateAsyncQueryWithExtraSparkSubmitParameter() { + when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) + .thenReturn( + "{" + + "\"applicationId\": \"00fd775baqpu4g0p\"," + + "\"executionRoleARN\": \"arn:aws:iam::270824043731:role/emr-job-execution-role\"," + + "\"region\": \"eu-west-1\"," + + "\"sparkSubmitParameters\": \"--conf spark.dynamicAllocation.enabled=false\"" + + "}"); + when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) + .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + when(sparkQueryDispatcher.dispatch(any())) + .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); + + jobExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "select * from my_glue.default.http_logs", "my_glue", LangType.SQL)); + + verify(sparkQueryDispatcher, times(1)) + .dispatch( + argThat( + actualReq -> + actualReq + .getExtraSparkSubmitParams() + .equals("--conf spark.dynamicAllocation.enabled=false"))); + } + @Test void testGetAsyncQueryResultsWithJobNotFoundException() { - AsyncQueryExecutorServiceImpl jobExecutorService = - new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) .thenReturn(Optional.empty()); AsyncQueryNotFoundException asyncQueryNotFoundException = @@ -102,9 +136,6 @@ void testGetAsyncQueryResultsWithJobNotFoundException() { @Test void testGetAsyncQueryResultsWithInProgressJob() { - AsyncQueryExecutorServiceImpl jobExecutorService = - new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(); @@ -131,9 +162,6 @@ void testGetAsyncQueryResultsWithSuccessJob() throws IOException { new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); - AsyncQueryExecutorServiceImpl jobExecutorService = - new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); @@ -164,15 +192,11 @@ void testGetAsyncQueryResultsWithDisabledExecutionEngine() { @Test void testCancelJobWithJobNotFound() { - AsyncQueryExecutorService asyncQueryExecutorService = - new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) .thenReturn(Optional.empty()); AsyncQueryNotFoundException asyncQueryNotFoundException = Assertions.assertThrows( - AsyncQueryNotFoundException.class, - () -> asyncQueryExecutorService.cancelQuery(EMR_JOB_ID)); + AsyncQueryNotFoundException.class, () -> jobExecutorService.cancelQuery(EMR_JOB_ID)); Assertions.assertEquals( "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); verifyNoInteractions(sparkQueryDispatcher); @@ -181,15 +205,12 @@ void testCancelJobWithJobNotFound() { @Test void testCancelJob() { - AsyncQueryExecutorService asyncQueryExecutorService = - new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); when(sparkQueryDispatcher.cancelJob( new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(EMR_JOB_ID); - String jobId = asyncQueryExecutorService.cancelQuery(EMR_JOB_ID); + String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, jobId); verifyNoInteractions(settings); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java new file mode 100644 index 0000000000..a914a975b9 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParametersTest.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery.model; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +public class SparkSubmitParametersTest { + + @Test + public void testBuildWithoutExtraParameters() { + String params = SparkSubmitParameters.Builder.builder().build().toString(); + + assertNotNull(params); + } + + @Test + public void testBuildWithExtraParameters() { + String params = + SparkSubmitParameters.Builder.builder().extraParameters("--conf A=1").build().toString(); + + // Assert the conf is included with a space + assertTrue(params.endsWith(" --conf A=1")); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigTest.java b/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigTest.java new file mode 100644 index 0000000000..29b69ea830 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigTest.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +import org.junit.jupiter.api.Test; + +public class SparkExecutionEngineConfigTest { + + @Test + public void testToSparkExecutionEngineConfigWithoutAllFields() { + String json = + "{" + + "\"applicationId\": \"app-1\"," + + "\"executionRoleARN\": \"role-1\"," + + "\"region\": \"us-west-1\"" + + "}"; + SparkExecutionEngineConfig config = + SparkExecutionEngineConfig.toSparkExecutionEngineConfig(json); + + assertEquals("app-1", config.getApplicationId()); + assertEquals("role-1", config.getExecutionRoleARN()); + assertEquals("us-west-1", config.getRegion()); + assertNull(config.getSparkSubmitParameters()); + } + + @Test + public void testToSparkExecutionEngineConfigWithAllFields() { + String json = + "{" + + "\"applicationId\": \"app-1\"," + + "\"executionRoleARN\": \"role-1\"," + + "\"region\": \"us-west-1\"," + + "\"sparkSubmitParameters\": \"--conf A=1\"" + + "}"; + SparkExecutionEngineConfig config = + SparkExecutionEngineConfig.toSparkExecutionEngineConfig(json); + + assertEquals("app-1", config.getApplicationId()); + assertEquals("role-1", config.getExecutionRoleARN()); + assertEquals("us-west-1", config.getRegion()); + assertEquals("--conf A=1", config.getSparkSubmitParameters()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 925e6f1a90..7d97cc6c50 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -6,7 +6,9 @@ package org.opensearch.sql.spark.dispatcher; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -33,6 +35,7 @@ import org.apache.commons.lang3.StringUtils; import org.json.JSONObject; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -62,15 +65,21 @@ public class SparkQueryDispatcherTest { @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; @Mock private FlintIndexMetadataReader flintIndexMetadataReader; - @Test - void testDispatchSelectQuery() { - SparkQueryDispatcher sparkQueryDispatcher = + private SparkQueryDispatcher sparkQueryDispatcher; + + @BeforeEach + void setUp() { + sparkQueryDispatcher = new SparkQueryDispatcher( emrServerlessClient, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader, flintIndexMetadataReader); + } + + @Test + void testDispatchSelectQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); @@ -128,13 +137,6 @@ void testDispatchSelectQuery() { @Test void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); @@ -194,13 +196,6 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { @Test void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); @@ -256,13 +251,6 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { @Test void testDispatchIndexQuery() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("table", "http_logs"); @@ -330,13 +318,7 @@ void testDispatchWithPPLQuery() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); + String query = "source = my_glue.default.http_logs"; when(emrServerlessClient.startJobRun( new StartJobRequest( @@ -394,13 +376,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() { HashMap tags = new HashMap<>(); tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); + String query = "show tables"; when(emrServerlessClient.startJobRun( new StartJobRequest( @@ -461,13 +437,7 @@ void testDispatchIndexQueryWithoutADatasourceName() { tags.put("index", "elb_and_requestUri"); tags.put("cluster", TEST_CLUSTER_NAME); tags.put("schema", "default"); - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); + String query = "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; @@ -526,13 +496,6 @@ void testDispatchIndexQueryWithoutADatasourceName() { @Test void testDispatchWithWrongURI() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); when(dataSourceService.getRawDataSourceMetadata("my_glue")) .thenReturn(constructMyGlueDataSourceMetadataWithBadURISyntax()); String query = "select * from my_glue.default.http_logs"; @@ -555,13 +518,6 @@ void testDispatchWithWrongURI() { @Test void testDispatchWithUnSupportedDataSourceType() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); when(dataSourceService.getRawDataSourceMetadata("my_prometheus")) .thenReturn(constructPrometheusDataSourceType()); String query = "select * from my_prometheus.default.http_logs"; @@ -584,13 +540,6 @@ void testDispatchWithUnSupportedDataSourceType() { @Test void testCancelJob() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn( new CancelJobRunResult() @@ -604,13 +553,6 @@ void testCancelJob() { @Test void testGetQueryResponse() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID)) .thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING))); @@ -660,13 +602,6 @@ void testGetQueryResponseWithSuccess() { @Test void testDropIndexQuery() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); String query = "DROP INDEX size_year ON my_glue.default.http_logs"; when(flintIndexMetadataReader.getJobIdFromFlintIndexMetadata( new IndexDetails( @@ -711,13 +646,6 @@ void testDropIndexQuery() { @Test void testDropSkippingIndexQuery() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader); String query = "DROP SKIPPING INDEX ON my_glue.default.http_logs"; when(flintIndexMetadataReader.getJobIdFromFlintIndexMetadata( new IndexDetails( @@ -760,6 +688,39 @@ void testDropSkippingIndexQuery() { Assertions.assertTrue(dispatchQueryResponse.isDropIndexQuery()); } + @Test + void testDispatchQueryWithExtraSparkSubmitParameters() { + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); + doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); + + String extraParameters = "--conf spark.dynamicAllocation.enabled=false"; + DispatchQueryRequest[] requests = { + // SQL direct query + constructDispatchQueryRequest( + "select * from my_glue.default.http_logs", LangType.SQL, extraParameters), + // SQL index query + constructDispatchQueryRequest( + "create skipping index on my_glue.default.http_logs (status VALUE_SET)", + LangType.SQL, + extraParameters), + // PPL query + constructDispatchQueryRequest( + "source = my_glue.default.http_logs", LangType.PPL, extraParameters) + }; + + for (DispatchQueryRequest request : requests) { + when(emrServerlessClient.startJobRun(any())).thenReturn(EMR_JOB_ID); + sparkQueryDispatcher.dispatch(request); + + verify(emrServerlessClient, times(1)) + .startJobRun( + argThat( + actualReq -> actualReq.getSparkSubmitParams().endsWith(" " + extraParameters))); + reset(emrServerlessClient); + } + } + private String constructExpectedSparkSubmitParameterString( String auth, Map authParams) { StringBuilder authParamConfigBuilder = new StringBuilder(); @@ -881,4 +842,16 @@ private DataSourceMetadata constructPrometheusDataSourceType() { dataSourceMetadata.setProperties(properties); return dataSourceMetadata; } + + private DispatchQueryRequest constructDispatchQueryRequest( + String query, LangType langType, String extraParameters) { + return new DispatchQueryRequest( + EMRS_APPLICATION_ID, + query, + "my_glue", + langType, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME, + extraParameters); + } }