From d5a900c74bcf73a8796e13849227ebf8a67b539b Mon Sep 17 00:00:00 2001 From: David Turner Date: Thu, 17 Feb 2022 08:42:24 +0000 Subject: [PATCH] Preserve context in ResultDeduplicator (#84038) (#84085) Today the `ResultDeduplicator` may complete a collection of listeners in contexts different from the ones in which they were submitted. This commit makes sure that the context is preserved in the listener. --- docs/changelog/84038.yaml | 6 ++++ .../action/ResultDeduplicator.java | 10 +++++- .../action/shard/ShardStateAction.java | 3 +- .../blobstore/BlobStoreRepository.java | 3 +- .../snapshots/SnapshotShardsService.java | 4 +-- .../tasks/TaskCancellationService.java | 3 +- .../elasticsearch/tasks/TaskManagerTests.java | 13 +++++-- .../transport/ResultDeduplicatorTests.java | 35 ++++++++++++------- 8 files changed, 56 insertions(+), 21 deletions(-) create mode 100644 docs/changelog/84038.yaml diff --git a/docs/changelog/84038.yaml b/docs/changelog/84038.yaml new file mode 100644 index 0000000000000..c4f07f6d3aefa --- /dev/null +++ b/docs/changelog/84038.yaml @@ -0,0 +1,6 @@ +pr: 84038 +summary: Preserve context in `ResultDeduplicator` +area: Infra/Core +type: bug +issues: + - 84036 diff --git a/server/src/main/java/org/elasticsearch/action/ResultDeduplicator.java b/server/src/main/java/org/elasticsearch/action/ResultDeduplicator.java index 8f3e7ee60b242..b63eeaf64e505 100644 --- a/server/src/main/java/org/elasticsearch/action/ResultDeduplicator.java +++ b/server/src/main/java/org/elasticsearch/action/ResultDeduplicator.java @@ -8,7 +8,9 @@ package org.elasticsearch.action; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.common.util.concurrent.ThreadContext; import java.util.ArrayList; import java.util.List; @@ -22,8 +24,13 @@ */ public final class ResultDeduplicator { + private final ThreadContext threadContext; private final ConcurrentMap requests = ConcurrentCollections.newConcurrentMap(); + public ResultDeduplicator(ThreadContext threadContext) { + this.threadContext = threadContext; + } + /** * Ensures a given request not executed multiple times when another equal request is already in-flight. * If the request is not yet known to the deduplicator it will invoke the passed callback with an {@link ActionListener} @@ -35,7 +42,8 @@ public final class ResultDeduplicator { * @param callback Callback to be invoked with request and completion listener the first time the request is added to the deduplicator */ public void executeOnce(T request, ActionListener listener, BiConsumer> callback) { - ActionListener completionListener = requests.computeIfAbsent(request, CompositeListener::new).addListener(listener); + ActionListener completionListener = requests.computeIfAbsent(request, CompositeListener::new) + .addListener(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext)); if (completionListener != null) { callback.accept(request, completionListener); } diff --git a/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java b/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java index ac92b489ebb48..9f3152863f17a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java +++ b/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java @@ -81,7 +81,7 @@ public class ShardStateAction { private final ThreadPool threadPool; // we deduplicate these shard state requests in order to avoid sending duplicate failed/started shard requests for a shard - private final ResultDeduplicator remoteShardStateUpdateDeduplicator = new ResultDeduplicator<>(); + private final ResultDeduplicator remoteShardStateUpdateDeduplicator; @Inject public ShardStateAction( @@ -94,6 +94,7 @@ public ShardStateAction( this.transportService = transportService; this.clusterService = clusterService; this.threadPool = threadPool; + this.remoteShardStateUpdateDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext()); transportService.registerRequestHandler( SHARD_STARTED_ACTION_NAME, diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index 28e1897a0272d..b80a0124bc5d9 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -403,6 +403,7 @@ protected BlobStoreRepository( this.namedXContentRegistry = namedXContentRegistry; this.basePath = basePath; this.maxSnapshotCount = MAX_SNAPSHOTS_SETTING.get(metadata.settings()); + this.repoDataDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext()); } @Override @@ -1866,7 +1867,7 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) * {@link #bestEffortConsistency} must be {@code false}, in which case we can assume that the {@link RepositoryData} loaded is * unique for a given value of {@link #metadata} at any point in time. */ - private final ResultDeduplicator repoDataDeduplicator = new ResultDeduplicator<>(); + private final ResultDeduplicator repoDataDeduplicator; private void doGetRepositoryData(ActionListener listener) { // Retry loading RepositoryData in a loop in case we run into concurrent modifications of the repository. diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java index 646df885cb48c..4223a4239c3a3 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java @@ -82,8 +82,7 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements private final Map> shardSnapshots = new HashMap<>(); // A map of snapshots to the shardIds that we already reported to the master as failed - private final ResultDeduplicator remoteFailedRequestDeduplicator = - new ResultDeduplicator<>(); + private final ResultDeduplicator remoteFailedRequestDeduplicator; public SnapshotShardsService( Settings settings, @@ -97,6 +96,7 @@ public SnapshotShardsService( this.transportService = transportService; this.clusterService = clusterService; this.threadPool = transportService.getThreadPool(); + this.remoteFailedRequestDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext()); if (DiscoveryNode.canContainData(settings)) { // this is only useful on the nodes that can hold data clusterService.addListener(this); diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java index bd6078ec558e5..cd5bbd56a315a 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java @@ -44,11 +44,12 @@ public class TaskCancellationService { private static final Logger logger = LogManager.getLogger(TaskCancellationService.class); private final TransportService transportService; private final TaskManager taskManager; - private final ResultDeduplicator deduplicator = new ResultDeduplicator<>(); + private final ResultDeduplicator deduplicator; public TaskCancellationService(TransportService transportService) { this.transportService = transportService; this.taskManager = transportService.getTaskManager(); + this.deduplicator = new ResultDeduplicator<>(transportService.getThreadPool().getThreadContext()); transportService.registerRequestHandler( BAN_PARENT_ACTION_NAME, ThreadPool.Names.SAME, diff --git a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java index 9e8fc5c8983a6..6e40e9434141e 100644 --- a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java @@ -46,6 +46,7 @@ import static org.hamcrest.Matchers.everyItem; import static org.hamcrest.Matchers.in; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class TaskManagerTests extends ESTestCase { private ThreadPool threadPool; @@ -76,7 +77,9 @@ public void testResultsServiceRetryTotalTime() { public void testTrackingChannelTask() throws Exception { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of()); Set cancelledTasks = ConcurrentCollections.newConcurrentSet(); - taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) { + final var transportServiceMock = mock(TransportService.class); + when(transportServiceMock.getThreadPool()).thenReturn(threadPool); + taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) { @Override void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) { assertThat(reason, equalTo("channel was closed")); @@ -124,7 +127,9 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF public void testTrackingTaskAndCloseChannelConcurrently() throws Exception { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of()); Set cancelledTasks = ConcurrentCollections.newConcurrentSet(); - taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) { + final var transportServiceMock = mock(TransportService.class); + when(transportServiceMock.getThreadPool()).thenReturn(threadPool); + taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) { @Override void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) { assertTrue("task [" + task + "] was cancelled already", cancelledTasks.add(task)); @@ -180,7 +185,9 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF public void testRemoveBansOnChannelDisconnects() throws Exception { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of()); - taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) { + final var transportServiceMock = mock(TransportService.class); + when(transportServiceMock.getThreadPool()).thenReturn(threadPool); + taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) { @Override void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) {} }); diff --git a/server/src/test/java/org/elasticsearch/transport/ResultDeduplicatorTests.java b/server/src/test/java/org/elasticsearch/transport/ResultDeduplicatorTests.java index 2bdfa3cc7865c..2d9fa940d5d5a 100644 --- a/server/src/test/java/org/elasticsearch/transport/ResultDeduplicatorTests.java +++ b/server/src/test/java/org/elasticsearch/transport/ResultDeduplicatorTests.java @@ -10,6 +10,8 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ResultDeduplicator; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; @@ -29,8 +31,11 @@ public void testRequestDeduplication() throws Exception { @Override public void setParentTask(final TaskId taskId) {} }; - final ResultDeduplicator deduplicator = new ResultDeduplicator<>(); + final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + final ResultDeduplicator deduplicator = new ResultDeduplicator<>(threadContext); final SetOnce> listenerHolder = new SetOnce<>(); + final var headerName = "thread-context-header"; + final var headerGenerator = new AtomicInteger(); int iterationsPerThread = scaledRandomIntBetween(100, 1000); Thread[] threads = new Thread[between(1, 4)]; Phaser barrier = new Phaser(threads.length + 1); @@ -38,18 +43,24 @@ public void setParentTask(final TaskId taskId) {} threads[i] = new Thread(() -> { barrier.arriveAndAwaitAdvance(); for (int n = 0; n < iterationsPerThread; n++) { - deduplicator.executeOnce(request, new ActionListener() { - @Override - public void onResponse(Void aVoid) { - successCount.incrementAndGet(); - } + final var headerValue = Integer.toString(headerGenerator.incrementAndGet()); + try (var ignored = threadContext.stashContext()) { + threadContext.putHeader(headerName, headerValue); + deduplicator.executeOnce(request, new ActionListener<>() { + @Override + public void onResponse(Void aVoid) { + assertThat(threadContext.getHeader(headerName), equalTo(headerValue)); + successCount.incrementAndGet(); + } - @Override - public void onFailure(Exception e) { - assertThat(e, sameInstance(failure)); - failureCount.incrementAndGet(); - } - }, (req, reqListener) -> listenerHolder.set(reqListener)); + @Override + public void onFailure(Exception e) { + assertThat(threadContext.getHeader(headerName), equalTo(headerValue)); + assertThat(e, sameInstance(failure)); + failureCount.incrementAndGet(); + } + }, (req, reqListener) -> listenerHolder.set(reqListener)); + } } }); threads[i].start();