From c43cdfec0f8818e214042deaa799a8e6a4b90c36 Mon Sep 17 00:00:00 2001 From: Sagar <99425694+sgup432@users.noreply.github.com> Date: Sat, 10 Jun 2023 00:38:28 -0700 Subject: [PATCH] Task cancellation monitoring service (#7642) * Task cancellation monitoring service Signed-off-by: Sagar Upadhyaya Signed-off-by: Rishab Nahata --- CHANGELOG.md | 1 + .../admin/cluster/node/stats/NodeStats.java | 24 +- .../cluster/node/stats/NodesStatsRequest.java | 3 +- .../node/stats/TransportNodesStatsAction.java | 3 +- .../stats/TransportClusterStatsAction.java | 1 + .../common/settings/ClusterSettings.java | 7 +- .../main/java/org/opensearch/node/Node.java | 17 +- .../java/org/opensearch/node/NodeService.java | 16 +- .../SearchShardTaskCancellationStats.java | 75 ++++ .../TaskCancellationMonitoringService.java | 179 +++++++++ .../TaskCancellationMonitoringSettings.java | 93 +++++ .../tasks/TaskCancellationStats.java | 64 +++ .../org/opensearch/tasks/TaskManager.java | 34 ++ .../cluster/node/stats/NodeStatsTests.java | 1 + .../node/tasks/CancellableTasksTests.java | 31 ++ .../opensearch/cluster/DiskUsageTests.java | 6 + ...SearchShardTaskCancellationStatsTests.java | 28 ++ ...askCancellationMonitoringServiceTests.java | 371 ++++++++++++++++++ ...skCancellationMonitoringSettingsTests.java | 39 ++ .../tasks/TaskCancellationStatsTests.java | 28 ++ .../MockInternalClusterInfoService.java | 3 +- .../opensearch/test/InternalTestCluster.java | 1 + 22 files changed, 1016 insertions(+), 9 deletions(-) create mode 100644 server/src/main/java/org/opensearch/tasks/SearchShardTaskCancellationStats.java create mode 100644 server/src/main/java/org/opensearch/tasks/TaskCancellationMonitoringService.java create mode 100644 server/src/main/java/org/opensearch/tasks/TaskCancellationMonitoringSettings.java create mode 100644 server/src/main/java/org/opensearch/tasks/TaskCancellationStats.java create mode 100644 server/src/test/java/org/opensearch/tasks/SearchShardTaskCancellationStatsTests.java create mode 100644 server/src/test/java/org/opensearch/tasks/TaskCancellationMonitoringServiceTests.java create mode 100644 server/src/test/java/org/opensearch/tasks/TaskCancellationMonitoringSettingsTests.java create mode 100644 server/src/test/java/org/opensearch/tasks/TaskCancellationStatsTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 69ac64e06e38c..fe3c012dab387 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x] ### Added +- Add task cancellation monitoring service ([#7642](https://github.com/opensearch-project/OpenSearch/pull/7642)) - Add TokenManager Interface ([#7452](https://github.com/opensearch-project/OpenSearch/pull/7452)) - Add Remote store as a segment replication source ([#7653](https://github.com/opensearch-project/OpenSearch/pull/7653)) diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodeStats.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodeStats.java index f92963af1681a..d03011774bb83 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodeStats.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodeStats.java @@ -59,6 +59,7 @@ import org.opensearch.script.ScriptCacheStats; import org.opensearch.script.ScriptStats; import org.opensearch.search.backpressure.stats.SearchBackpressureStats; +import org.opensearch.tasks.TaskCancellationStats; import org.opensearch.threadpool.ThreadPoolStats; import org.opensearch.transport.TransportStats; @@ -134,6 +135,9 @@ public class NodeStats extends BaseNodeResponse implements ToXContentFragment { @Nullable private FileCacheStats fileCacheStats; + @Nullable + private TaskCancellationStats taskCancellationStats; + public NodeStats(StreamInput in) throws IOException { super(in); timestamp = in.readVLong(); @@ -180,6 +184,11 @@ public NodeStats(StreamInput in) throws IOException { } else { fileCacheStats = null; } + if (in.getVersion().onOrAfter(Version.V_3_0_0)) { + taskCancellationStats = in.readOptionalWriteable(TaskCancellationStats::new); + } else { + taskCancellationStats = null; + } } public NodeStats( @@ -204,7 +213,8 @@ public NodeStats( @Nullable SearchBackpressureStats searchBackpressureStats, @Nullable ClusterManagerThrottlingStats clusterManagerThrottlingStats, @Nullable WeightedRoutingStats weightedRoutingStats, - @Nullable FileCacheStats fileCacheStats + @Nullable FileCacheStats fileCacheStats, + @Nullable TaskCancellationStats taskCancellationStats ) { super(node); this.timestamp = timestamp; @@ -228,6 +238,7 @@ public NodeStats( this.clusterManagerThrottlingStats = clusterManagerThrottlingStats; this.weightedRoutingStats = weightedRoutingStats; this.fileCacheStats = fileCacheStats; + this.taskCancellationStats = taskCancellationStats; } public long getTimestamp() { @@ -355,6 +366,11 @@ public FileCacheStats getFileCacheStats() { return fileCacheStats; } + @Nullable + public TaskCancellationStats getTaskCancellationStats() { + return taskCancellationStats; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -392,6 +408,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_2_7_0)) { out.writeOptionalWriteable(fileCacheStats); } + if (out.getVersion().onOrAfter(Version.V_3_0_0)) { + out.writeOptionalWriteable(taskCancellationStats); + } } @Override @@ -476,6 +495,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (getFileCacheStats() != null) { getFileCacheStats().toXContent(builder, params); } + if (getTaskCancellationStats() != null) { + getTaskCancellationStats().toXContent(builder, params); + } return builder; } diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodesStatsRequest.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodesStatsRequest.java index a9c58ac803590..68f391b91507c 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodesStatsRequest.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/NodesStatsRequest.java @@ -210,7 +210,8 @@ public enum Metric { SEARCH_BACKPRESSURE("search_backpressure"), CLUSTER_MANAGER_THROTTLING("cluster_manager_throttling"), WEIGHTED_ROUTING_STATS("weighted_routing"), - FILE_CACHE_STATS("file_cache"); + FILE_CACHE_STATS("file_cache"), + TASK_CANCELLATION("task_cancellation"); private String metricName; diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java index 02b5ceef2c7e4..6aadf546d30f7 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/stats/TransportNodesStatsAction.java @@ -122,7 +122,8 @@ protected NodeStats nodeOperation(NodeStatsRequest nodeStatsRequest) { NodesStatsRequest.Metric.SEARCH_BACKPRESSURE.containedIn(metrics), NodesStatsRequest.Metric.CLUSTER_MANAGER_THROTTLING.containedIn(metrics), NodesStatsRequest.Metric.WEIGHTED_ROUTING_STATS.containedIn(metrics), - NodesStatsRequest.Metric.FILE_CACHE_STATS.containedIn(metrics) + NodesStatsRequest.Metric.FILE_CACHE_STATS.containedIn(metrics), + NodesStatsRequest.Metric.TASK_CANCELLATION.containedIn(metrics) ); } diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/stats/TransportClusterStatsAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/stats/TransportClusterStatsAction.java index 26332f762bdf2..726f8a0de19ae 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/stats/TransportClusterStatsAction.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/stats/TransportClusterStatsAction.java @@ -166,6 +166,7 @@ protected ClusterStatsNodeResponse nodeOperation(ClusterStatsNodeRequest nodeReq false, false, false, + false, false ); List shardsStats = new ArrayList<>(); diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index c12eb87ddbcb5..6dfa705b12896 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -48,6 +48,7 @@ import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; import org.opensearch.search.backpressure.settings.SearchShardTaskSettings; import org.opensearch.search.backpressure.settings.SearchTaskSettings; +import org.opensearch.tasks.TaskCancellationMonitoringSettings; import org.opensearch.tasks.TaskManager; import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.tasks.consumer.TopNSearchTasksLogger; @@ -649,7 +650,11 @@ public void apply(Settings value, Settings current, Settings previous) { RemoteRefreshSegmentPressureSettings.MIN_CONSECUTIVE_FAILURES_LIMIT, RemoteRefreshSegmentPressureSettings.UPLOAD_BYTES_MOVING_AVERAGE_WINDOW_SIZE, RemoteRefreshSegmentPressureSettings.UPLOAD_BYTES_PER_SEC_MOVING_AVERAGE_WINDOW_SIZE, - RemoteRefreshSegmentPressureSettings.UPLOAD_TIME_MOVING_AVERAGE_WINDOW_SIZE + RemoteRefreshSegmentPressureSettings.UPLOAD_TIME_MOVING_AVERAGE_WINDOW_SIZE, + + // Related to monitoring of task cancellation + TaskCancellationMonitoringSettings.IS_ENABLED_SETTING, + TaskCancellationMonitoringSettings.DURATION_MILLIS_SETTING ) ) ); diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index dd205ad87812b..eb1fc2008df06 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -59,6 +59,8 @@ import org.opensearch.search.backpressure.SearchBackpressureService; import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.tasks.TaskCancellationMonitoringService; +import org.opensearch.tasks.TaskCancellationMonitoringSettings; import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.tasks.consumer.TopNSearchTasksLogger; import org.opensearch.threadpool.RunnableTaskExecutionListener; @@ -972,6 +974,15 @@ protected Node( client, FeatureFlags.isEnabled(SEARCH_PIPELINE) ); + final TaskCancellationMonitoringSettings taskCancellationMonitoringSettings = new TaskCancellationMonitoringSettings( + settings, + clusterService.getClusterSettings() + ); + final TaskCancellationMonitoringService taskCancellationMonitoringService = new TaskCancellationMonitoringService( + threadPool, + transportService.getTaskManager(), + taskCancellationMonitoringSettings + ); this.nodeService = new NodeService( settings, threadPool, @@ -992,7 +1003,8 @@ protected Node( searchModule.getValuesSourceRegistry().getUsageService(), searchBackpressureService, searchPipelineService, - fileCache + fileCache, + taskCancellationMonitoringService ); final SearchService searchService = newSearchService( @@ -1222,6 +1234,7 @@ public Node start() throws NodeValidationException { injector.getInstance(FsHealthService.class).start(); nodeService.getMonitorService().start(); nodeService.getSearchBackpressureService().start(); + nodeService.getTaskCancellationMonitoringService().start(); final ClusterService clusterService = injector.getInstance(ClusterService.class); @@ -1380,6 +1393,7 @@ private Node stop() { injector.getInstance(GatewayService.class).stop(); injector.getInstance(SearchService.class).stop(); injector.getInstance(TransportService.class).stop(); + nodeService.getTaskCancellationMonitoringService().stop(); pluginLifecycleComponents.forEach(LifecycleComponent::stop); // we should stop this last since it waits for resources to get released @@ -1443,6 +1457,7 @@ public synchronized void close() throws IOException { toClose.add(injector.getInstance(SearchService.class)); toClose.add(() -> stopWatch.stop().start("transport")); toClose.add(injector.getInstance(TransportService.class)); + toClose.add(nodeService.getTaskCancellationMonitoringService()); for (LifecycleComponent plugin : pluginLifecycleComponents) { toClose.add(() -> stopWatch.stop().start("plugin(" + plugin.getClass().getName() + ")")); diff --git a/server/src/main/java/org/opensearch/node/NodeService.java b/server/src/main/java/org/opensearch/node/NodeService.java index 0eab742a8da7d..9382746081c18 100644 --- a/server/src/main/java/org/opensearch/node/NodeService.java +++ b/server/src/main/java/org/opensearch/node/NodeService.java @@ -57,6 +57,7 @@ import org.opensearch.search.aggregations.support.AggregationUsageService; import org.opensearch.search.backpressure.SearchBackpressureService; import org.opensearch.search.pipeline.SearchPipelineService; +import org.opensearch.tasks.TaskCancellationMonitoringService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -90,6 +91,7 @@ public class NodeService implements Closeable { private final ClusterService clusterService; private final Discovery discovery; private final FileCache fileCache; + private final TaskCancellationMonitoringService taskCancellationMonitoringService; NodeService( Settings settings, @@ -111,7 +113,8 @@ public class NodeService implements Closeable { AggregationUsageService aggregationUsageService, SearchBackpressureService searchBackpressureService, SearchPipelineService searchPipelineService, - FileCache fileCache + FileCache fileCache, + TaskCancellationMonitoringService taskCancellationMonitoringService ) { this.settings = settings; this.threadPool = threadPool; @@ -133,6 +136,7 @@ public class NodeService implements Closeable { this.searchPipelineService = searchPipelineService; this.clusterService = clusterService; this.fileCache = fileCache; + this.taskCancellationMonitoringService = taskCancellationMonitoringService; clusterService.addStateApplier(ingestService); clusterService.addStateApplier(searchPipelineService); } @@ -211,7 +215,8 @@ public NodeStats stats( boolean searchBackpressure, boolean clusterManagerThrottling, boolean weightedRoutingStats, - boolean fileCacheStats + boolean fileCacheStats, + boolean taskCancellation ) { // for indices stats we want to include previous allocated shards stats as well (it will // only be applied to the sensible ones to use, like refresh/merge/flush/indexing stats) @@ -237,7 +242,8 @@ public NodeStats stats( searchBackpressure ? this.searchBackpressureService.nodeStats() : null, clusterManagerThrottling ? this.clusterService.getClusterManagerService().getThrottlingStats() : null, weightedRoutingStats ? WeightedRoutingStats.getInstance() : null, - fileCacheStats && fileCache != null ? fileCache.fileCacheStats() : null + fileCacheStats && fileCache != null ? fileCache.fileCacheStats() : null, + taskCancellation ? this.taskCancellationMonitoringService.stats() : null ); } @@ -253,6 +259,10 @@ public SearchBackpressureService getSearchBackpressureService() { return searchBackpressureService; } + public TaskCancellationMonitoringService getTaskCancellationMonitoringService() { + return taskCancellationMonitoringService; + } + @Override public void close() throws IOException { IOUtils.close(indicesService); diff --git a/server/src/main/java/org/opensearch/tasks/SearchShardTaskCancellationStats.java b/server/src/main/java/org/opensearch/tasks/SearchShardTaskCancellationStats.java new file mode 100644 index 0000000000000..d78a4480700da --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/SearchShardTaskCancellationStats.java @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * Holds monitoring service stats specific to search shard task. + */ +public class SearchShardTaskCancellationStats implements ToXContentObject, Writeable { + + private final long currentLongRunningCancelledTaskCount; + private final long totalLongRunningCancelledTaskCount; + + public SearchShardTaskCancellationStats(long currentTaskCount, long totalTaskCount) { + this.currentLongRunningCancelledTaskCount = currentTaskCount; + this.totalLongRunningCancelledTaskCount = totalTaskCount; + } + + public SearchShardTaskCancellationStats(StreamInput in) throws IOException { + this.currentLongRunningCancelledTaskCount = in.readVLong(); + this.totalLongRunningCancelledTaskCount = in.readVLong(); + } + + // package private for testing + protected long getCurrentLongRunningCancelledTaskCount() { + return this.currentLongRunningCancelledTaskCount; + } + + // package private for testing + protected long getTotalLongRunningCancelledTaskCount() { + return this.totalLongRunningCancelledTaskCount; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("current_count_post_cancel", currentLongRunningCancelledTaskCount); + builder.field("total_count_post_cancel", totalLongRunningCancelledTaskCount); + return builder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(currentLongRunningCancelledTaskCount); + out.writeVLong(totalLongRunningCancelledTaskCount); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SearchShardTaskCancellationStats that = (SearchShardTaskCancellationStats) o; + return currentLongRunningCancelledTaskCount == that.currentLongRunningCancelledTaskCount + && totalLongRunningCancelledTaskCount == that.totalLongRunningCancelledTaskCount; + } + + @Override + public int hashCode() { + return Objects.hash(currentLongRunningCancelledTaskCount, totalLongRunningCancelledTaskCount); + } +} diff --git a/server/src/main/java/org/opensearch/tasks/TaskCancellationMonitoringService.java b/server/src/main/java/org/opensearch/tasks/TaskCancellationMonitoringService.java new file mode 100644 index 0000000000000..5b512af56e195 --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskCancellationMonitoringService.java @@ -0,0 +1,179 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.component.AbstractLifecycleComponent; +import org.opensearch.common.metrics.CounterMetric; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** + * This monitoring service is responsible to track long-running(defined by a threshold) cancelled tasks as part of + * node stats. + */ +public class TaskCancellationMonitoringService extends AbstractLifecycleComponent implements TaskManager.TaskEventListeners { + + private static final Logger logger = LogManager.getLogger(TaskCancellationMonitoringService.class); + private final static List> TASKS_TO_TRACK = Arrays.asList(SearchShardTask.class); + + private volatile Scheduler.Cancellable scheduledFuture; + private final ThreadPool threadPool; + private final TaskManager taskManager; + /** + * This is to keep track of currently running cancelled tasks. This is needed to accurately calculate cumulative + * sum(from genesis) of cancelled tasks which have been running beyond a threshold and avoid double count + * problem. + * For example: + * A task M was cancelled at some point of time and continues to run for long. This Monitoring service sees this + * M for the first time and adds it as part of stats. In next iteration of monitoring service, it might see + * this M(if still running) again, but using below map we will not double count this task as part of our cumulative + * metric. + */ + private final Map cancelledTaskTracker; + /** + * This map holds statistics for each cancellable task type. + */ + private final Map, TaskCancellationStatsHolder> cancellationStatsHolder; + private final TaskCancellationMonitoringSettings taskCancellationMonitoringSettings; + + public TaskCancellationMonitoringService( + ThreadPool threadPool, + TaskManager taskManager, + TaskCancellationMonitoringSettings taskCancellationMonitoringSettings + ) { + this.threadPool = threadPool; + this.taskManager = taskManager; + this.taskCancellationMonitoringSettings = taskCancellationMonitoringSettings; + this.cancelledTaskTracker = new ConcurrentHashMap<>(); + cancellationStatsHolder = TASKS_TO_TRACK.stream() + .collect(Collectors.toConcurrentMap(task -> task, task -> new TaskCancellationStatsHolder())); + taskManager.addTaskEventListeners(this); + } + + void doRun() { + if (!taskCancellationMonitoringSettings.isEnabled() || this.cancelledTaskTracker.isEmpty()) { + return; + } + Map, List> taskCancellationListByType = getCurrentRunningTasksPostCancellation(); + taskCancellationListByType.forEach((key, value) -> { + long uniqueTasksRunningCount = value.stream().filter(task -> { + if (this.cancelledTaskTracker.containsKey(task.getId()) && !this.cancelledTaskTracker.get(task.getId())) { + // Mark it as seen by the stats logic. + this.cancelledTaskTracker.put(task.getId(), true); + return true; + } else { + return false; + } + }).count(); + cancellationStatsHolder.get(key).totalLongRunningCancelledTaskCount.inc(uniqueTasksRunningCount); + }); + } + + @Override + protected void doStart() { + scheduledFuture = threadPool.scheduleWithFixedDelay(() -> { + try { + doRun(); + } catch (Exception e) { + logger.debug("Exception occurred in Task monitoring service", e); + } + }, taskCancellationMonitoringSettings.getInterval(), ThreadPool.Names.GENERIC); + } + + @Override + protected void doStop() { + if (scheduledFuture != null) { + scheduledFuture.cancel(); + } + } + + @Override + protected void doClose() throws IOException { + + } + + // For testing + protected Map getCancelledTaskTracker() { + return this.cancelledTaskTracker; + } + + /** + * Invoked when a task is completed. This helps us to disable monitoring service when there are no cancelled tasks + * running to avoid wasteful work. + * @param task task which got completed. + */ + @Override + public void onTaskCompleted(Task task) { + if (!TASKS_TO_TRACK.contains(task.getClass())) { + return; + } + this.cancelledTaskTracker.entrySet().removeIf(entry -> entry.getKey() == task.getId()); + } + + /** + * Invoked when a task is cancelled. This is to keep track of tasks being cancelled. More importantly also helps + * us to enable this monitoring service only when needed. + * @param task task which got cancelled. + */ + @Override + public void onTaskCancelled(CancellableTask task) { + if (!TASKS_TO_TRACK.contains(task.getClass())) { + return; + } + // Add task to tracker and mark it as not seen(false) yet by the stats logic. + this.cancelledTaskTracker.putIfAbsent(task.getId(), false); + } + + public TaskCancellationStats stats() { + Map, List> currentRunningCancelledTasks = + getCurrentRunningTasksPostCancellation(); + return new TaskCancellationStats( + new SearchShardTaskCancellationStats( + Optional.of(currentRunningCancelledTasks).map(mapper -> mapper.get(SearchShardTask.class)).map(List::size).orElse(0), + cancellationStatsHolder.get(SearchShardTask.class).totalLongRunningCancelledTaskCount.count() + ) + ); + } + + private Map, List> getCurrentRunningTasksPostCancellation() { + long currentTimeInNanos = System.nanoTime(); + + return taskManager.getCancellableTasks() + .values() + .stream() + .filter(task -> TASKS_TO_TRACK.contains(task.getClass())) + .filter(CancellableTask::isCancelled) + .filter(task -> { + long runningTimeSinceCancellationSeconds = TimeUnit.NANOSECONDS.toSeconds( + currentTimeInNanos - task.getCancellationStartTimeNanos() + ); + return runningTimeSinceCancellationSeconds >= taskCancellationMonitoringSettings.getDuration().getSeconds(); + }) + .collect(Collectors.groupingBy(CancellableTask::getClass, Collectors.toList())); + } + + /** + * Holds stats related to monitoring service + */ + public static class TaskCancellationStatsHolder { + CounterMetric totalLongRunningCancelledTaskCount = new CounterMetric(); + } +} diff --git a/server/src/main/java/org/opensearch/tasks/TaskCancellationMonitoringSettings.java b/server/src/main/java/org/opensearch/tasks/TaskCancellationMonitoringSettings.java new file mode 100644 index 0000000000000..d4ec99873d584 --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskCancellationMonitoringSettings.java @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Settings related to task cancellation monitoring service. + */ +public class TaskCancellationMonitoringSettings { + + public static final long INTERVAL_MILLIS_SETTING_DEFAULT_VALUE = 5000; + public static final long DURATION_MILLIS_SETTING_DEFAULT_VALUE = 10000; + public static final boolean IS_ENABLED_SETTING_DEFAULT_VALUE = true; + + /** + * Defines the interval(in millis) at which task cancellation service monitors and gather stats. + */ + public static final Setting INTERVAL_MILLIS_SETTING = Setting.longSetting( + "task_cancellation.interval_millis", + INTERVAL_MILLIS_SETTING_DEFAULT_VALUE, + 1, + Setting.Property.NodeScope + ); + + /** + * Setting which defines the duration threshold(in millis) of current running cancelled tasks above which they + * are tracked as part of stats. + */ + public static final Setting DURATION_MILLIS_SETTING = Setting.longSetting( + "task_cancellation.duration_millis", + DURATION_MILLIS_SETTING_DEFAULT_VALUE, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Setting to enable/disable monitoring service. + */ + public static final Setting IS_ENABLED_SETTING = Setting.boolSetting( + "task_cancellation.enabled", + IS_ENABLED_SETTING_DEFAULT_VALUE, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + private final TimeValue interval; + private TimeValue duration; + private final AtomicBoolean isEnabled; + private final Settings settings; + private final ClusterSettings clusterSettings; + + public TaskCancellationMonitoringSettings(Settings settings, ClusterSettings clusterSettings) { + this.settings = settings; + this.clusterSettings = clusterSettings; + this.interval = new TimeValue(INTERVAL_MILLIS_SETTING.get(settings)); + this.duration = new TimeValue(DURATION_MILLIS_SETTING.get(settings)); + this.isEnabled = new AtomicBoolean(IS_ENABLED_SETTING.get(settings)); + clusterSettings.addSettingsUpdateConsumer(IS_ENABLED_SETTING, this::setIsEnabled); + clusterSettings.addSettingsUpdateConsumer(DURATION_MILLIS_SETTING, this::setDurationMillis); + } + + public TimeValue getInterval() { + return this.interval; + } + + public TimeValue getDuration() { + return this.duration; + } + + public void setDurationMillis(long durationMillis) { + this.duration = new TimeValue(durationMillis); + } + + public boolean isEnabled() { + return isEnabled.get(); + } + + public void setIsEnabled(boolean isEnabled) { + this.isEnabled.set(isEnabled); + } +} diff --git a/server/src/main/java/org/opensearch/tasks/TaskCancellationStats.java b/server/src/main/java/org/opensearch/tasks/TaskCancellationStats.java new file mode 100644 index 0000000000000..2ccb3738b1235 --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskCancellationStats.java @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * Holds stats related to task cancellation. + */ +public class TaskCancellationStats implements ToXContentFragment, Writeable { + + private final SearchShardTaskCancellationStats searchShardTaskCancellationStats; + + public TaskCancellationStats(SearchShardTaskCancellationStats searchShardTaskCancellationStats) { + this.searchShardTaskCancellationStats = searchShardTaskCancellationStats; + } + + public TaskCancellationStats(StreamInput in) throws IOException { + searchShardTaskCancellationStats = new SearchShardTaskCancellationStats(in); + } + + // package private for testing + protected SearchShardTaskCancellationStats getSearchShardTaskCancellationStats() { + return this.searchShardTaskCancellationStats; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject("task_cancellation"); + builder.field("search_shard_task", searchShardTaskCancellationStats); + return builder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + searchShardTaskCancellationStats.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TaskCancellationStats that = (TaskCancellationStats) o; + return Objects.equals(searchShardTaskCancellationStats, that.searchShardTaskCancellationStats); + } + + @Override + public int hashCode() { + return Objects.hash(searchShardTaskCancellationStats); + } +} diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 1d0e19e7a557b..6aeba47766842 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -132,6 +132,7 @@ public class TaskManager implements ClusterStateApplier { private volatile boolean taskResourceConsumersEnabled; private final Set> taskResourceConsumer; + private final List taskEventListeners = new ArrayList<>(); public static TaskManager createTaskManagerWithClusterSettings( Settings settings, @@ -152,6 +153,19 @@ public TaskManager(Settings settings, ThreadPool threadPool, Set taskHea taskResourceConsumer = new HashSet<>(); } + /** + * Listener that gets invoked during an event such as task cancellation/completion. + */ + public interface TaskEventListeners { + default void onTaskCancelled(CancellableTask task) {} + + default void onTaskCompleted(Task task) {} + } + + public void addTaskEventListeners(TaskEventListeners taskEventListeners) { + this.taskEventListeners.add(taskEventListeners); + } + public void registerTaskResourceConsumer(Consumer consumer) { taskResourceConsumer.add(consumer); } @@ -261,6 +275,17 @@ private void registerCancellableTask(Task task) { */ public void cancel(CancellableTask task, String reason, Runnable listener) { CancellableTaskHolder holder = cancellableTasks.get(task.getId()); + List exceptions = new ArrayList<>(); + for (TaskEventListeners taskEventListener : taskEventListeners) { + try { + taskEventListener.onTaskCancelled(task); + } catch (Exception e) { + exceptions.add(e); + } + } + // Throwing exception in case any of the cancellation listener results into exception. + // Should we just swallow such exceptions? + ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptions); if (holder != null) { logger.trace("cancelling task with id {}", task.getId()); holder.cancel(reason, listener); @@ -274,6 +299,15 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { */ public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); + List exceptions = new ArrayList<>(); + for (TaskEventListeners taskEventListener : taskEventListeners) { + try { + taskEventListener.onTaskCompleted(task); + } catch (Exception e) { + exceptions.add(e); + } + } + ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptions); // Decrement the task's self-thread as part of unregistration. task.decrementResourceTrackingThreads(); diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java index 473ab3d26a05c..d99b93b780140 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/stats/NodeStatsTests.java @@ -749,6 +749,7 @@ public static NodeStats createNodeStats() { null, clusterManagerThrottlingStats, weightedRoutingStats, + null, null ); } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/CancellableTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/CancellableTasksTests.java index ffd3a66ad1d48..e7026e9bc34cb 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/CancellableTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/CancellableTasksTests.java @@ -33,6 +33,7 @@ import com.carrotsearch.randomizedtesting.RandomizedContext; import com.carrotsearch.randomizedtesting.generators.RandomNumbers; +import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksAction; import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; @@ -563,6 +564,19 @@ public void testNonExistingTaskCancellation() throws Exception { public void testCancelConcurrently() throws Exception { setupTestNodes(Settings.EMPTY); final TaskManager taskManager = testNodes[0].transportService.getTaskManager(); + AtomicBoolean onTaskCancelled = new AtomicBoolean(); + AtomicBoolean onTaskCompleted = new AtomicBoolean(); + taskManager.addTaskEventListeners(new TaskManager.TaskEventListeners() { + @Override + public void onTaskCancelled(CancellableTask task) { + onTaskCancelled.set(true); + } + + @Override + public void onTaskCompleted(Task task) { + onTaskCompleted.set(true); + } + }); int numTasks = randomIntBetween(1, 10); List tasks = new ArrayList<>(numTasks); for (int i = 0; i < numTasks; i++) { @@ -577,11 +591,13 @@ public void testCancelConcurrently() throws Exception { threads[i] = new Thread(() -> { phaser.arriveAndAwaitAdvance(); taskManager.cancel(cancellingTask, "test", () -> assertTrue(notified.compareAndSet(idx, 0, 1))); + assertTrue(onTaskCancelled.get()); }); threads[i].start(); } phaser.arriveAndAwaitAdvance(); taskManager.unregister(cancellingTask); + assertTrue(onTaskCompleted.get()); for (int i = 0; i < threads.length; i++) { threads[i].join(); assertThat(notified.get(i), equalTo(1)); @@ -591,6 +607,21 @@ public void testCancelConcurrently() throws Exception { assertTrue(called.get()); } + public void testCancelWithCancellationListenerThrowingException() { + setupTestNodes(Settings.EMPTY); + final TaskManager taskManager = testNodes[0].transportService.getTaskManager(); + taskManager.addTaskEventListeners(new TaskManager.TaskEventListeners() { + @Override + public void onTaskCancelled(CancellableTask task) { + throw new OpenSearchException("Exception"); + } + }); + CancellableTask cancellableTask = (CancellableTask) taskManager.register("type-0", "action-0", new CancellableNodeRequest()); + AtomicBoolean taskCompleted = new AtomicBoolean(); + assertThrows(OpenSearchException.class, () -> taskManager.cancel(cancellableTask, "test", () -> taskCompleted.set(true))); + assertFalse(taskCompleted.get()); + } + private static void debugDelay(String name) { // Introduce an additional pseudo random repeatable race conditions String delayName = RandomizedContext.current().getRunnerSeedAsString() + ":" + name; diff --git a/server/src/test/java/org/opensearch/cluster/DiskUsageTests.java b/server/src/test/java/org/opensearch/cluster/DiskUsageTests.java index 1fad9ad5086d8..73349d45bd5c7 100644 --- a/server/src/test/java/org/opensearch/cluster/DiskUsageTests.java +++ b/server/src/test/java/org/opensearch/cluster/DiskUsageTests.java @@ -188,6 +188,7 @@ public void testFillDiskUsage() { null, null, null, + null, null ), new NodeStats( @@ -212,6 +213,7 @@ public void testFillDiskUsage() { null, null, null, + null, null ), new NodeStats( @@ -236,6 +238,7 @@ public void testFillDiskUsage() { null, null, null, + null, null ) ); @@ -291,6 +294,7 @@ public void testFillDiskUsageSomeInvalidValues() { null, null, null, + null, null ), new NodeStats( @@ -315,6 +319,7 @@ public void testFillDiskUsageSomeInvalidValues() { null, null, null, + null, null ), new NodeStats( @@ -339,6 +344,7 @@ public void testFillDiskUsageSomeInvalidValues() { null, null, null, + null, null ) ); diff --git a/server/src/test/java/org/opensearch/tasks/SearchShardTaskCancellationStatsTests.java b/server/src/test/java/org/opensearch/tasks/SearchShardTaskCancellationStatsTests.java new file mode 100644 index 0000000000000..4bab365536a49 --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/SearchShardTaskCancellationStatsTests.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.test.AbstractWireSerializingTestCase; + +public class SearchShardTaskCancellationStatsTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return SearchShardTaskCancellationStats::new; + } + + @Override + protected SearchShardTaskCancellationStats createTestInstance() { + return randomInstance(); + } + + public static SearchShardTaskCancellationStats randomInstance() { + return new SearchShardTaskCancellationStats(randomNonNegativeLong(), randomNonNegativeLong()); + } +} diff --git a/server/src/test/java/org/opensearch/tasks/TaskCancellationMonitoringServiceTests.java b/server/src/test/java/org/opensearch/tasks/TaskCancellationMonitoringServiceTests.java new file mode 100644 index 0000000000000..e068e6ee6e319 --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/TaskCancellationMonitoringServiceTests.java @@ -0,0 +1,371 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.Version; +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.transport.MockTransportService; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportRequest; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Phaser; +import java.util.concurrent.TimeUnit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.tasks.TaskCancellationMonitoringSettings.DURATION_MILLIS_SETTING; + +public class TaskCancellationMonitoringServiceTests extends OpenSearchTestCase { + + MockTransportService transportService; + TaskManager taskManager; + ThreadPool threadPool; + + @Before + public void setup() { + threadPool = new TestThreadPool(getClass().getName()); + transportService = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool); + transportService.start(); + transportService.acceptIncomingRequests(); + taskManager = transportService.getTaskManager(); + taskManager.setTaskCancellationService(new TaskCancellationService(transportService)); + } + + @After + public void cleanup() { + transportService.close(); + ThreadPool.terminate(threadPool, 5, TimeUnit.SECONDS); + } + + public void testWithNoCurrentRunningCancelledTasks() { + TaskCancellationMonitoringSettings settings = new TaskCancellationMonitoringSettings( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + TaskManager mockTaskManager = mock(TaskManager.class); + TaskCancellationMonitoringService taskCancellationMonitoringService = new TaskCancellationMonitoringService( + threadPool, + mockTaskManager, + settings + ); + + taskCancellationMonitoringService.doRun(); + // Task manager should not be invoked. + verify(mockTaskManager, times(0)).getTasks(); + } + + public void testWithNonZeroCancelledSearchShardTasksRunning() throws InterruptedException { + Settings settings = Settings.builder() + .put(DURATION_MILLIS_SETTING.getKey(), 0) // Setting to zero for testing + .build(); + TaskCancellationMonitoringSettings taskCancellationMonitoringSettings = new TaskCancellationMonitoringSettings( + settings, + new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + TaskCancellationMonitoringService taskCancellationMonitoringService = new TaskCancellationMonitoringService( + threadPool, + taskManager, + taskCancellationMonitoringSettings + ); + int numTasks = randomIntBetween(5, 50); + List tasks = createTasks(numTasks); + + int cancelFromIdx = randomIntBetween(0, numTasks - 1); + int cancelTillIdx = randomIntBetween(cancelFromIdx, numTasks - 1); + + int numberOfTasksCancelled = cancelTillIdx - cancelFromIdx + 1; + CountDownLatch countDownLatch = cancelTasksConcurrently(tasks, cancelFromIdx, cancelTillIdx); + + countDownLatch.await(); // Wait for all threads execution. + taskCancellationMonitoringService.doRun(); // 1st run to verify whether we are able to track running cancelled + // tasks. + TaskCancellationStats stats = taskCancellationMonitoringService.stats(); + assertEquals(numberOfTasksCancelled, stats.getSearchShardTaskCancellationStats().getCurrentLongRunningCancelledTaskCount()); + assertEquals(numberOfTasksCancelled, stats.getSearchShardTaskCancellationStats().getTotalLongRunningCancelledTaskCount()); + + taskCancellationMonitoringService.doRun(); // 2nd run. Verify same. + stats = taskCancellationMonitoringService.stats(); + assertEquals(numberOfTasksCancelled, stats.getSearchShardTaskCancellationStats().getCurrentLongRunningCancelledTaskCount()); + assertEquals(numberOfTasksCancelled, stats.getSearchShardTaskCancellationStats().getTotalLongRunningCancelledTaskCount()); + completeTasksConcurrently(tasks, 0, tasks.size() - 1).await(); + taskCancellationMonitoringService.doRun(); // 3rd run to verify current count is 0 and total remains the same. + stats = taskCancellationMonitoringService.stats(); + assertTrue(taskCancellationMonitoringService.getCancelledTaskTracker().isEmpty()); + assertEquals(0, stats.getSearchShardTaskCancellationStats().getCurrentLongRunningCancelledTaskCount()); + assertEquals(numberOfTasksCancelled, stats.getSearchShardTaskCancellationStats().getTotalLongRunningCancelledTaskCount()); + } + + public void testShouldRunGetsDisabledAfterTaskCompletion() throws InterruptedException { + Settings settings = Settings.builder() + .put(DURATION_MILLIS_SETTING.getKey(), 0) // Setting to zero for testing + .build(); + TaskCancellationMonitoringSettings taskCancellationMonitoringSettings = new TaskCancellationMonitoringSettings( + settings, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + TaskCancellationMonitoringService taskCancellationMonitoringService = new TaskCancellationMonitoringService( + threadPool, + taskManager, + taskCancellationMonitoringSettings + ); + assertTrue(taskCancellationMonitoringService.getCancelledTaskTracker().isEmpty()); + assertEquals(0, taskCancellationMonitoringService.getCancelledTaskTracker().size()); + + // Start few tasks. + int numTasks = randomIntBetween(5, 50); + List tasks = createTasks(numTasks); + + taskCancellationMonitoringService.doRun(); + TaskCancellationStats stats = taskCancellationMonitoringService.stats(); + // verify no cancelled tasks currently being recorded + assertEquals(0, stats.getSearchShardTaskCancellationStats().getCurrentLongRunningCancelledTaskCount()); + assertEquals(0, stats.getSearchShardTaskCancellationStats().getTotalLongRunningCancelledTaskCount()); + cancelTasksConcurrently(tasks, 0, tasks.size() - 1).await(); + taskCancellationMonitoringService.doRun(); + stats = taskCancellationMonitoringService.stats(); + assertFalse(taskCancellationMonitoringService.getCancelledTaskTracker().isEmpty()); + assertEquals(numTasks, stats.getSearchShardTaskCancellationStats().getCurrentLongRunningCancelledTaskCount()); + assertEquals(numTasks, stats.getSearchShardTaskCancellationStats().getTotalLongRunningCancelledTaskCount()); + + completeTasksConcurrently(tasks, 0, tasks.size() - 1).await(); + stats = taskCancellationMonitoringService.stats(); + assertTrue(taskCancellationMonitoringService.getCancelledTaskTracker().isEmpty()); + assertEquals(0, stats.getSearchShardTaskCancellationStats().getCurrentLongRunningCancelledTaskCount()); + assertEquals(numTasks, stats.getSearchShardTaskCancellationStats().getTotalLongRunningCancelledTaskCount()); + } + + public void testWithVaryingCancelledTasksDuration() throws InterruptedException { + long cancelledTaskDurationThresholdMilis = 2000; + Settings settings = Settings.builder() + .put(DURATION_MILLIS_SETTING.getKey(), cancelledTaskDurationThresholdMilis) // Setting to one for testing + .build(); + TaskCancellationMonitoringSettings taskCancellationMonitoringSettings = new TaskCancellationMonitoringSettings( + settings, + new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + TaskCancellationMonitoringService taskCancellationMonitoringService = new TaskCancellationMonitoringService( + threadPool, + taskManager, + taskCancellationMonitoringSettings + ); + + int numTasks = randomIntBetween(5, 50); + List tasks = createTasks(numTasks); + + int numTasksToBeCancelledInFirstIteration = randomIntBetween(1, numTasks - 1); + CountDownLatch countDownLatch = cancelTasksConcurrently(tasks, 0, numTasksToBeCancelledInFirstIteration - 1); + countDownLatch.await(); // Wait for all tasks to be cancelled in first iteration + + Thread.sleep(cancelledTaskDurationThresholdMilis); // Sleep, so we later verify whether above tasks are being + // captured as part of stats. + + taskCancellationMonitoringService.doRun(); + TaskCancellationStats stats = taskCancellationMonitoringService.stats(); + // Verify only tasks that were cancelled as part of first iteration is being captured as part of stats as + // they have been running longer as per threshold. + assertEquals( + numTasksToBeCancelledInFirstIteration, + stats.getSearchShardTaskCancellationStats().getCurrentLongRunningCancelledTaskCount() + ); + assertEquals( + numTasksToBeCancelledInFirstIteration, + stats.getSearchShardTaskCancellationStats().getTotalLongRunningCancelledTaskCount() + ); + + countDownLatch = cancelTasksConcurrently(tasks, numTasksToBeCancelledInFirstIteration, numTasks - 1); + countDownLatch.await(); // Wait for rest of tasks to be cancelled. + + Thread.sleep(cancelledTaskDurationThresholdMilis); // Sleep again, so we now verify whether all tasks are + // being captured as part of stats. + taskCancellationMonitoringService.doRun(); + stats = taskCancellationMonitoringService.stats(); + assertEquals(numTasks, stats.getSearchShardTaskCancellationStats().getCurrentLongRunningCancelledTaskCount()); + assertEquals(numTasks, stats.getSearchShardTaskCancellationStats().getTotalLongRunningCancelledTaskCount()); + + completeTasksConcurrently(tasks, 0, tasks.size() - 1).await(); + taskCancellationMonitoringService.doRun(); + stats = taskCancellationMonitoringService.stats(); + // Verify no current running tasks + assertEquals(0, stats.getSearchShardTaskCancellationStats().getCurrentLongRunningCancelledTaskCount()); + assertEquals(numTasks, stats.getSearchShardTaskCancellationStats().getTotalLongRunningCancelledTaskCount()); + } + + public void testTasksAreGettingEvictedCorrectlyAfterCompletion() throws InterruptedException { + Settings settings = Settings.builder() + .put(DURATION_MILLIS_SETTING.getKey(), 0) // Setting to one for testing + .build(); + TaskCancellationMonitoringSettings taskCancellationMonitoringSettings = new TaskCancellationMonitoringSettings( + settings, + new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + TaskCancellationMonitoringService taskCancellationMonitoringService = new TaskCancellationMonitoringService( + threadPool, + taskManager, + taskCancellationMonitoringSettings + ); + + // Start few tasks. + int numTasks = randomIntBetween(5, 50); + List tasks = createTasks(numTasks); + assertTrue(taskCancellationMonitoringService.getCancelledTaskTracker().isEmpty()); + int numTasksToBeCancelledInFirstIteration = randomIntBetween(2, numTasks - 1); + CountDownLatch countDownLatch = cancelTasksConcurrently(tasks, 0, numTasksToBeCancelledInFirstIteration - 1); + countDownLatch.await(); // Wait for all tasks to be cancelled in first iteration + + assertEquals(numTasksToBeCancelledInFirstIteration, taskCancellationMonitoringService.getCancelledTaskTracker().size()); + // Verify desired task ids are present. + for (int itr = 0; itr < numTasksToBeCancelledInFirstIteration; itr++) { + assertTrue(taskCancellationMonitoringService.getCancelledTaskTracker().containsKey(tasks.get(itr).getId())); + } + // Cancel rest of the tasks + cancelTasksConcurrently(tasks, numTasksToBeCancelledInFirstIteration, numTasks - 1).await(); + for (int itr = 0; itr < tasks.size(); itr++) { + assertTrue(taskCancellationMonitoringService.getCancelledTaskTracker().containsKey(tasks.get(itr).getId())); + } + // Complete one task to start with. + completeTasksConcurrently(tasks, 0, 0).await(); + assertFalse(taskCancellationMonitoringService.getCancelledTaskTracker().containsKey(tasks.get(0).getId())); + // Verify rest of the tasks are still present in tracker + for (int itr = 1; itr < tasks.size(); itr++) { + assertTrue(taskCancellationMonitoringService.getCancelledTaskTracker().containsKey(tasks.get(itr).getId())); + } + // Complete first iteration tasks + completeTasksConcurrently(tasks, 1, numTasksToBeCancelledInFirstIteration - 1).await(); + // Verify desired tasks were evicted from tracker map + for (int itr = 0; itr < numTasksToBeCancelledInFirstIteration; itr++) { + assertFalse(taskCancellationMonitoringService.getCancelledTaskTracker().containsKey(tasks.get(0).getId())); + } + // Verify rest of the tasks are still present in tracker + for (int itr = numTasksToBeCancelledInFirstIteration; itr < tasks.size(); itr++) { + assertTrue(taskCancellationMonitoringService.getCancelledTaskTracker().containsKey(tasks.get(itr).getId())); + } + // Complete all of them finally + completeTasksConcurrently(tasks, numTasksToBeCancelledInFirstIteration, tasks.size() - 1).await(); + assertTrue(taskCancellationMonitoringService.getCancelledTaskTracker().isEmpty()); + for (int itr = 0; itr < tasks.size(); itr++) { + assertFalse(taskCancellationMonitoringService.getCancelledTaskTracker().containsKey(tasks.get(itr).getId())); + } + } + + public void testDoStartAndStop() { + TaskCancellationMonitoringSettings settings = new TaskCancellationMonitoringSettings( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + ThreadPool mockThreadPool = mock(ThreadPool.class); + Scheduler.Cancellable scheduleFuture = mock(Scheduler.Cancellable.class); + when(scheduleFuture.cancel()).thenReturn(true); + when(mockThreadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(scheduleFuture); + TaskCancellationMonitoringService taskCancellationMonitoringService = new TaskCancellationMonitoringService( + mockThreadPool, + taskManager, + settings + ); + + taskCancellationMonitoringService.doStart(); + taskCancellationMonitoringService.doStop(); + verify(scheduleFuture, times(1)).cancel(); + } + + private List createTasks(int numTasks) { + List tasks = new ArrayList<>(numTasks); + for (int i = 0; i < numTasks; i++) { + tasks.add((SearchShardTask) taskManager.register("type-" + i, "action-" + i, new MockQuerySearchRequest())); + } + return tasks; + } + + // Caller can this with the list of tasks specifically mentioning which ones to cancel. And can call CountDownLatch + // .await() to wait for all tasks be cancelled. + private CountDownLatch cancelTasksConcurrently(List tasks, int cancelFromIdx, int cancelTillIdx) { + assert cancelFromIdx >= 0; + assert cancelTillIdx <= tasks.size() - 1; + assert cancelTillIdx >= cancelFromIdx; + int totalTasksToBeCancelled = cancelTillIdx - cancelFromIdx + 1; + Thread[] threads = new Thread[totalTasksToBeCancelled]; + Phaser phaser = new Phaser(totalTasksToBeCancelled + 1); // Used to concurrently cancel tasks by multiple threads. + CountDownLatch countDownLatch = new CountDownLatch(totalTasksToBeCancelled); // To wait for all threads to finish. + for (int i = 0; i < totalTasksToBeCancelled; i++) { + int idx = i + cancelFromIdx; + threads[i] = new Thread(() -> { + phaser.arriveAndAwaitAdvance(); + taskManager.cancel(tasks.get(idx), "test", () -> {}); + countDownLatch.countDown(); + }); + threads[i].start(); + } + phaser.arriveAndAwaitAdvance(); + return countDownLatch; + } + + private CountDownLatch completeTasksConcurrently(List tasks, int completeFromIdx, int completeTillIdx) { + assert completeFromIdx >= 0; + assert completeTillIdx <= tasks.size() - 1; + assert completeTillIdx >= completeFromIdx; + int totalTasksToBeCompleted = completeTillIdx - completeFromIdx + 1; + Thread[] threads = new Thread[totalTasksToBeCompleted]; + Phaser phaser = new Phaser(totalTasksToBeCompleted + 1); + CountDownLatch countDownLatch = new CountDownLatch(totalTasksToBeCompleted); + for (int i = 0; i < totalTasksToBeCompleted; i++) { + int idx = i + completeFromIdx; + threads[i] = new Thread(() -> { + phaser.arriveAndAwaitAdvance(); + taskManager.unregister(tasks.get(idx)); + countDownLatch.countDown(); + }); + threads[i].start(); + } + phaser.arriveAndAwaitAdvance(); + return countDownLatch; + } + + public static class MockQuerySearchRequest extends TransportRequest { + protected String requestName; + + public MockQuerySearchRequest() { + super(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(requestName); + } + + @Override + public String getDescription() { + return "MockQuerySearchRequest[" + requestName + "]"; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new SearchShardTask(id, type, action, getDescription(), parentTaskId, headers); + } + } + +} diff --git a/server/src/test/java/org/opensearch/tasks/TaskCancellationMonitoringSettingsTests.java b/server/src/test/java/org/opensearch/tasks/TaskCancellationMonitoringSettingsTests.java new file mode 100644 index 0000000000000..245f410b2609a --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/TaskCancellationMonitoringSettingsTests.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; + +public class TaskCancellationMonitoringSettingsTests extends OpenSearchTestCase { + + public void testDefaults() { + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + TaskCancellationMonitoringSettings settings = new TaskCancellationMonitoringSettings(Settings.EMPTY, clusterSettings); + assertEquals(TaskCancellationMonitoringSettings.DURATION_MILLIS_SETTING_DEFAULT_VALUE, settings.getDuration().millis()); + assertEquals(TaskCancellationMonitoringSettings.INTERVAL_MILLIS_SETTING_DEFAULT_VALUE, settings.getInterval().millis()); + assertEquals(TaskCancellationMonitoringSettings.IS_ENABLED_SETTING_DEFAULT_VALUE, settings.isEnabled()); + } + + public void testUpdate() { + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + TaskCancellationMonitoringSettings settings = new TaskCancellationMonitoringSettings(Settings.EMPTY, clusterSettings); + + Settings newSettings = Settings.builder() + .put(TaskCancellationMonitoringSettings.DURATION_MILLIS_SETTING.getKey(), 20000) + .put(TaskCancellationMonitoringSettings.INTERVAL_MILLIS_SETTING.getKey(), 2000) + .put(TaskCancellationMonitoringSettings.IS_ENABLED_SETTING.getKey(), false) + .build(); + clusterSettings.applySettings(newSettings); + assertEquals(20000, settings.getDuration().millis()); + assertFalse(settings.isEnabled()); + assertNotEquals(2000, settings.getInterval().millis()); + } +} diff --git a/server/src/test/java/org/opensearch/tasks/TaskCancellationStatsTests.java b/server/src/test/java/org/opensearch/tasks/TaskCancellationStatsTests.java new file mode 100644 index 0000000000000..a81110b59e98a --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/TaskCancellationStatsTests.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.test.AbstractWireSerializingTestCase; + +public class TaskCancellationStatsTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return TaskCancellationStats::new; + } + + @Override + protected TaskCancellationStats createTestInstance() { + return randomInstance(); + } + + public static TaskCancellationStats randomInstance() { + return new TaskCancellationStats(SearchShardTaskCancellationStatsTests.randomInstance()); + } +} diff --git a/test/framework/src/main/java/org/opensearch/cluster/MockInternalClusterInfoService.java b/test/framework/src/main/java/org/opensearch/cluster/MockInternalClusterInfoService.java index 007e717149a62..cf5f6613c3ea1 100644 --- a/test/framework/src/main/java/org/opensearch/cluster/MockInternalClusterInfoService.java +++ b/test/framework/src/main/java/org/opensearch/cluster/MockInternalClusterInfoService.java @@ -118,7 +118,8 @@ List adjustNodesStats(List nodesStats) { nodeStats.getSearchBackpressureStats(), nodeStats.getClusterManagerThrottlingStats(), nodeStats.getWeightedRoutingStats(), - nodeStats.getFileCacheStats() + nodeStats.getFileCacheStats(), + nodeStats.getTaskCancellationStats() ); }).collect(Collectors.toList()); } diff --git a/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java b/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java index 3faf13c373720..a3612167f16c3 100644 --- a/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java +++ b/test/framework/src/main/java/org/opensearch/test/InternalTestCluster.java @@ -2684,6 +2684,7 @@ public void ensureEstimatedStats() { false, false, false, + false, false ); assertThat(