From 7d11e4163bb5a167b774c16c127711ca828d76f5 Mon Sep 17 00:00:00 2001 From: Artem Prigoda Date: Wed, 5 Jul 2023 11:00:52 +0200 Subject: [PATCH] [7.17] Preserve context in ResultDeduplicator (#84038) (#96868) Today the `ResultDeduplicator` may complete a collection of listeners in contexts different from the ones in which they were submitted. This commit makes sure that the context is preserved in the listener. Co-authored-by: David Turner --- docs/changelog/84038.yaml | 6 ++++ .../action/ResultDeduplicator.java | 10 +++++- .../action/shard/ShardStateAction.java | 3 +- .../snapshots/SnapshotShardsService.java | 4 +-- .../tasks/TaskCancellationService.java | 3 +- .../elasticsearch/tasks/TaskManagerTests.java | 13 +++++-- .../transport/ResultDeduplicatorTests.java | 35 ++++++++++++------- 7 files changed, 54 insertions(+), 20 deletions(-) create mode 100644 docs/changelog/84038.yaml diff --git a/docs/changelog/84038.yaml b/docs/changelog/84038.yaml new file mode 100644 index 0000000000000..c4f07f6d3aefa --- /dev/null +++ b/docs/changelog/84038.yaml @@ -0,0 +1,6 @@ +pr: 84038 +summary: Preserve context in `ResultDeduplicator` +area: Infra/Core +type: bug +issues: + - 84036 diff --git a/server/src/main/java/org/elasticsearch/action/ResultDeduplicator.java b/server/src/main/java/org/elasticsearch/action/ResultDeduplicator.java index 8f3e7ee60b242..b63eeaf64e505 100644 --- a/server/src/main/java/org/elasticsearch/action/ResultDeduplicator.java +++ b/server/src/main/java/org/elasticsearch/action/ResultDeduplicator.java @@ -8,7 +8,9 @@ package org.elasticsearch.action; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.common.util.concurrent.ThreadContext; import java.util.ArrayList; import java.util.List; @@ -22,8 +24,13 @@ */ public final class ResultDeduplicator { + private final ThreadContext threadContext; private final ConcurrentMap requests = ConcurrentCollections.newConcurrentMap(); + public ResultDeduplicator(ThreadContext threadContext) { + this.threadContext = threadContext; + } + /** * Ensures a given request not executed multiple times when another equal request is already in-flight. * If the request is not yet known to the deduplicator it will invoke the passed callback with an {@link ActionListener} @@ -35,7 +42,8 @@ public final class ResultDeduplicator { * @param callback Callback to be invoked with request and completion listener the first time the request is added to the deduplicator */ public void executeOnce(T request, ActionListener listener, BiConsumer> callback) { - ActionListener completionListener = requests.computeIfAbsent(request, CompositeListener::new).addListener(listener); + ActionListener completionListener = requests.computeIfAbsent(request, CompositeListener::new) + .addListener(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext)); if (completionListener != null) { callback.accept(request, completionListener); } diff --git a/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java b/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java index 8c6b8cc0eb2a8..99450798b0792 100644 --- a/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java +++ b/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java @@ -118,7 +118,7 @@ private static Priority parseReroutePriority(String priorityString) { // a list of shards that failed during replication // we keep track of these shards in order to avoid sending duplicate failed shard requests for a single failing shard. - private final ResultDeduplicator remoteFailedShardsDeduplicator = new ResultDeduplicator<>(); + private final ResultDeduplicator remoteFailedShardsDeduplicator; @Inject public ShardStateAction( @@ -131,6 +131,7 @@ public ShardStateAction( this.transportService = transportService; this.clusterService = clusterService; this.threadPool = threadPool; + remoteFailedShardsDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext()); followUpRerouteTaskPriority = FOLLOW_UP_REROUTE_PRIORITY_SETTING.get(clusterService.getSettings()); clusterService.getClusterSettings() diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java index 2cb2d231f3597..58e2525ecb9f4 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java @@ -85,8 +85,7 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements private final Map> shardSnapshots = new HashMap<>(); // A map of snapshots to the shardIds that we already reported to the master as failed - private final ResultDeduplicator remoteFailedRequestDeduplicator = - new ResultDeduplicator<>(); + private final ResultDeduplicator remoteFailedRequestDeduplicator; public SnapshotShardsService( Settings settings, @@ -100,6 +99,7 @@ public SnapshotShardsService( this.transportService = transportService; this.clusterService = clusterService; this.threadPool = transportService.getThreadPool(); + this.remoteFailedRequestDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext()); if (DiscoveryNode.canContainData(settings)) { // this is only useful on the nodes that can hold data clusterService.addListener(this); diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java index b7eafc7b044dd..f2e9ca821ae7e 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java @@ -44,11 +44,12 @@ public class TaskCancellationService { private static final Logger logger = LogManager.getLogger(TaskCancellationService.class); private final TransportService transportService; private final TaskManager taskManager; - private final ResultDeduplicator deduplicator = new ResultDeduplicator<>(); + private final ResultDeduplicator deduplicator; public TaskCancellationService(TransportService transportService) { this.transportService = transportService; this.taskManager = transportService.getTaskManager(); + this.deduplicator = new ResultDeduplicator<>(transportService.getThreadPool().getThreadContext()); transportService.registerRequestHandler( BAN_PARENT_ACTION_NAME, ThreadPool.Names.SAME, diff --git a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java index 58d998f1e0b0e..9d47e61fc1268 100644 --- a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java @@ -46,6 +46,7 @@ import static org.hamcrest.Matchers.everyItem; import static org.hamcrest.Matchers.in; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class TaskManagerTests extends ESTestCase { private ThreadPool threadPool; @@ -76,7 +77,9 @@ public void testResultsServiceRetryTotalTime() { public void testTrackingChannelTask() throws Exception { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); Set cancelledTasks = ConcurrentCollections.newConcurrentSet(); - taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) { + final TransportService transportServiceMock = mock(TransportService.class); + when(transportServiceMock.getThreadPool()).thenReturn(threadPool); + taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) { @Override void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) { assertThat(reason, equalTo("channel was closed")); @@ -124,7 +127,9 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF public void testTrackingTaskAndCloseChannelConcurrently() throws Exception { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); Set cancelledTasks = ConcurrentCollections.newConcurrentSet(); - taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) { + final TransportService transportServiceMock = mock(TransportService.class); + when(transportServiceMock.getThreadPool()).thenReturn(threadPool); + taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) { @Override void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) { assertTrue("task [" + task + "] was cancelled already", cancelledTasks.add(task)); @@ -180,7 +185,9 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF public void testRemoveBansOnChannelDisconnects() throws Exception { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); - taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) { + final TransportService transportServiceMock = mock(TransportService.class); + when(transportServiceMock.getThreadPool()).thenReturn(threadPool); + taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) { @Override void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) {} }); diff --git a/server/src/test/java/org/elasticsearch/transport/ResultDeduplicatorTests.java b/server/src/test/java/org/elasticsearch/transport/ResultDeduplicatorTests.java index 2bdfa3cc7865c..479fff73d1152 100644 --- a/server/src/test/java/org/elasticsearch/transport/ResultDeduplicatorTests.java +++ b/server/src/test/java/org/elasticsearch/transport/ResultDeduplicatorTests.java @@ -10,6 +10,8 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ResultDeduplicator; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; @@ -29,8 +31,11 @@ public void testRequestDeduplication() throws Exception { @Override public void setParentTask(final TaskId taskId) {} }; - final ResultDeduplicator deduplicator = new ResultDeduplicator<>(); + final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + final ResultDeduplicator deduplicator = new ResultDeduplicator<>(threadContext); final SetOnce> listenerHolder = new SetOnce<>(); + final String headerName = "thread-context-header"; + final AtomicInteger headerGenerator = new AtomicInteger(); int iterationsPerThread = scaledRandomIntBetween(100, 1000); Thread[] threads = new Thread[between(1, 4)]; Phaser barrier = new Phaser(threads.length + 1); @@ -38,18 +43,24 @@ public void setParentTask(final TaskId taskId) {} threads[i] = new Thread(() -> { barrier.arriveAndAwaitAdvance(); for (int n = 0; n < iterationsPerThread; n++) { - deduplicator.executeOnce(request, new ActionListener() { - @Override - public void onResponse(Void aVoid) { - successCount.incrementAndGet(); - } + final String headerValue = Integer.toString(headerGenerator.incrementAndGet()); + try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { + threadContext.putHeader(headerName, headerValue); + deduplicator.executeOnce(request, new ActionListener() { + @Override + public void onResponse(Void aVoid) { + assertThat(threadContext.getHeader(headerName), equalTo(headerValue)); + successCount.incrementAndGet(); + } - @Override - public void onFailure(Exception e) { - assertThat(e, sameInstance(failure)); - failureCount.incrementAndGet(); - } - }, (req, reqListener) -> listenerHolder.set(reqListener)); + @Override + public void onFailure(Exception e) { + assertThat(threadContext.getHeader(headerName), equalTo(headerValue)); + assertThat(e, sameInstance(failure)); + failureCount.incrementAndGet(); + } + }, (req, reqListener) -> listenerHolder.set(reqListener)); + } } }); threads[i].start();