From 94bd664b6ff174374621864fe0afcbc7202c4186 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Fri, 2 Feb 2024 11:00:28 -0800 Subject: [PATCH] Refactor async executor service dependencies using guice framework (#2488) Signed-off-by: Vamsi Manohar --- .../sql/legacy/plugin/RestSqlStatsAction.java | 24 ++- .../org/opensearch/sql/plugin/SQLPlugin.java | 106 ++----------- .../sql/plugin/rest/RestPPLStatsAction.java | 8 +- .../AsyncQueryExecutorServiceImpl.java | 28 ---- .../client/EMRServerlessClientFactory.java | 17 +++ .../EMRServerlessClientFactoryImpl.java | 71 +++++++++ .../dispatcher/SparkQueryDispatcher.java | 18 +-- .../execution/session/SessionManager.java | 14 +- .../config/AsyncExecutorServiceModule.java | 143 ++++++++++++++++++ ...AsyncQueryExecutorServiceImplSpecTest.java | 49 ++++-- .../AsyncQueryExecutorServiceImplTest.java | 14 -- .../AsyncQueryExecutorServiceSpec.java | 19 ++- .../AsyncQueryGetResultSpecTest.java | 4 +- .../spark/asyncquery/IndexQuerySpecTest.java | 125 ++++++++++++--- .../EMRServerlessClientFactoryImplTest.java | 96 ++++++++++++ .../sql/spark/constants/TestConstants.java | 2 + .../dispatcher/SparkQueryDispatcherTest.java | 5 +- .../session/InteractiveSessionTest.java | 12 +- .../execution/session/SessionManagerTest.java | 8 +- .../execution/statement/StatementTest.java | 25 ++- .../AsyncExecutorServiceModuleTest.java | 50 ++++++ 21 files changed, 624 insertions(+), 214 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModuleTest.java diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java index bc0f3c73b8..383363b1e3 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java @@ -11,11 +11,14 @@ import java.util.Arrays; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.ThreadContext; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.rest.RestStatus; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; @@ -24,6 +27,7 @@ import org.opensearch.sql.common.utils.QueryContext; import org.opensearch.sql.legacy.executor.format.ErrorMessageFactory; import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.threadpool.ThreadPool; /** * Currently this interface is for node level. Cluster level is coming up soon. @@ -69,8 +73,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli try { return channel -> - channel.sendResponse( - new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON())); + schedule( + client, + () -> + channel.sendResponse( + new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON()))); } catch (Exception e) { LOG.error("Failed during Query SQL STATS Action.", e); @@ -91,4 +98,17 @@ protected Set responseParams() { "sql", "flat", "separator", "_score", "_type", "_id", "newLine", "format", "sanitize")); return responseParams; } + + private void schedule(NodeClient client, Runnable task) { + ThreadPool threadPool = client.threadPool(); + threadPool.schedule(withCurrentContext(task), new TimeValue(0), "sql-worker"); + } + + private Runnable withCurrentContext(final Runnable task) { + final Map currentContext = ThreadContext.getImmutableContext(); + return () -> { + ThreadContext.putAll(currentContext); + task.run(); + }; + } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f0689a0966..2b75a8b2c9 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -5,21 +5,14 @@ package org.opensearch.sql.plugin; -import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; +import static java.util.Collections.singletonList; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; -import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.emrserverless.AWSEMRServerless; -import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.time.Clock; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -68,7 +61,6 @@ import org.opensearch.sql.datasources.transport.*; import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.executor.AsyncRestExecutor; -import org.opensearch.sql.legacy.metrics.GaugeMetric; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.legacy.plugin.RestSqlAction; import org.opensearch.sql.legacy.plugin.RestSqlStatsAction; @@ -87,26 +79,13 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; -import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; -import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; -import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; -import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.client.EmrServerlessClientImpl; import org.opensearch.sql.spark.cluster.ClusterManagerEventListener; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; -import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; -import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; -import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; -import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; -import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction; import org.opensearch.sql.spark.transport.TransportGetAsyncQueryResultAction; +import org.opensearch.sql.spark.transport.config.AsyncExecutorServiceModule; import org.opensearch.sql.spark.transport.model.CancelAsyncQueryActionResponse; import org.opensearch.sql.spark.transport.model.CreateAsyncQueryActionResponse; import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse; @@ -127,7 +106,6 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin { private NodeClient client; private DataSourceServiceImpl dataSourceService; - private AsyncQueryExecutorService asyncQueryExecutorService; private Injector injector; public String name() { @@ -223,23 +201,6 @@ public Collection createComponents( dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); - SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier = - new SparkExecutionEngineConfigSupplierImpl(pluginSettings); - SparkExecutionEngineConfig sparkExecutionEngineConfig = - sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); - if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) { - LOGGER.warn( - String.format( - "Async Query APIs are disabled as %s is not configured properly in cluster settings. " - + "Please configure and restart the domain to enable Async Query APIs", - SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue())); - this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl(); - } else { - this.asyncQueryExecutorService = - createAsyncQueryExecutorService( - sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig); - } - ModulesBuilder modules = new ModulesBuilder(); modules.add(new OpenSearchPluginModule()); modules.add( @@ -247,8 +208,9 @@ public Collection createComponents( b.bind(NodeClient.class).toInstance((NodeClient) client); b.bind(org.opensearch.sql.common.setting.Settings.class).toInstance(pluginSettings); b.bind(DataSourceService.class).toInstance(dataSourceService); + b.bind(ClusterService.class).toInstance(clusterService); }); - + modules.add(new AsyncExecutorServiceModule()); injector = modules.createInjector(); ClusterManagerEventListener clusterManagerEventListener = new ClusterManagerEventListener( @@ -261,12 +223,15 @@ public Collection createComponents( OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING, environment.settings()); return ImmutableList.of( - dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings); + dataSourceService, + injector.getInstance(AsyncQueryExecutorService.class), + clusterManagerEventListener, + pluginSettings); } @Override public List> getExecutorBuilders(Settings settings) { - return Collections.singletonList( + return singletonList( new FixedExecutorBuilder( settings, AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME, @@ -318,57 +283,4 @@ private DataSourceServiceImpl createDataSourceService() { dataSourceMetadataStorage, dataSourceUserAuthorizationHelper); } - - private AsyncQueryExecutorService createAsyncQueryExecutorService( - SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier, - SparkExecutionEngineConfig sparkExecutionEngineConfig) { - StateStore stateStore = new StateStore(client, clusterService); - registerStateStoreMetrics(stateStore); - AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(stateStore); - EMRServerlessClient emrServerlessClient = - createEMRServerlessClient(sparkExecutionEngineConfig.getRegion()); - JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - this.dataSourceService, - new DataSourceUserAuthorizationHelperImpl(client), - jobExecutionResponseReader, - new FlintIndexMetadataReaderImpl(client), - client, - new SessionManager(stateStore, emrServerlessClient, pluginSettings), - new DefaultLeaseManager(pluginSettings, stateStore), - stateStore); - return new AsyncQueryExecutorServiceImpl( - asyncQueryJobMetadataStorageService, - sparkQueryDispatcher, - sparkExecutionEngineConfigSupplier); - } - - private void registerStateStoreMetrics(StateStore stateStore) { - GaugeMetric activeSessionMetric = - new GaugeMetric<>( - "active_async_query_sessions_count", - StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE)); - GaugeMetric activeStatementMetric = - new GaugeMetric<>( - "active_async_query_statements_count", - StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE)); - Metrics.getInstance().registerMetric(activeSessionMetric); - Metrics.getInstance().registerMetric(activeStatementMetric); - } - - private EMRServerlessClient createEMRServerlessClient(String region) { - return AccessController.doPrivileged( - (PrivilegedAction) - () -> { - AWSEMRServerless awsemrServerless = - AWSEMRServerlessClientBuilder.standard() - .withRegion(region) - .withCredentials(new DefaultAWSCredentialsProviderChain()) - .build(); - return new EmrServerlessClientImpl(awsemrServerless); - }); - } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java index 7a51fc282b..39a3d20abb 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java @@ -22,6 +22,7 @@ import org.opensearch.rest.RestController; import org.opensearch.rest.RestRequest; import org.opensearch.sql.common.utils.QueryContext; +import org.opensearch.sql.datasources.utils.Scheduler; import org.opensearch.sql.legacy.executor.format.ErrorMessageFactory; import org.opensearch.sql.legacy.metrics.Metrics; @@ -67,8 +68,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli try { return channel -> - channel.sendResponse( - new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON())); + Scheduler.schedule( + client, + () -> + channel.sendResponse( + new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON()))); } catch (Exception e) { LOG.error("Failed during Query PPL STATS Action.", e); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 1c0979dffb..eb77725052 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -5,7 +5,6 @@ package org.opensearch.sql.spark.asyncquery; -import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; @@ -34,26 +33,10 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService; private SparkQueryDispatcher sparkQueryDispatcher; private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; - private Boolean isSparkJobExecutionEnabled; - - public AsyncQueryExecutorServiceImpl() { - this.isSparkJobExecutionEnabled = Boolean.FALSE; - } - - public AsyncQueryExecutorServiceImpl( - AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService, - SparkQueryDispatcher sparkQueryDispatcher, - SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) { - this.isSparkJobExecutionEnabled = Boolean.TRUE; - this.asyncQueryJobMetadataStorageService = asyncQueryJobMetadataStorageService; - this.sparkQueryDispatcher = sparkQueryDispatcher; - this.sparkExecutionEngineConfigSupplier = sparkExecutionEngineConfigSupplier; - } @Override public CreateAsyncQueryResponse createAsyncQuery( CreateAsyncQueryRequest createAsyncQueryRequest) { - validateSparkExecutionEngineSettings(); SparkExecutionEngineConfig sparkExecutionEngineConfig = sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); DispatchQueryResponse dispatchQueryResponse = @@ -80,7 +63,6 @@ public CreateAsyncQueryResponse createAsyncQuery( @Override public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) { - validateSparkExecutionEngineSettings(); Optional jobMetadata = asyncQueryJobMetadataStorageService.getJobMetadata(queryId); if (jobMetadata.isPresent()) { @@ -120,14 +102,4 @@ public String cancelQuery(String queryId) { } throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId)); } - - private void validateSparkExecutionEngineSettings() { - if (!isSparkJobExecutionEnabled) { - throw new IllegalArgumentException( - String.format( - "Async Query APIs are disabled as %s is not configured in cluster settings. Please" - + " configure the setting and restart the domain to enable Async Query APIs", - SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue())); - } - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java new file mode 100644 index 0000000000..2c05dc865d --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactory.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +/** Factory interface for creating instances of {@link EMRServerlessClient}. */ +public interface EMRServerlessClientFactory { + + /** + * Gets an instance of {@link EMRServerlessClient}. + * + * @return An {@link EMRServerlessClient} instance. + */ + EMRServerlessClient getClient(); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java new file mode 100644 index 0000000000..e0cc5ea397 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImpl.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG; + +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.emrserverless.AWSEMRServerless; +import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder; +import java.security.AccessController; +import java.security.PrivilegedAction; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; + +/** Implementation of {@link EMRServerlessClientFactory}. */ +@RequiredArgsConstructor +public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory { + + private final SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; + private EMRServerlessClient emrServerlessClient; + private String region; + + /** + * Gets an instance of {@link EMRServerlessClient}. + * + * @return An {@link EMRServerlessClient} instance. + */ + @Override + public EMRServerlessClient getClient() { + SparkExecutionEngineConfig sparkExecutionEngineConfig = + this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(); + validateSparkExecutionEngineConfig(sparkExecutionEngineConfig); + if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) { + region = sparkExecutionEngineConfig.getRegion(); + this.emrServerlessClient = createEMRServerlessClient(this.region); + } + return this.emrServerlessClient; + } + + private boolean isNewClientCreationRequired(String region) { + return !region.equals(this.region); + } + + private void validateSparkExecutionEngineConfig( + SparkExecutionEngineConfig sparkExecutionEngineConfig) { + if (sparkExecutionEngineConfig == null || sparkExecutionEngineConfig.getRegion() == null) { + throw new IllegalArgumentException( + String.format( + "Async Query APIs are disabled. Please configure %s in cluster settings to enable" + + " them.", + SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue())); + } + } + + private EMRServerlessClient createEMRServerlessClient(String awsRegion) { + return AccessController.doPrivileged( + (PrivilegedAction) + () -> { + AWSEMRServerless awsemrServerless = + AWSEMRServerlessClientBuilder.standard() + .withRegion(awsRegion) + .withCredentials(new DefaultAWSCredentialsProviderChain()) + .build(); + return new EmrServerlessClientImpl(awsemrServerless); + }); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 0aa183335e..498a3b9af5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -8,8 +8,6 @@ import java.util.HashMap; import java.util.Map; import lombok.AllArgsConstructor; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.json.JSONObject; import org.opensearch.client.Client; import org.opensearch.sql.datasource.DataSourceService; @@ -18,6 +16,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -35,13 +34,12 @@ @AllArgsConstructor public class SparkQueryDispatcher { - private static final Logger LOG = LogManager.getLogger(); public static final String INDEX_TAG_KEY = "index"; public static final String DATASOURCE_TAG_KEY = "datasource"; public static final String CLUSTER_NAME_TAG_KEY = "domain_ident"; public static final String JOB_TYPE_TAG_KEY = "type"; - private EMRServerlessClient emrServerlessClient; + private EMRServerlessClientFactory emrServerlessClientFactory; private DataSourceService dataSourceService; @@ -60,10 +58,10 @@ public class SparkQueryDispatcher { private StateStore stateStore; public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { + EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); - AsyncQueryHandler asyncQueryHandler = sessionManager.isEnabled() ? new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) @@ -83,7 +81,7 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) contextBuilder.indexQueryDetails(indexQueryDetails); if (IndexQueryActionType.DROP.equals(indexQueryDetails.getIndexQueryActionType())) { - asyncQueryHandler = createIndexDMLHandler(); + asyncQueryHandler = createIndexDMLHandler(emrServerlessClient); } else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType()) && indexQueryDetails.isAutoRefresh()) { asyncQueryHandler = @@ -99,11 +97,12 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) } public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { + EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); if (asyncQueryJobMetadata.getSessionId() != null) { return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) .getQueryResponse(asyncQueryJobMetadata); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { - return createIndexDMLHandler().getQueryResponse(asyncQueryJobMetadata); + return createIndexDMLHandler(emrServerlessClient).getQueryResponse(asyncQueryJobMetadata); } else { return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager) .getQueryResponse(asyncQueryJobMetadata); @@ -111,12 +110,13 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); AsyncQueryHandler queryHandler; if (asyncQueryJobMetadata.getSessionId() != null) { queryHandler = new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { - queryHandler = createIndexDMLHandler(); + queryHandler = createIndexDMLHandler(emrServerlessClient); } else { queryHandler = new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager); @@ -124,7 +124,7 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { return queryHandler.cancelJob(asyncQueryJobMetadata); } - private IndexDMLHandler createIndexDMLHandler() { + private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) { return new IndexDMLHandler( emrServerlessClient, dataSourceService, diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index c3d5807305..e441492c20 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -10,7 +10,7 @@ import java.util.Optional; import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.utils.RealTimeProvider; @@ -21,13 +21,15 @@ */ public class SessionManager { private final StateStore stateStore; - private final EMRServerlessClient emrServerlessClient; + private final EMRServerlessClientFactory emrServerlessClientFactory; private Settings settings; public SessionManager( - StateStore stateStore, EMRServerlessClient emrServerlessClient, Settings settings) { + StateStore stateStore, + EMRServerlessClientFactory emrServerlessClientFactory, + Settings settings) { this.stateStore = stateStore; - this.emrServerlessClient = emrServerlessClient; + this.emrServerlessClientFactory = emrServerlessClientFactory; this.settings = settings; } @@ -36,7 +38,7 @@ public Session createSession(CreateSessionRequest request) { InteractiveSession.builder() .sessionId(newSessionId(request.getDatasourceName())) .stateStore(stateStore) - .serverlessClient(emrServerlessClient) + .serverlessClient(emrServerlessClientFactory.getClient()) .build(); session.open(request); return session; @@ -68,7 +70,7 @@ public Optional getSession(SessionId sid, String dataSourceName) { InteractiveSession.builder() .sessionId(sid) .stateStore(stateStore) - .serverlessClient(emrServerlessClient) + .serverlessClient(emrServerlessClientFactory.getClient()) .sessionModel(model.get()) .sessionInactivityTimeoutMilli( settings.getSettingValue(SESSION_INACTIVITY_TIMEOUT_MILLIS)) diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java new file mode 100644 index 0000000000..d88c1dd9df --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.transport.config; + +import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; + +import lombok.RequiredArgsConstructor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.AbstractModule; +import org.opensearch.common.inject.Provides; +import org.opensearch.common.inject.Singleton; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.legacy.metrics.GaugeMetric; +import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl; +import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.client.EMRServerlessClientFactoryImpl; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; +import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; +import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class AsyncExecutorServiceModule extends AbstractModule { + + @Override + protected void configure() {} + + @Provides + public AsyncQueryExecutorService asyncQueryExecutorService( + AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService, + SparkQueryDispatcher sparkQueryDispatcher, + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) { + return new AsyncQueryExecutorServiceImpl( + asyncQueryJobMetadataStorageService, + sparkQueryDispatcher, + sparkExecutionEngineConfigSupplier); + } + + @Provides + public AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService( + StateStore stateStore) { + return new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + } + + @Provides + @Singleton + public StateStore stateStore(NodeClient client, ClusterService clusterService) { + StateStore stateStore = new StateStore(client, clusterService); + registerStateStoreMetrics(stateStore); + return stateStore; + } + + @Provides + public SparkQueryDispatcher sparkQueryDispatcher( + EMRServerlessClientFactory emrServerlessClientFactory, + DataSourceService dataSourceService, + DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper, + JobExecutionResponseReader jobExecutionResponseReader, + FlintIndexMetadataReaderImpl flintIndexMetadataReader, + NodeClient client, + SessionManager sessionManager, + DefaultLeaseManager defaultLeaseManager, + StateStore stateStore) { + return new SparkQueryDispatcher( + emrServerlessClientFactory, + dataSourceService, + dataSourceUserAuthorizationHelper, + jobExecutionResponseReader, + flintIndexMetadataReader, + client, + sessionManager, + defaultLeaseManager, + stateStore); + } + + @Provides + public SessionManager sessionManager( + StateStore stateStore, + EMRServerlessClientFactory emrServerlessClientFactory, + Settings settings) { + return new SessionManager(stateStore, emrServerlessClientFactory, settings); + } + + @Provides + public DefaultLeaseManager defaultLeaseManager(Settings settings, StateStore stateStore) { + return new DefaultLeaseManager(settings, stateStore); + } + + @Provides + public EMRServerlessClientFactory createEMRServerlessClientFactory( + SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier) { + return new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier); + } + + @Provides + public SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier(Settings settings) { + return new SparkExecutionEngineConfigSupplierImpl(settings); + } + + @Provides + @Singleton + public FlintIndexMetadataReaderImpl flintIndexMetadataReader(NodeClient client) { + return new FlintIndexMetadataReaderImpl(client); + } + + @Provides + public JobExecutionResponseReader jobExecutionResponseReader(NodeClient client) { + return new JobExecutionResponseReader(client); + } + + @Provides + public DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper( + NodeClient client) { + return new DataSourceUserAuthorizationHelperImpl(client); + } + + private void registerStateStoreMetrics(StateStore stateStore) { + GaugeMetric activeSessionMetric = + new GaugeMetric<>( + "active_async_query_sessions_count", + StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE)); + GaugeMetric activeStatementMetric = + new GaugeMetric<>( + "active_async_query_statements_count", + StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE)); + Metrics.getInstance().registerMetric(activeSessionMetric); + Metrics.getInstance().registerMetric(activeStatementMetric); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 011d97dcdf..33fec89e26 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -32,6 +32,7 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statement.StatementModel; @@ -46,8 +47,9 @@ public class AsyncQueryExecutorServiceImplSpecTest extends AsyncQueryExecutorSer @Disabled("batch query is unsupported") public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // disable session enableSession(false); @@ -74,8 +76,9 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { @Disabled("batch query is unsupported") public void sessionLimitNotImpactBatchQuery() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // disable session enableSession(false); @@ -96,8 +99,9 @@ public void sessionLimitNotImpactBatchQuery() { @Disabled("batch query is unsupported") public void createAsyncQueryCreateJobWithCorrectParameters() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); enableSession(false); CreateAsyncQueryResponse response = @@ -129,8 +133,9 @@ public void createAsyncQueryCreateJobWithCorrectParameters() { @Test public void withSessionCreateAsyncQueryThenGetResultThenCancel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // 1. create async query. CreateAsyncQueryResponse response = @@ -156,8 +161,9 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { @Test public void reuseSessionWhenCreateAsyncQuery() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -207,8 +213,9 @@ public void reuseSessionWhenCreateAsyncQuery() { @Disabled("batch query is unsupported") public void batchQueryHasTimeout() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); enableSession(false); CreateAsyncQueryResponse response = @@ -221,8 +228,9 @@ public void batchQueryHasTimeout() { @Test public void interactiveQueryNoTimeout() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -255,8 +263,9 @@ public void datasourceWithBasicAuth() { properties, null)); LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -274,8 +283,9 @@ public void datasourceWithBasicAuth() { @Test public void withSessionCreateAsyncQueryFailed() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -322,8 +332,9 @@ public void withSessionCreateAsyncQueryFailed() { @Test public void createSessionMoreThanLimitFailed() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -351,8 +362,9 @@ public void createSessionMoreThanLimitFailed() { @Test public void recreateSessionIfNotReady() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -388,8 +400,9 @@ public void recreateSessionIfNotReady() { @Test public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -426,8 +439,9 @@ public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { @Test public void recreateSessionIfStale() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -480,8 +494,9 @@ public void recreateSessionIfStale() { @Test public void submitQueryInInvalidSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -516,8 +531,9 @@ public void datasourceNameIncludeUppercase() { null)); LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // enable session enableSession(true); @@ -536,8 +552,9 @@ public void datasourceNameIncludeUppercase() { @Test public void concurrentSessionLimitIsDomainLevel() { LocalEMRSClient emrsClient = new LocalEMRSClient(); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // only allow one session in domain. setSessionLimit(1); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index efb965e9f3..634df6670d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -186,20 +186,6 @@ void testGetAsyncQueryResultsWithSuccessJob() throws IOException { verifyNoInteractions(sparkExecutionEngineConfigSupplier); } - @Test - void testGetAsyncQueryResultsWithDisabledExecutionEngine() { - AsyncQueryExecutorService asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl(); - IllegalArgumentException illegalArgumentException = - Assertions.assertThrows( - IllegalArgumentException.class, - () -> asyncQueryExecutorService.getAsyncQueryResults(EMR_JOB_ID)); - Assertions.assertEquals( - "Async Query APIs are disabled as plugins.query.executionengine.spark.config is not" - + " configured in cluster settings. Please configure the setting and restart the domain" - + " to enable Async Query APIs", - illegalArgumentException.getMessage()); - } - @Test void testCancelJobWithJobNotFound() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index c7054dd200..c9b4b6fc88 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -59,6 +59,7 @@ import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; @@ -195,27 +196,27 @@ private DataSourceServiceImpl createDataSourceService() { } protected AsyncQueryExecutorService createAsyncQueryExecutorService( - EMRServerlessClient emrServerlessClient) { + EMRServerlessClientFactory emrServerlessClientFactory) { return createAsyncQueryExecutorService( - emrServerlessClient, new JobExecutionResponseReader(client)); + emrServerlessClientFactory, new JobExecutionResponseReader(client)); } /** Pass a custom response reader which can mock interaction between PPL plugin and EMR-S job. */ protected AsyncQueryExecutorService createAsyncQueryExecutorService( - EMRServerlessClient emrServerlessClient, + EMRServerlessClientFactory emrServerlessClientFactory, JobExecutionResponseReader jobExecutionResponseReader) { StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService(stateStore); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - emrServerlessClient, + emrServerlessClientFactory, this.dataSourceService, new DataSourceUserAuthorizationHelperImpl(client), jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), client, - new SessionManager(stateStore, emrServerlessClient, pluginSettings), + new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), new DefaultLeaseManager(pluginSettings, stateStore), stateStore); return new AsyncQueryExecutorServiceImpl( @@ -271,6 +272,14 @@ public void setJobState(JobRunState jobState) { } } + public static class LocalEMRServerlessClientFactory implements EMRServerlessClientFactory { + + @Override + public EMRServerlessClient getClient() { + return new LocalEMRSClient(); + } + } + public SparkExecutionEngineConfig sparkExecutionEngineConfig() { return new SparkExecutionEngineConfig("appId", "us-west-2", "roleArn", "", "myCluster"); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index 2ddfe77868..ab6439492a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -26,6 +26,7 @@ import org.opensearch.sql.protocol.response.format.ResponseFormatter; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -411,9 +412,10 @@ private class AssertionHelper { private Interaction interaction; AssertionHelper(String query, LocalEMRSClient emrClient) { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrClient; this.queryService = createAsyncQueryExecutorService( - emrClient, + emrServerlessClientFactory, /* * Custom reader that intercepts get results call and inject extra steps defined in * current interaction. Intercept both get methods for different query handler which diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java index 49ac538e65..844567f4f5 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/IndexQuerySpecTest.java @@ -12,6 +12,8 @@ import org.junit.Assert; import org.junit.Test; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexType; import org.opensearch.sql.spark.leasemanager.ConcurrencyLimitExceededException; @@ -72,9 +74,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -120,8 +128,15 @@ public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { throw new IllegalArgumentException("Job run is not in a cancellable state"); } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -157,8 +172,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Running")); } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -193,9 +215,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; - + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -248,8 +276,15 @@ public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { throw new IllegalArgumentException("Job run is not in a cancellable state"); } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -290,8 +325,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Running")); } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -331,8 +373,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return new GetJobRunResult().withJobRun(new JobRun().withState("Cancelled")); } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -380,8 +429,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -424,8 +480,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -468,8 +531,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -517,8 +587,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -565,8 +642,15 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { return null; } }; + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactory() { + @Override + public EMRServerlessClient getClient() { + return emrsClient; + } + }; AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(emrsClient); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index mockDS.createIndex(); @@ -586,8 +670,9 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { @Test public void concurrentRefreshJobLimitNotApplied() { + EMRServerlessClientFactory emrServerlessClientFactory = new LocalEMRServerlessClientFactory(); AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(new LocalEMRSClient()); + createAsyncQueryExecutorService(emrServerlessClientFactory); // Mock flint index COVERING.createIndex(); @@ -607,8 +692,9 @@ public void concurrentRefreshJobLimitNotApplied() { @Test public void concurrentRefreshJobLimitAppliedToDDLWithAuthRefresh() { + EMRServerlessClientFactory emrServerlessClientFactory = new LocalEMRServerlessClientFactory(); AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(new LocalEMRSClient()); + createAsyncQueryExecutorService(emrServerlessClientFactory); setConcurrentRefreshJob(1); @@ -633,8 +719,9 @@ public void concurrentRefreshJobLimitAppliedToDDLWithAuthRefresh() { @Test public void concurrentRefreshJobLimitAppliedToRefresh() { + EMRServerlessClientFactory emrServerlessClientFactory = new LocalEMRServerlessClientFactory(); AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(new LocalEMRSClient()); + createAsyncQueryExecutorService(emrServerlessClientFactory); setConcurrentRefreshJob(1); @@ -658,9 +745,9 @@ public void concurrentRefreshJobLimitAppliedToRefresh() { @Test public void concurrentRefreshJobLimitNotAppliedToDDL() { String query = "CREATE INDEX covering ON mys3.default.http_logs(l_orderkey, l_quantity)"; - + EMRServerlessClientFactory emrServerlessClientFactory = new LocalEMRServerlessClientFactory(); AsyncQueryExecutorService asyncQueryExecutorService = - createAsyncQueryExecutorService(new LocalEMRSClient()); + createAsyncQueryExecutorService(emrServerlessClientFactory); setConcurrentRefreshJob(1); diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java new file mode 100644 index 0000000000..9bfed9f498 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EMRServerlessClientFactoryImplTest.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; +import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; +import org.opensearch.sql.spark.constants.TestConstants; + +@ExtendWith(MockitoExtension.class) +public class EMRServerlessClientFactoryImplTest { + + @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; + + @Test + public void testGetClient() { + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) + .thenReturn(createSparkExecutionEngineConfig()); + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier); + EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(); + Assertions.assertNotNull(emrserverlessClient); + } + + @Test + public void testGetClientWithChangeInSetting() { + SparkExecutionEngineConfig sparkExecutionEngineConfig = createSparkExecutionEngineConfig(); + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) + .thenReturn(sparkExecutionEngineConfig); + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier); + EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(); + Assertions.assertNotNull(emrserverlessClient); + + EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient(); + Assertions.assertEquals(emrServerlessClient1, emrserverlessClient); + + sparkExecutionEngineConfig.setRegion(TestConstants.US_WEST_REGION); + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) + .thenReturn(sparkExecutionEngineConfig); + EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient(); + Assertions.assertNotEquals(emrServerlessClient2, emrserverlessClient); + Assertions.assertNotEquals(emrServerlessClient2, emrServerlessClient1); + } + + @Test + public void testGetClientWithException() { + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()).thenReturn(null); + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier); + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, emrServerlessClientFactory::getClient); + Assertions.assertEquals( + "Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config" + + " in cluster settings to enable them.", + illegalArgumentException.getMessage()); + } + + @Test + public void testGetClientWithExceptionWithNullRegion() { + SparkExecutionEngineConfig sparkExecutionEngineConfig = new SparkExecutionEngineConfig(); + when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) + .thenReturn(sparkExecutionEngineConfig); + EMRServerlessClientFactory emrServerlessClientFactory = + new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier); + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, emrServerlessClientFactory::getClient); + Assertions.assertEquals( + "Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config" + + " in cluster settings to enable them.", + illegalArgumentException.getMessage()); + } + + private SparkExecutionEngineConfig createSparkExecutionEngineConfig() { + SparkExecutionEngineConfig sparkExecutionEngineConfig = new SparkExecutionEngineConfig(); + sparkExecutionEngineConfig.setRegion(TestConstants.US_EAST_REGION); + sparkExecutionEngineConfig.setExecutionRoleARN(TestConstants.EMRS_EXECUTION_ROLE); + sparkExecutionEngineConfig.setSparkSubmitParameters( + SparkSubmitParameters.Builder.builder().build().toString()); + sparkExecutionEngineConfig.setClusterName(TestConstants.TEST_CLUSTER_NAME); + sparkExecutionEngineConfig.setApplicationId(TestConstants.EMRS_APPLICATION_ID); + return sparkExecutionEngineConfig; + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index cc13061323..b06b2e4cc3 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -21,4 +21,6 @@ public class TestConstants { public static final String ENTRY_POINT_START_JAR = "file:///home/hadoop/.ivy2/jars/org.opensearch_opensearch-spark-sql-application_2.12-0.1.0-SNAPSHOT.jar"; public static final String DEFAULT_RESULT_INDEX = "query_execution_result_ds1"; + public static final String US_EAST_REGION = "us-east-1"; + public static final String US_WEST_REGION = "us-west-1"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 4205102cb1..2a499e7d30 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -62,6 +62,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; @@ -82,6 +83,7 @@ public class SparkQueryDispatcherTest { @Mock private EMRServerlessClient emrServerlessClient; + @Mock private EMRServerlessClientFactory emrServerlessClientFactory; @Mock private DataSourceService dataSourceService; @Mock private JobExecutionResponseReader jobExecutionResponseReader; @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; @@ -112,7 +114,7 @@ public class SparkQueryDispatcherTest { void setUp() { sparkQueryDispatcher = new SparkQueryDispatcher( - emrServerlessClient, + emrServerlessClientFactory, dataSourceService, dataSourceUserAuthorizationHelper, jobExecutionResponseReader, @@ -121,6 +123,7 @@ void setUp() { sessionManager, leaseManager, stateStore); + when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); } @Test diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index d670fc4ca8..338da431fb 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -23,6 +23,7 @@ import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchIntegTestCase; @@ -117,8 +118,9 @@ public void closeNotExistSession() { @Test public void sessionManagerCreateSession() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); TestSession testSession = testSession(session, stateStore); @@ -127,7 +129,9 @@ public void sessionManagerCreateSession() { @Test public void sessionManagerGetSession() { - SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting()); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + SessionManager sessionManager = + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()); Session session = sessionManager.createSession(createSessionRequest()); Optional managerSession = sessionManager.getSession(session.getSessionId()); @@ -137,7 +141,9 @@ public void sessionManagerGetSession() { @Test public void sessionManagerGetSessionNotExist() { - SessionManager sessionManager = new SessionManager(stateStore, emrsClient, sessionSetting()); + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; + SessionManager sessionManager = + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()); Optional managerSession = sessionManager.getSession(SessionId.newSessionId("no-exist")); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index 44dd5c3a57..d021bc7248 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -14,17 +14,19 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.StateStore; @ExtendWith(MockitoExtension.class) public class SessionManagerTest { @Mock private StateStore stateStore; - @Mock private EMRServerlessClient emrClient; + + @Mock private EMRServerlessClientFactory emrServerlessClientFactory; @Test public void sessionEnable() { - Assertions.assertTrue(new SessionManager(stateStore, emrClient, sessionSetting()).isEnabled()); + Assertions.assertTrue( + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()).isEnabled()); } public static org.opensearch.sql.common.setting.Settings sessionSetting() { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 97f38d37a7..3a69fa01d7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -24,6 +24,7 @@ import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -258,8 +259,9 @@ public void cancelRunningStatementSuccess() { @Test public void submitStatementInRunningSession() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); // App change state to running @@ -271,8 +273,9 @@ public void submitStatementInRunningSession() { @Test public void submitStatementInNotStartedState() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); StatementId statementId = session.submit(queryRequest()); @@ -281,8 +284,9 @@ public void submitStatementInNotStartedState() { @Test public void failToSubmitStatementInDeadState() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); @@ -297,8 +301,9 @@ public void failToSubmitStatementInDeadState() { @Test public void failToSubmitStatementInFailState() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); @@ -313,8 +318,9 @@ public void failToSubmitStatementInFailState() { @Test public void newStatementFieldAssert() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); StatementId statementId = session.submit(queryRequest()); Optional statement = session.get(statementId); @@ -331,8 +337,9 @@ public void newStatementFieldAssert() { @Test public void failToSubmitStatementInDeletedSession() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); // other's delete session @@ -347,8 +354,9 @@ public void failToSubmitStatementInDeletedSession() { @Test public void getStatementSuccess() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); @@ -362,8 +370,9 @@ public void getStatementSuccess() { @Test public void getStatementNotExist() { + EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrsClient, sessionSetting()) + new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) .createSession(createSessionRequest()); // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); diff --git a/spark/src/test/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModuleTest.java b/spark/src/test/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModuleTest.java new file mode 100644 index 0000000000..d45950852f --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModuleTest.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.transport.config; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.ModulesBuilder; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; + +@ExtendWith(MockitoExtension.class) +public class AsyncExecutorServiceModuleTest { + + @Mock private NodeClient nodeClient; + + @Mock private ClusterService clusterService; + + @Mock private Settings settings; + + @Mock private DataSourceService dataSourceService; + + @Test + public void testAsyncQueryExecutorService() { + ModulesBuilder modulesBuilder = new ModulesBuilder(); + modulesBuilder.add(new AsyncExecutorServiceModule()); + modulesBuilder.add( + b -> { + b.bind(NodeClient.class).toInstance(nodeClient); + b.bind(org.opensearch.sql.common.setting.Settings.class).toInstance(settings); + b.bind(DataSourceService.class).toInstance(dataSourceService); + b.bind(ClusterService.class).toInstance(clusterService); + }); + Injector injector = modulesBuilder.createInjector(); + assertNotNull(injector.getInstance(AsyncQueryExecutorService.class)); + assertNotNull(Metrics.getInstance().getMetric("active_async_query_sessions_count")); + assertNotNull(Metrics.getInstance().getMetric("active_async_query_statements_count")); + } +}