Skip to content

Commit

Permalink
Add CountDownActionListener (#92308)
Browse files Browse the repository at this point in the history
  • Loading branch information
joegallo authored Dec 19, 2022
1 parent 8bccf66 commit 2d2b82b
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.action.support;

import org.elasticsearch.action.ActionListener;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
* Wraps another listener and adds a counter -- each invocation of this listener will decrement the counter, and when the counter has been
* exhausted the final invocation of this listener will delegate to the wrapped listener. Similar to {@link GroupedActionListener}, but for
* the cases where tracking individual results is not useful.
*/
public final class CountDownActionListener extends ActionListener.Delegating<Void, Void> {

private final AtomicInteger countDown;
private final AtomicReference<Exception> failure = new AtomicReference<>();

/**
* Creates a new listener
* @param groupSize the group size
* @param delegate the delegate listener
*/
public CountDownActionListener(int groupSize, ActionListener<Void> delegate) {
super(Objects.requireNonNull(delegate));
if (groupSize <= 0) {
assert false : "illegal group size [" + groupSize + "]";
throw new IllegalArgumentException("groupSize must be greater than 0 but was " + groupSize);
}
countDown = new AtomicInteger(groupSize);
}

/**
* Creates a new listener
* @param groupSize the group size
* @param runnable the runnable
*/
public CountDownActionListener(int groupSize, Runnable runnable) {
this(groupSize, ActionListener.wrap(Objects.requireNonNull(runnable)));
}

private boolean countDown() {
final var result = countDown.getAndUpdate(current -> Math.max(0, current - 1));
assert result > 0;
return result == 1;
}

@Override
public void onResponse(Void element) {
if (countDown()) {
if (failure.get() != null) {
super.onFailure(failure.get());
} else {
delegate.onResponse(element);
}
}
}

@Override
public void onFailure(Exception e) {
if (failure.compareAndSet(null, e) == false) {
failure.accumulateAndGet(e, (current, update) -> {
// we have to avoid self-suppression!
if (update != current) {
current.addSuppressed(update);
}
return current;
});
}
if (countDown()) {
super.onFailure(failure.get());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.cluster.coordination.FollowersChecker;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
Expand Down Expand Up @@ -97,10 +97,7 @@ public void connectToNodes(DiscoveryNodes discoveryNodes, Runnable onCompletion)
return;
}

final GroupedActionListener<Void> listener = new GroupedActionListener<>(
discoveryNodes.getSize(),
ActionListener.wrap(onCompletion)
);
final CountDownActionListener listener = new CountDownActionListener(discoveryNodes.getSize(), onCompletion);

final List<Runnable> runnables = new ArrayList<>(discoveryNodes.getSize());
synchronized (mutex) {
Expand Down Expand Up @@ -159,10 +156,7 @@ void ensureConnections(Runnable onCompletion) {
runnables.add(onCompletion);
} else {
logger.trace("ensureConnections: {}", targetsByNode);
final GroupedActionListener<Void> listener = new GroupedActionListener<>(
connectionTargets.size(),
ActionListener.wrap(onCompletion)
);
final CountDownActionListener listener = new CountDownActionListener(connectionTargets.size(), onCompletion);
for (final ConnectionTarget connectionTarget : connectionTargets) {
runnables.add(connectionTarget.connect(listener));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterInfo;
import org.elasticsearch.cluster.ClusterState;
Expand Down Expand Up @@ -301,7 +301,7 @@ public void onNewInfo(ClusterInfo info) {
}
}

final ActionListener<Void> listener = new GroupedActionListener<>(3, ActionListener.wrap(this::checkFinished));
final ActionListener<Void> listener = new CountDownActionListener(3, this::checkFinished);

if (reroute) {
logger.debug("rerouting shards: [{}]", explanation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.SingleResultDeduplicator;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.ListenableActionFuture;
import org.elasticsearch.action.support.PlainActionFuture;
Expand Down Expand Up @@ -964,7 +965,7 @@ private void doDeleteShardSnapshots(
writeUpdatedRepoDataStep.whenComplete(updatedRepoData -> {
listener.onRepositoryDataWritten(updatedRepoData);
// Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion
final ActionListener<Void> afterCleanupsListener = new GroupedActionListener<>(2, ActionListener.wrap(listener::onDone));
final ActionListener<Void> afterCleanupsListener = new CountDownActionListener(2, listener::onDone);
cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, updatedRepoData, afterCleanupsListener);
asyncCleanupUnlinkedShardLevelBlobs(
repositoryData,
Expand All @@ -978,10 +979,10 @@ private void doDeleteShardSnapshots(
final RepositoryData updatedRepoData = repositoryData.removeSnapshots(snapshotIds, ShardGenerations.EMPTY);
writeIndexGen(updatedRepoData, repositoryStateId, repoMetaVersion, Function.identity(), ActionListener.wrap(newRepoData -> {
// Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion
final ActionListener<Void> afterCleanupsListener = new GroupedActionListener<>(2, ActionListener.wrap(() -> {
final ActionListener<Void> afterCleanupsListener = new CountDownActionListener(2, () -> {
listener.onRepositoryDataWritten(newRepoData);
listener.onDone();
}));
});
cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, newRepoData, afterCleanupsListener);
final StepListener<Collection<ShardSnapshotMetaDeleteResult>> writeMetaAndComputeDeletesStep = new StepListener<>();
writeUpdatedShardMetaDataAndComputeDeletes(snapshotIds, repositoryData, false, writeMetaAndComputeDeletesStep);
Expand Down Expand Up @@ -1414,7 +1415,7 @@ public void finalizeSnapshot(final FinalizeSnapshotContext finalizeSnapshotConte
indexMetaIdentifiers = null;
}

final ActionListener<Void> allMetaListener = new GroupedActionListener<>(2 + indices.size(), ActionListener.wrap(v -> {
final ActionListener<Void> allMetaListener = new CountDownActionListener(2 + indices.size(), ActionListener.wrap(v -> {
final String slmPolicy = slmPolicy(snapshotInfo);
final SnapshotDetails snapshotDetails = new SnapshotDetails(
snapshotInfo.state(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotRequest;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
Expand Down Expand Up @@ -525,11 +525,11 @@ static void refreshRepositoryUuids(boolean enabled, RepositoriesService reposito
"refreshing repository UUIDs for repositories [{}]",
repositories.stream().map(repository -> repository.getMetadata().name()).collect(Collectors.joining(","))
);
final ActionListener<RepositoryData> groupListener = new GroupedActionListener<>(
final ActionListener<RepositoryData> countDownListener = new CountDownActionListener(
repositories.size(),
new ActionListener<Collection<Void>>() {
new ActionListener<Void>() {
@Override
public void onResponse(Collection<Void> ignored) {
public void onResponse(Void ignored) {
logger.debug("repository UUID refresh completed");
refreshListener.onResponse(null);
}
Expand All @@ -543,7 +543,7 @@ public void onFailure(Exception e) {
).map(repositoryData -> null /* don't collect the RepositoryData */);

for (Repository repository : repositories) {
repository.getRepositoryData(groupListener);
repository.getRepositoryData(countDownListener);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.action.ResultDeduplicator;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -99,18 +100,18 @@ void doCancelTaskAndDescendants(CancellableTask task, String reason, boolean wai
if (task.shouldCancelChildrenOnCancellation()) {
logger.trace("cancelling task [{}] and its descendants", taskId);
StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(3, completedListener.map(r -> null));
CountDownActionListener countDownListener = new CountDownActionListener(3, completedListener);
Collection<Transport.Connection> childConnections = taskManager.startBanOnChildTasks(task.getId(), reason, () -> {
logger.trace("child tasks of parent [{}] are completed", taskId);
groupedListener.onResponse(null);
countDownListener.onResponse(null);
});
taskManager.cancel(task, reason, () -> {
logger.trace("task [{}] is cancelled", taskId);
groupedListener.onResponse(null);
countDownListener.onResponse(null);
});
StepListener<Void> setBanListener = new StepListener<>();
setBanOnChildConnections(reason, waitForCompletion, task, childConnections, setBanListener);
setBanListener.addListener(groupedListener);
setBanListener.addListener(countDownListener);
// If we start unbanning when the last child task completed and that child task executed with a specific user, then unban
// requests are denied because internal requests can't run with a user. We need to remove bans with the current thread context.
final Runnable removeBansRunnable = transportService.getThreadPool()
Expand Down Expand Up @@ -149,7 +150,7 @@ private void setBanOnChildConnections(
}
final TaskId taskId = new TaskId(localNodeId(), task.getId());
logger.trace("cancelling child tasks of [{}] on child connections {}", taskId, childConnections);
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(childConnections.size(), listener.map(r -> null));
CountDownActionListener countDownListener = new CountDownActionListener(childConnections.size(), listener);
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion);
for (Transport.Connection connection : childConnections) {
assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
Expand All @@ -162,7 +163,7 @@ private void setBanOnChildConnections(
@Override
public void handleResponse(TransportResponse.Empty response) {
logger.trace("sent ban for tasks with the parent [{}] for connection [{}]", taskId, connection);
groupedListener.onResponse(null);
countDownListener.onResponse(null);
}

@Override
Expand All @@ -188,7 +189,7 @@ public void handleException(TransportException exp) {
);
}

groupedListener.onFailure(exp);
countDownListener.onFailure(exp);
}
}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
Expand Down Expand Up @@ -332,14 +332,14 @@ synchronized void updateRemoteCluster(String clusterAlias, Settings newSettings,
*/
void initializeRemoteClusters() {
final TimeValue timeValue = REMOTE_INITIAL_CONNECTION_TIMEOUT_SETTING.get(settings);
final PlainActionFuture<Collection<Void>> future = new PlainActionFuture<>();
final PlainActionFuture<Void> future = new PlainActionFuture<>();
Set<String> enabledClusters = RemoteClusterAware.getEnabledRemoteClusters(settings);

if (enabledClusters.isEmpty()) {
return;
}

GroupedActionListener<Void> listener = new GroupedActionListener<>(enabledClusters.size(), future);
CountDownActionListener listener = new CountDownActionListener(enabledClusters.size(), future);
for (String clusterAlias : enabledClusters) {
updateRemoteCluster(clusterAlias, settings, listener);
}
Expand Down
Loading

0 comments on commit 2d2b82b

Please sign in to comment.