diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/metrics/MetricName.java b/legacy/src/main/java/org/opensearch/sql/legacy/metrics/MetricName.java index 91ade7b038..4c3f24a9a9 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/metrics/MetricName.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/metrics/MetricName.java @@ -47,7 +47,8 @@ public enum MetricName { EMR_CANCEL_JOB_REQUEST_FAILURE_COUNT("emr_cancel_job_request_failure_count"), EMR_STREAMING_QUERY_JOBS_CREATION_COUNT("emr_streaming_jobs_creation_count"), EMR_INTERACTIVE_QUERY_JOBS_CREATION_COUNT("emr_interactive_jobs_creation_count"), - EMR_BATCH_QUERY_JOBS_CREATION_COUNT("emr_batch_jobs_creation_count"); + EMR_BATCH_QUERY_JOBS_CREATION_COUNT("emr_batch_jobs_creation_count"), + STREAMING_JOB_CLEANER_TASK_FAILURE_COUNT("streaming_job_cleaner_task_failure_count"); private String name; @@ -91,6 +92,7 @@ public static List getNames() { .add(ASYNC_QUERY_CREATE_API_REQUEST_COUNT) .add(ASYNC_QUERY_GET_API_REQUEST_COUNT) .add(ASYNC_QUERY_CANCEL_API_REQUEST_COUNT) + .add(STREAMING_JOB_CLEANER_TASK_FAILURE_COUNT) .build(); public boolean isNumerical() { diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 2b75a8b2c9..8487f30716 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -79,7 +79,10 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.cluster.ClusterManagerEventListener; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; @@ -221,7 +224,11 @@ public Collection createComponents( OpenSearchSettings.SESSION_INDEX_TTL_SETTING, OpenSearchSettings.RESULT_INDEX_TTL_SETTING, OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING, - environment.settings()); + environment.settings(), + dataSourceService, + injector.getInstance(FlintIndexMetadataService.class), + injector.getInstance(StateStore.class), + injector.getInstance(EMRServerlessClientFactory.class)); return ImmutableList.of( dataSourceService, injector.getInstance(AsyncQueryExecutorService.class), diff --git a/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java b/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java index 3d004b548f..9bea1f4b0d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java +++ b/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java @@ -19,17 +19,26 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.threadpool.Scheduler.Cancellable; import org.opensearch.threadpool.ThreadPool; public class ClusterManagerEventListener implements LocalNodeClusterManagerListener { private Cancellable flintIndexRetentionCron; + private Cancellable flintStreamingJobCleanerCron; private ClusterService clusterService; private ThreadPool threadPool; private Client client; private Clock clock; + private DataSourceService dataSourceService; + private FlintIndexMetadataService flintIndexMetadataService; + private StateStore stateStore; + private EMRServerlessClientFactory emrServerlessClientFactory; private Duration sessionTtlDuration; private Duration resultTtlDuration; private boolean isAutoIndexManagementEnabled; @@ -42,13 +51,20 @@ public ClusterManagerEventListener( Setting sessionTtl, Setting resultTtl, Setting isAutoIndexManagementEnabledSetting, - Settings settings) { + Settings settings, + DataSourceService dataSourceService, + FlintIndexMetadataService flintIndexMetadataService, + StateStore stateStore, + EMRServerlessClientFactory emrServerlessClientFactory) { this.clusterService = clusterService; this.threadPool = threadPool; this.client = client; this.clusterService.addLocalNodeClusterManagerListener(this); this.clock = clock; - + this.dataSourceService = dataSourceService; + this.flintIndexMetadataService = flintIndexMetadataService; + this.stateStore = stateStore; + this.emrServerlessClientFactory = emrServerlessClientFactory; this.sessionTtlDuration = toDuration(sessionTtl.get(settings)); this.resultTtlDuration = toDuration(resultTtl.get(settings)); @@ -104,6 +120,19 @@ public void beforeStop() { } }); } + initializeStreamingJobCleanerCron(); + } + + private void initializeStreamingJobCleanerCron() { + flintStreamingJobCleanerCron = + threadPool.scheduleWithFixedDelay( + new FlintStreamingJobCleanerTask( + dataSourceService, + flintIndexMetadataService, + stateStore, + emrServerlessClientFactory), + TimeValue.timeValueMinutes(15), + executorName()); } private void reInitializeFlintIndexRetention() { @@ -125,6 +154,8 @@ private void reInitializeFlintIndexRetention() { public void offClusterManager() { cancel(flintIndexRetentionCron); flintIndexRetentionCron = null; + cancel(flintStreamingJobCleanerCron); + flintStreamingJobCleanerCron = null; } private void cancel(Cancellable cron) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobCleanerTask.java b/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobCleanerTask.java new file mode 100644 index 0000000000..56c63dc814 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobCleanerTask.java @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.cluster; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceStatus; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.legacy.metrics.MetricName; +import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadata; +import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpAlter; +import org.opensearch.sql.spark.flint.operation.FlintIndexOpCancel; + +/** Cleaner task which alters the active streaming jobs of a disabled datasource. */ +@RequiredArgsConstructor +public class FlintStreamingJobCleanerTask implements Runnable { + + private final DataSourceService dataSourceService; + private final FlintIndexMetadataService flintIndexMetadataService; + private final StateStore stateStore; + private final EMRServerlessClientFactory emrServerlessClientFactory; + + private static final Logger LOGGER = LogManager.getLogger(FlintStreamingJobCleanerTask.class); + protected static final AtomicBoolean isRunning = new AtomicBoolean(false); + + @Override + public void run() { + if (!isRunning.compareAndSet(false, true)) { + LOGGER.info("Previous task is still running. Skipping this execution."); + return; + } + try { + LOGGER.info("Starting the cleaner task for disabled data sources."); + List s3GlueDisabledDataSources = getS3GlueDisabledDataSources(); + LOGGER.info("Found {} disabled data sources to process.", s3GlueDisabledDataSources.size()); + for (DataSourceMetadata dataSourceMetadata : s3GlueDisabledDataSources) { + LOGGER.info("Processing disabled data source: {}", dataSourceMetadata.getName()); + Map autoRefreshFlintIndicesMap = + getAutoRefreshIndicesOfDataSource(dataSourceMetadata); + LOGGER.info( + "Found {} auto-refresh indices to alter for data source: {}", + autoRefreshFlintIndicesMap.size(), + dataSourceMetadata.getName()); + autoRefreshFlintIndicesMap.forEach( + (autoRefreshIndex, flintIndexMetadata) -> { + try { + LOGGER.debug("Attempting to alter index: {}", autoRefreshIndex); + FlintIndexOptions flintIndexOptions = new FlintIndexOptions(); + flintIndexOptions.setOption(FlintIndexOptions.AUTO_REFRESH, "false"); + FlintIndexOpAlter flintIndexOpAlter = + new FlintIndexOpAlter( + flintIndexOptions, + stateStore, + dataSourceMetadata.getName(), + emrServerlessClientFactory.getClient(), + flintIndexMetadataService); + flintIndexOpAlter.apply(flintIndexMetadata); + LOGGER.info("Successfully altered index: {}", autoRefreshIndex); + } catch (Exception exception) { + LOGGER.error( + "Failed to alter index {}: {}", + autoRefreshIndex, + exception.getMessage(), + exception); + Metrics.getInstance() + .getNumericalMetric(MetricName.STREAMING_JOB_CLEANER_TASK_FAILURE_COUNT) + .increment(); + } + }); + } + } catch (Throwable error) { + LOGGER.info("Error while running the streaming job cleaner task: {}", error.getMessage()); + } finally { + isRunning.set(false); + } + } + + public void run2() { + if (!isRunning.compareAndSet(false, true)) { + LOGGER.info("Previous task is still running. Skipping this execution."); + return; + } + try { + LOGGER.info("Starting the cleaner task for disabled and deleted data sources."); + List s3GlueDisabledDataSources = getS3GlueDataSources(); + Set disabledS3DataSources = s3GlueDisabledDataSources.stream() + .filter(dataSourceMetadata -> dataSourceMetadata.getStatus() == DataSourceStatus.DISABLED) + .map(DataSourceMetadata::getName).collect( + Collectors.toSet()); + Set allS3DataSources = s3GlueDisabledDataSources.stream().map(DataSourceMetadata::getName).collect( + Collectors.toSet()); + Map autoRefreshFlintIndicesMap = getAllAutoRefreshIndices(); + autoRefreshFlintIndicesMap.forEach( + (autoRefreshIndex, flintIndexMetadata) -> { + try { + String datasourceName = getDataSourceName(autoRefreshIndex); + if (disabledS3DataSources.contains(datasourceName)) { + LOGGER.debug("Attempting to alter index: {}", autoRefreshIndex); + FlintIndexOptions flintIndexOptions = new FlintIndexOptions(); + flintIndexOptions.setOption(FlintIndexOptions.AUTO_REFRESH, "false"); + FlintIndexOpAlter flintIndexOpAlter = + new FlintIndexOpAlter( + flintIndexOptions, + stateStore, + datasourceName, + emrServerlessClientFactory.getClient(), + flintIndexMetadataService); + flintIndexOpAlter.apply(flintIndexMetadata); + LOGGER.info("Successfully altered index: {}", autoRefreshIndex); + } else if(!allS3DataSources.contains(datasourceName)) { + LOGGER.debug("Attempting to cancel auto refresh index: {}", autoRefreshIndex); + FlintIndexOpCancel flintIndexOpCancel = new FlintIndexOpCancel(stateStore, datasourceName, + emrServerlessClientFactory.getClient(), true); + flintIndexOpCancel.apply(flintIndexMetadata); + LOGGER.info("Successfully cancelled index: {}", autoRefreshIndex); + } + } catch (Exception exception) { + LOGGER.error( + "Failed to alter/cancel index {}: {}", + autoRefreshIndex, + exception.getMessage(), + exception); + Metrics.getInstance() + .getNumericalMetric(MetricName.STREAMING_JOB_CLEANER_TASK_FAILURE_COUNT) + .increment(); + } + }); + } catch (Throwable error) { + LOGGER.info("Error while running the streaming job cleaner task: {}", error.getMessage()); + } finally { + isRunning.set(false); + } + } + + private String getDataSourceName(String autoRefreshIndex) { + String[] split = autoRefreshIndex.split("_"); + return split.length > 1 ? split[1] : StringUtils.EMPTY; + } + + + private Map getAutoRefreshIndicesOfDataSource( + DataSourceMetadata dataSourceMetadata) { + Map flintIndexMetadataHashMap = + flintIndexMetadataService.getFlintIndexMetadata( + "flint_" + dataSourceMetadata.getName() + "_*"); + return flintIndexMetadataHashMap.entrySet().stream() + .filter(entry -> entry.getValue().getFlintIndexOptions().autoRefresh()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + private Map getAllAutoRefreshIndices() { + Map flintIndexMetadataHashMap = + flintIndexMetadataService.getFlintIndexMetadata( + "flint_*"); + return flintIndexMetadataHashMap.entrySet().stream() + .filter(entry -> entry.getValue().getFlintIndexOptions().autoRefresh()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + private List getS3GlueDisabledDataSources() { + return this.dataSourceService.getDataSourceMetadata(false).stream() + .filter(dataSourceMetadata -> dataSourceMetadata.getConnector() == DataSourceType.S3GLUE) + .filter(dataSourceMetadata -> dataSourceMetadata.getStatus() == DataSourceStatus.DISABLED) + .collect(Collectors.toList()); + } + + private List getS3GlueDataSources() { + return this.dataSourceService.getDataSourceMetadata(false).stream() + .filter(dataSourceMetadata -> dataSourceMetadata.getConnector() == DataSourceType.S3GLUE) + .collect(Collectors.toList()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobCleanerTaskTest.java b/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobCleanerTaskTest.java new file mode 100644 index 0000000000..40c51a7bb7 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobCleanerTaskTest.java @@ -0,0 +1,125 @@ +package org.opensearch.sql.spark.cluster; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import static org.mockito.Mockito.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.HashMap; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceStatus; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.legacy.metrics.MetricName; +import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.flint.FlintIndexMetadata; +import org.opensearch.sql.spark.flint.FlintIndexMetadataService; + +public class FlintStreamingJobCleanerTaskTest { + + @Mock + private DataSourceService dataSourceService; + @Mock + private FlintIndexMetadataService flintIndexMetadataService; + @Mock + private StateStore stateStore; + @Mock + private EMRServerlessClientFactory emrServerlessClientFactory; + @Mock + @InjectMocks + private FlintStreamingJobCleanerTask task; + + @BeforeEach + void setUp() { + FlintStreamingJobCleanerTask.isRunning.set(false); + when(Metrics.getInstance()).thenReturn(metrics); + when(metrics.getNumericalMetric(MetricName.STREAMING_JOB_CLEANER_TASK_FAILURE_COUNT)).thenReturn(numericalMetric); + } + + @Test + void shouldSkipExecutionIfPreviousTaskIsRunning() { + // Arrange + FlintStreamingJobCleanerTask.isRunning.set(true); + + // Act + task.run(); + + // Assert + verifyNoInteractions(dataSourceService); + } + + @Test + void shouldProcessDisabledDataSources() { + // Arrange + List disabledDataSources = new ArrayList<>(); + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata("testDataSource", DataSourceType.S3GLUE, "", DataSourceStatus.DISABLED); + disabledDataSources.add(dataSourceMetadata); + when(dataSourceService.getDataSourceMetadata(false)).thenReturn(disabledDataSources); + when(flintIndexMetadataService.getFlintIndexMetadata(anyString())).thenReturn(new HashMap<>()); + + // Act + task.run(); + + // Assert + verify(dataSourceService).getDataSourceMetadata(false); + verify(flintIndexMetadataService).getFlintIndexMetadata(contains("flint_testDataSource_*")); + } + + @Test + void shouldAttemptToAlterAutoRefreshIndices() { + // Arrange + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata("testDataSource", DataSourceType.S3GLUE, "", DataSourceStatus.DISABLED); + FlintIndexMetadata flintIndexMetadata = new FlintIndexMetadata("indexName", new FlintIndexOptions()); + Map indicesMap = new HashMap<>(); + indicesMap.put("indexName", flintIndexMetadata); + + when(dataSourceService.getDataSourceMetadata(false)).thenReturn(List.of(dataSourceMetadata)); + when(flintIndexMetadataService.getFlintIndexMetadata(anyString())).thenReturn(indicesMap); + + // Act + task.run(); + + // Assert + verify(flintIndexMetadataService).getFlintIndexMetadata(contains("flint_testDataSource_*")); + // Verify if alter operation was attempted, this could involve checking calls to a method responsible for altering indices. + } + + @Test + void shouldHandleExceptionsDuringIndexAlterationGracefully() { + // Arrange + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata("testDataSource", DataSourceType.S3GLUE, "", DataSourceStatus.DISABLED); + when(dataSourceService.getDataSourceMetadata(false)).thenReturn(List.of(dataSourceMetadata)); + when(flintIndexMetadataService.getFlintIndexMetadata(anyString())).thenThrow(RuntimeException.class); + + // Act + task.run(); + + // Assert + verify(numericalMetric).increment(); + // Further assertions might include checking for specific log messages. + } + + @Test + void shouldIncrementFailureMetricsWhenExceptionOccurs() { + // Arrange + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata("testDataSource", DataSourceType.S3GLUE, "", DataSourceStatus.DISABLED); + FlintIndexMetadata flintIndexMetadata = new FlintIndexMetadata("indexName", new FlintIndexOptions()); + Map indicesMap = new HashMap<>(); + indicesMap.put("indexName", flintIndexMetadata); + + when(dataSourceService.getDataSourceMetadata(false)).thenReturn(List.of(dataSourceMetadata)); + when(flintIndexMetadataService.getFlintIndexMetadata(anyString())).thenReturn(indicesMap); + doThrow(RuntimeException.class).when(flintIndexMetadataService).getFlintIndexMetadata(contains("flint_testDataSource_*")); + + // Act + task.run(); + + // Assert + verify(numericalMetric, atLeastOnce()).increment(); + } +}