diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java index d6c1907f84..25e8006d66 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java @@ -34,6 +34,8 @@ public class DataSourceServiceImpl implements DataSourceService { private static String DATASOURCE_NAME_REGEX = "[@*A-Za-z]+?[*a-zA-Z_\\-0-9]*"; + public static final Set CONFIDENTIAL_AUTH_KEYS = + Set.of("auth.username", "auth.password", "auth.access_key", "auth.secret_key"); private final DataSourceLoaderCache dataSourceLoaderCache; @@ -159,7 +161,12 @@ private void removeAuthInfo(Set dataSourceMetadataSet) { private void removeAuthInfo(DataSourceMetadata dataSourceMetadata) { HashMap safeProperties = new HashMap<>(dataSourceMetadata.getProperties()); - safeProperties.entrySet().removeIf(entry -> entry.getKey().contains("auth")); + safeProperties + .entrySet() + .removeIf( + entry -> + CONFIDENTIAL_AUTH_KEYS.stream() + .anyMatch(confidentialKey -> entry.getKey().endsWith(confidentialKey))); dataSourceMetadata.setProperties(safeProperties); } } diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index c8312e6013..6164d8b73f 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -233,7 +233,7 @@ void testGetDataSourceMetadataSet() { assertEquals(1, dataSourceMetadataSet.size()); DataSourceMetadata dataSourceMetadata = dataSourceMetadataSet.iterator().next(); assertTrue(dataSourceMetadata.getProperties().containsKey("prometheus.uri")); - assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.type")); + assertTrue(dataSourceMetadata.getProperties().containsKey("prometheus.auth.type")); assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.username")); assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.password")); assertFalse( @@ -352,11 +352,72 @@ void testRemovalOfAuthorizationInfo() { DataSourceMetadata dataSourceMetadata1 = dataSourceService.getDataSourceMetadata("testDS"); assertEquals("testDS", dataSourceMetadata1.getName()); assertEquals(DataSourceType.PROMETHEUS, dataSourceMetadata1.getConnector()); - assertFalse(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); assertFalse(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.username")); assertFalse(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.password")); } + @Test + void testRemovalOfAuthorizationInfoForAccessKeyAndSecretKye() { + HashMap properties = new HashMap<>(); + properties.put("prometheus.uri", "https://localhost:9090"); + properties.put("prometheus.auth.type", "awssigv4"); + properties.put("prometheus.auth.access_key", "access_key"); + properties.put("prometheus.auth.secret_key", "secret_key"); + DataSourceMetadata dataSourceMetadata = + new DataSourceMetadata( + "testDS", + DataSourceType.PROMETHEUS, + Collections.singletonList("prometheus_access"), + properties, + null); + when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) + .thenReturn(Optional.of(dataSourceMetadata)); + + DataSourceMetadata dataSourceMetadata1 = dataSourceService.getDataSourceMetadata("testDS"); + assertEquals("testDS", dataSourceMetadata1.getName()); + assertEquals(DataSourceType.PROMETHEUS, dataSourceMetadata1.getConnector()); + assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); + assertFalse(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.access_key")); + assertFalse(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.secret_key")); + } + + @Test + void testRemovalOfAuthorizationInfoForGlueWithRoleARN() { + HashMap properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put("glue.auth.role_arn", "role_arn"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9200"); + properties.put("glue.indexstore.opensearch.auth", "basicauth"); + properties.put("glue.indexstore.opensearch.auth.username", "username"); + properties.put("glue.indexstore.opensearch.auth.password", "password"); + DataSourceMetadata dataSourceMetadata = + new DataSourceMetadata( + "testGlue", + DataSourceType.S3GLUE, + Collections.singletonList("glue_access"), + properties, + null); + when(dataSourceMetadataStorage.getDataSourceMetadata("testGlue")) + .thenReturn(Optional.of(dataSourceMetadata)); + + DataSourceMetadata dataSourceMetadata1 = dataSourceService.getDataSourceMetadata("testGlue"); + assertEquals("testGlue", dataSourceMetadata1.getName()); + assertEquals(DataSourceType.S3GLUE, dataSourceMetadata1.getConnector()); + assertTrue(dataSourceMetadata1.getProperties().containsKey("glue.auth.type")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("glue.auth.role_arn")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("glue.indexstore.opensearch.uri")); + assertTrue(dataSourceMetadata1.getProperties().containsKey("glue.indexstore.opensearch.auth")); + assertFalse( + dataSourceMetadata1 + .getProperties() + .containsKey("glue.indexstore.opensearch.auth.username")); + assertFalse( + dataSourceMetadata1 + .getProperties() + .containsKey("glue.indexstore.opensearch.auth.password")); + } + @Test void testGetDataSourceMetadataForNonExistingDataSource() { when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")).thenReturn(Optional.empty()); @@ -381,7 +442,7 @@ void testGetDataSourceMetadataForSpecificDataSourceName() { "testDS", DataSourceType.PROMETHEUS, Collections.emptyList(), properties))); DataSourceMetadata dataSourceMetadata = this.dataSourceService.getDataSourceMetadata("testDS"); assertTrue(dataSourceMetadata.getProperties().containsKey("prometheus.uri")); - assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.type")); + assertTrue(dataSourceMetadata.getProperties().containsKey("prometheus.auth.type")); assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.username")); assertFalse(dataSourceMetadata.getProperties().containsKey("prometheus.auth.password")); verify(dataSourceMetadataStorage, times(1)).getDataSourceMetadata("testDS"); diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java index 087629a1f1..8623b9fa6f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java @@ -85,6 +85,10 @@ public void createDataSourceAPITest() { new Gson().fromJson(getResponseString, DataSourceMetadata.class); Assert.assertEquals( "https://localhost:9090", dataSourceMetadata.getProperties().get("prometheus.uri")); + Assert.assertEquals( + "basicauth", dataSourceMetadata.getProperties().get("prometheus.auth.type")); + Assert.assertNull(dataSourceMetadata.getProperties().get("prometheus.auth.username")); + Assert.assertNull(dataSourceMetadata.getProperties().get("prometheus.auth.password")); Assert.assertEquals("Prometheus Creation for Integ test", dataSourceMetadata.getDescription()); } @@ -239,6 +243,10 @@ public void issue2196() { new Gson().fromJson(getResponseString, DataSourceMetadata.class); Assert.assertEquals( "https://localhost:9090", dataSourceMetadata.getProperties().get("prometheus.uri")); + Assert.assertEquals( + "basicauth", dataSourceMetadata.getProperties().get("prometheus.auth.type")); + Assert.assertNull(dataSourceMetadata.getProperties().get("prometheus.auth.username")); + Assert.assertNull(dataSourceMetadata.getProperties().get("prometheus.auth.password")); Assert.assertEquals("Prometheus Creation for Integ test", dataSourceMetadata.getDescription()); } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index d2c8c6ebb7..f3fd043b63 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -96,6 +96,8 @@ import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EmrServerlessClientImpl; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -216,15 +218,21 @@ public Collection createComponents( dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); - if (StringUtils.isEmpty(this.pluginSettings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG))) { + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = + new SparkExecutionEngineConfigSupplierImpl(pluginSettings); + SparkExecutionEngineConfig sparkExecutionEngineConfig = + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) { LOGGER.warn( String.format( - "Async Query APIs are disabled as %s is not configured in cluster settings. " + "Async Query APIs are disabled as %s is not configured properly in cluster settings. " + "Please configure and restart the domain to enable Async Query APIs", SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue())); this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl(); } else { - this.asyncQueryExecutorService = createAsyncQueryExecutorService(); + this.asyncQueryExecutorService = + createAsyncQueryExecutorService( + sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig); } ModulesBuilder modules = new ModulesBuilder(); @@ -295,10 +303,13 @@ private DataSourceServiceImpl createDataSourceService() { dataSourceUserAuthorizationHelper); } - private AsyncQueryExecutorService createAsyncQueryExecutorService() { + private AsyncQueryExecutorService createAsyncQueryExecutorService( + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier, + SparkExecutionEngineConfig sparkExecutionEngineConfig) { AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); - EMRServerlessClient emrServerlessClient = createEMRServerlessClient(); + EMRServerlessClient emrServerlessClient = + createEMRServerlessClient(sparkExecutionEngineConfig.getRegion()); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( @@ -309,21 +320,18 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService() { new FlintIndexMetadataReaderImpl(client), client); return new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, sparkQueryDispatcher, pluginSettings); + asyncQueryJobMetadataStorageService, + sparkQueryDispatcher, + sparkExecutionEngineConfigSupplier); } - private EMRServerlessClient createEMRServerlessClient() { - String sparkExecutionEngineConfigString = - this.pluginSettings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); + private EMRServerlessClient createEMRServerlessClient(String region) { return AccessController.doPrivileged( (PrivilegedAction) () -> { - SparkExecutionEngineConfig sparkExecutionEngineConfig = - SparkExecutionEngineConfig.toSparkExecutionEngineConfig( - sparkExecutionEngineConfigString); AWSEMRServerless awsemrServerless = AWSEMRServerlessClientBuilder.standard() - .withRegion(sparkExecutionEngineConfig.getRegion()) + .withRegion(region) .withCredentials(new DefaultAWSCredentialsProviderChain()) .build(); return new EmrServerlessClientImpl(awsemrServerless); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 55346bc289..13db103f4b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -5,26 +5,22 @@ package org.opensearch.sql.spark.asyncquery; -import static org.opensearch.sql.common.setting.Settings.Key.CLUSTER_NAME; import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.JobRunState; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.util.ArrayList; import java.util.List; import java.util.Optional; import lombok.AllArgsConstructor; import org.json.JSONObject; -import org.opensearch.cluster.ClusterName; -import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -37,7 +33,7 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService { private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService; private SparkQueryDispatcher sparkQueryDispatcher; - private Settings settings; + private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; private Boolean isSparkJobExecutionEnabled; public AsyncQueryExecutorServiceImpl() { @@ -47,26 +43,19 @@ public AsyncQueryExecutorServiceImpl() { public AsyncQueryExecutorServiceImpl( AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService, SparkQueryDispatcher sparkQueryDispatcher, - Settings settings) { + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) { this.isSparkJobExecutionEnabled = Boolean.TRUE; this.asyncQueryJobMetadataStorageService = asyncQueryJobMetadataStorageService; this.sparkQueryDispatcher = sparkQueryDispatcher; - this.settings = settings; + this.sparkExecutionEngineConfigSupplier = sparkExecutionEngineConfigSupplier; } @Override public CreateAsyncQueryResponse createAsyncQuery( CreateAsyncQueryRequest createAsyncQueryRequest) { validateSparkExecutionEngineSettings(); - String sparkExecutionEngineConfigString = - settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); SparkExecutionEngineConfig sparkExecutionEngineConfig = - AccessController.doPrivileged( - (PrivilegedAction) - () -> - SparkExecutionEngineConfig.toSparkExecutionEngineConfig( - sparkExecutionEngineConfigString)); - ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME); + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); DispatchQueryResponse dispatchQueryResponse = sparkQueryDispatcher.dispatch( new DispatchQueryRequest( @@ -75,7 +64,7 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getDatasource(), createAsyncQueryRequest.getLang(), sparkExecutionEngineConfig.getExecutionRoleARN(), - clusterName.value(), + sparkExecutionEngineConfig.getClusterName(), sparkExecutionEngineConfig.getSparkSubmitParameters())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata( diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java index 23e5907b5c..537a635150 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfig.java @@ -1,25 +1,21 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - package org.opensearch.sql.spark.config; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.google.gson.Gson; +import lombok.AllArgsConstructor; import lombok.Data; +import lombok.NoArgsConstructor; +/** + * POJO for spark Execution Engine Config. Interface between {@link + * org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService} and {@link + * SparkExecutionEngineConfigSupplier} + */ @Data -@JsonIgnoreProperties(ignoreUnknown = true) +@NoArgsConstructor +@AllArgsConstructor public class SparkExecutionEngineConfig { private String applicationId; private String region; private String executionRoleARN; - - /** Additional Spark submit parameters to append to request. */ private String sparkSubmitParameters; - - public static SparkExecutionEngineConfig toSparkExecutionEngineConfig(String jsonString) { - return new Gson().fromJson(jsonString, SparkExecutionEngineConfig.class); - } + private String clusterName; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java new file mode 100644 index 0000000000..b3f1295faa --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSetting.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.config; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.google.gson.Gson; +import lombok.Data; + +/** + * This POJO is just for reading stringified json in `plugins.query.executionengine.spark.config` + * setting. + */ +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class SparkExecutionEngineConfigClusterSetting { + private String applicationId; + private String region; + private String executionRoleARN; + + /** Additional Spark submit parameters to append to request. */ + private String sparkSubmitParameters; + + public static SparkExecutionEngineConfigClusterSetting toSparkExecutionEngineConfig( + String jsonString) { + return new Gson().fromJson(jsonString, SparkExecutionEngineConfigClusterSetting.class); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java new file mode 100644 index 0000000000..108cb07daf --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplier.java @@ -0,0 +1,12 @@ +package org.opensearch.sql.spark.config; + +/** Interface for extracting and providing SparkExecutionEngineConfig */ +public interface SparkExecutionEngineConfigSupplier { + + /** + * Get SparkExecutionEngineConfig + * + * @return {@link SparkExecutionEngineConfig}. + */ + SparkExecutionEngineConfig getSparkExecutionEngineConfig(); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java new file mode 100644 index 0000000000..f4c32f24eb --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImpl.java @@ -0,0 +1,42 @@ +package org.opensearch.sql.spark.config; + +import static org.opensearch.sql.common.setting.Settings.Key.CLUSTER_NAME; +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; + +import java.security.AccessController; +import java.security.PrivilegedAction; +import lombok.AllArgsConstructor; +import org.apache.commons.lang3.StringUtils; +import org.opensearch.cluster.ClusterName; +import org.opensearch.sql.common.setting.Settings; + +@AllArgsConstructor +public class SparkExecutionEngineConfigSupplierImpl implements SparkExecutionEngineConfigSupplier { + + private Settings settings; + + @Override + public SparkExecutionEngineConfig getSparkExecutionEngineConfig() { + String sparkExecutionEngineConfigSettingString = + this.settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG); + SparkExecutionEngineConfig sparkExecutionEngineConfig = new SparkExecutionEngineConfig(); + if (!StringUtils.isBlank(sparkExecutionEngineConfigSettingString)) { + SparkExecutionEngineConfigClusterSetting sparkExecutionEngineConfigClusterSetting = + AccessController.doPrivileged( + (PrivilegedAction) + () -> + SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig( + sparkExecutionEngineConfigSettingString)); + sparkExecutionEngineConfig.setApplicationId( + sparkExecutionEngineConfigClusterSetting.getApplicationId()); + sparkExecutionEngineConfig.setExecutionRoleARN( + sparkExecutionEngineConfigClusterSetting.getExecutionRoleARN()); + sparkExecutionEngineConfig.setSparkSubmitParameters( + sparkExecutionEngineConfigClusterSetting.getSparkSubmitParameters()); + sparkExecutionEngineConfig.setRegion(sparkExecutionEngineConfigClusterSetting.getRegion()); + } + ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME); + sparkExecutionEngineConfig.setClusterName(clusterName.value()); + return sparkExecutionEngineConfig; + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index e16dd89639..01bccd9030 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -27,11 +27,11 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.cluster.ClusterName; -import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -44,15 +44,17 @@ public class AsyncQueryExecutorServiceImplTest { @Mock private SparkQueryDispatcher sparkQueryDispatcher; @Mock private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService; - @Mock private Settings settings; - private AsyncQueryExecutorService jobExecutorService; + @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; + @BeforeEach void setUp() { jobExecutorService = new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings); + asyncQueryJobMetadataStorageService, + sparkQueryDispatcher, + sparkExecutionEngineConfigSupplier); } @Test @@ -60,11 +62,14 @@ void testCreateAsyncQuery() { CreateAsyncQueryRequest createAsyncQueryRequest = new CreateAsyncQueryRequest( "select * from my_glue.default.http_logs", "my_glue", LangType.SQL); - when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) .thenReturn( - "{\"applicationId\":\"00fd775baqpu4g0p\",\"executionRoleARN\":\"arn:aws:iam::270824043731:role/emr-job-execution-role\",\"region\":\"eu-west-1\"}"); - when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) - .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + new SparkExecutionEngineConfig( + "00fd775baqpu4g0p", + "eu-west-1", + "arn:aws:iam::270824043731:role/emr-job-execution-role", + null, + TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch( new DispatchQueryRequest( "00fd775baqpu4g0p", @@ -78,8 +83,7 @@ void testCreateAsyncQuery() { jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) .storeJobMetadata(new AsyncQueryJobMetadata("00fd775baqpu4g0p", EMR_JOB_ID, null)); - verify(settings, times(1)).getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG); - verify(settings, times(1)).getSettingValue(Settings.Key.CLUSTER_NAME); + verify(sparkExecutionEngineConfigSupplier, times(1)).getSparkExecutionEngineConfig(); verify(sparkQueryDispatcher, times(1)) .dispatch( new DispatchQueryRequest( @@ -94,16 +98,14 @@ void testCreateAsyncQuery() { @Test void testCreateAsyncQueryWithExtraSparkSubmitParameter() { - when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) .thenReturn( - "{" - + "\"applicationId\": \"00fd775baqpu4g0p\"," - + "\"executionRoleARN\": \"arn:aws:iam::270824043731:role/emr-job-execution-role\"," - + "\"region\": \"eu-west-1\"," - + "\"sparkSubmitParameters\": \"--conf spark.dynamicAllocation.enabled=false\"" - + "}"); - when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) - .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + new SparkExecutionEngineConfig( + "00fd775baqpu4g0p", + "eu-west-1", + "arn:aws:iam::270824043731:role/emr-job-execution-role", + "--conf spark.dynamicAllocation.enabled=false", + TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null)); @@ -131,7 +133,7 @@ void testGetAsyncQueryResultsWithJobNotFoundException() { Assertions.assertEquals( "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); verifyNoInteractions(sparkQueryDispatcher); - verifyNoInteractions(settings); + verifyNoInteractions(sparkExecutionEngineConfigSupplier); } @Test @@ -149,7 +151,7 @@ void testGetAsyncQueryResultsWithInProgressJob() { Assertions.assertNull(asyncQueryExecutionResponse.getResults()); Assertions.assertNull(asyncQueryExecutionResponse.getSchema()); Assertions.assertEquals("PENDING", asyncQueryExecutionResponse.getStatus()); - verifyNoInteractions(settings); + verifyNoInteractions(sparkExecutionEngineConfigSupplier); } @Test @@ -173,7 +175,7 @@ void testGetAsyncQueryResultsWithSuccessJob() throws IOException { 1, ((HashMap) asyncQueryExecutionResponse.getResults().get(0).value()) .get("1")); - verifyNoInteractions(settings); + verifyNoInteractions(sparkExecutionEngineConfigSupplier); } @Test @@ -200,7 +202,7 @@ void testCancelJobWithJobNotFound() { Assertions.assertEquals( "QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage()); verifyNoInteractions(sparkQueryDispatcher); - verifyNoInteractions(settings); + verifyNoInteractions(sparkExecutionEngineConfigSupplier); } @Test @@ -212,6 +214,6 @@ void testCancelJob() { .thenReturn(EMR_JOB_ID); String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, jobId); - verifyNoInteractions(settings); + verifyNoInteractions(sparkExecutionEngineConfigSupplier); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigTest.java b/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingTest.java similarity index 79% rename from spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigTest.java rename to spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingTest.java index 29b69ea830..c6be37567d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigClusterSettingTest.java @@ -10,7 +10,7 @@ import org.junit.jupiter.api.Test; -public class SparkExecutionEngineConfigTest { +public class SparkExecutionEngineConfigClusterSettingTest { @Test public void testToSparkExecutionEngineConfigWithoutAllFields() { @@ -20,8 +20,8 @@ public void testToSparkExecutionEngineConfigWithoutAllFields() { + "\"executionRoleARN\": \"role-1\"," + "\"region\": \"us-west-1\"" + "}"; - SparkExecutionEngineConfig config = - SparkExecutionEngineConfig.toSparkExecutionEngineConfig(json); + SparkExecutionEngineConfigClusterSetting config = + SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig(json); assertEquals("app-1", config.getApplicationId()); assertEquals("role-1", config.getExecutionRoleARN()); @@ -38,8 +38,8 @@ public void testToSparkExecutionEngineConfigWithAllFields() { + "\"region\": \"us-west-1\"," + "\"sparkSubmitParameters\": \"--conf A=1\"" + "}"; - SparkExecutionEngineConfig config = - SparkExecutionEngineConfig.toSparkExecutionEngineConfig(json); + SparkExecutionEngineConfigClusterSetting config = + SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig(json); assertEquals("app-1", config.getApplicationId()); assertEquals("role-1", config.getExecutionRoleARN()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java new file mode 100644 index 0000000000..298a56b17a --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/config/SparkExecutionEngineConfigSupplierImplTest.java @@ -0,0 +1,61 @@ +package org.opensearch.sql.spark.config; + +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.cluster.ClusterName; +import org.opensearch.sql.common.setting.Settings; + +@ExtendWith(MockitoExtension.class) +public class SparkExecutionEngineConfigSupplierImplTest { + + @Mock private Settings settings; + + @Test + void testGetSparkExecutionEngineConfig() { + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = + new SparkExecutionEngineConfigSupplierImpl(settings); + when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)) + .thenReturn( + "{" + + "\"applicationId\": \"00fd775baqpu4g0p\"," + + "\"executionRoleARN\": \"arn:aws:iam::270824043731:role/emr-job-execution-role\"," + + "\"region\": \"eu-west-1\"," + + "\"sparkSubmitParameters\": \"--conf spark.dynamicAllocation.enabled=false\"" + + "}"); + when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) + .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + SparkExecutionEngineConfig sparkExecutionEngineConfig = + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + Assertions.assertEquals("00fd775baqpu4g0p", sparkExecutionEngineConfig.getApplicationId()); + Assertions.assertEquals( + "arn:aws:iam::270824043731:role/emr-job-execution-role", + sparkExecutionEngineConfig.getExecutionRoleARN()); + Assertions.assertEquals("eu-west-1", sparkExecutionEngineConfig.getRegion()); + Assertions.assertEquals( + "--conf spark.dynamicAllocation.enabled=false", + sparkExecutionEngineConfig.getSparkSubmitParameters()); + Assertions.assertEquals(TEST_CLUSTER_NAME, sparkExecutionEngineConfig.getClusterName()); + } + + @Test + void testGetSparkExecutionEngineConfigWithNullSetting() { + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = + new SparkExecutionEngineConfigSupplierImpl(settings); + when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG)).thenReturn(null); + when(settings.getSettingValue(Settings.Key.CLUSTER_NAME)) + .thenReturn(new ClusterName(TEST_CLUSTER_NAME)); + SparkExecutionEngineConfig sparkExecutionEngineConfig = + sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + Assertions.assertNull(sparkExecutionEngineConfig.getApplicationId()); + Assertions.assertNull(sparkExecutionEngineConfig.getExecutionRoleARN()); + Assertions.assertNull(sparkExecutionEngineConfig.getRegion()); + Assertions.assertNull(sparkExecutionEngineConfig.getSparkSubmitParameters()); + Assertions.assertEquals(TEST_CLUSTER_NAME, sparkExecutionEngineConfig.getClusterName()); + } +}