Skip to content

Commit

Permalink
Improve scalability of BroadcastReplicationActions (#92902)
Browse files Browse the repository at this point in the history
BroadcastReplicationAction derivatives (`POST /<indices>/_refresh` and
`POST /<indices>/_flush`) are pretty inefficient when targeting high
shard counts due to how `TransportBroadcastReplicationAction` works:

- It computes the list of all target shards up-front on the calling (transport) thread.

- It accumulates responses in a `CopyOnWriteArrayList` which takes quadratic work to populate, even though nothing reads this list until it's fully populated.

- It then mostly discards the accumulated responses, keeping only the total number of shards, the number of successful shards, and a list of any failures.

- Each failure is wrapped up in a `ReplicationResponse.ShardInfo.Failure` but then unwrapped at the end to be re-wrapped in a `DefaultShardOperationFailedException`.

This commit fixes all this:

- The computation of the list of shards, and the sending of the per-shard requests, now happens on the relevant threadpool (`REFRESH` or `FLUSH`) rather than a transport thread.

- The failures are tracked in a regular `ArrayList`, avoiding the accidentally-quadratic complexity.

- Rather than accumulating the full responses for later processing we track the counts and failures directly.

- The failures are tracked in their final form, skipping the unwrap-and-rewrap step at the end.

Relates #77466 Relates #92729
  • Loading branch information
DaveCTurner authored Jan 13, 2023
1 parent 1a9150d commit 4aa4a0d
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

import java.util.List;
Expand Down Expand Up @@ -46,15 +47,11 @@ public TransportFlushAction(
client,
actionFilters,
indexNameExpressionResolver,
TransportShardFlushAction.TYPE
TransportShardFlushAction.TYPE,
ThreadPool.Names.FLUSH
);
}

@Override
protected ReplicationResponse newShardResponse() {
return new ReplicationResponse();
}

@Override
protected ShardFlushRequest newShardRequest(FlushRequest request, ShardId shardId) {
return new ShardFlushRequest(request, shardId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

import java.util.List;
Expand Down Expand Up @@ -48,15 +49,11 @@ public TransportRefreshAction(
client,
actionFilters,
indexNameExpressionResolver,
TransportShardRefreshAction.TYPE
TransportShardRefreshAction.TYPE,
ThreadPool.Names.REFRESH
);
}

@Override
protected ReplicationResponse newShardResponse() {
return new ReplicationResponse();
}

@Override
protected BasicReplicationRequest newShardRequest(RefreshRequest request, ShardId shardId) {
BasicReplicationRequest replicationRequest = new BasicReplicationRequest(shardId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

package org.elasticsearch.action.support.replication;

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.DefaultShardOperationFailedException;
Expand All @@ -26,14 +26,16 @@
import org.elasticsearch.cluster.routing.IndexRoutingTable;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.Transports;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.Map;

/**
* Base class for requests that should be executed on all shards of an index or several indices.
Expand All @@ -49,6 +51,7 @@ public abstract class TransportBroadcastReplicationAction<
private final ClusterService clusterService;
private final IndexNameExpressionResolver indexNameExpressionResolver;
private final NodeClient client;
private final String executor;

public TransportBroadcastReplicationAction(
String name,
Expand All @@ -58,58 +61,112 @@ public TransportBroadcastReplicationAction(
NodeClient client,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
ActionType<ShardResponse> replicatedBroadcastShardAction
ActionType<ShardResponse> replicatedBroadcastShardAction,
String executor
) {
super(name, transportService, actionFilters, requestReader);
this.client = client;
this.replicatedBroadcastShardAction = replicatedBroadcastShardAction;
this.clusterService = clusterService;
this.indexNameExpressionResolver = indexNameExpressionResolver;
this.executor = executor;
}

@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
final ClusterState clusterState = clusterService.state();
List<ShardId> shards = shards(request, clusterState);
final CopyOnWriteArrayList<ShardResponse> shardsResponses = new CopyOnWriteArrayList<>();
try (var refs = new RefCountingRunnable(() -> finishAndNotifyListener(listener, shardsResponses))) {
for (final ShardId shardId : shards) {
ActionListener<ShardResponse> shardActionListener = new ActionListener<ShardResponse>() {
@Override
public void onResponse(ShardResponse shardResponse) {
shardsResponses.add(shardResponse);
logger.trace("{}: got response from {}", actionName, shardId);
clusterService.threadPool().executor(executor).execute(ActionRunnable.wrap(listener, createAsyncAction(task, request)));
}

private CheckedConsumer<ActionListener<Response>, Exception> createAsyncAction(Task task, Request request) {
return new CheckedConsumer<ActionListener<Response>, Exception>() {

private int totalShardCopyCount;
private int successShardCopyCount;
private final List<DefaultShardOperationFailedException> allFailures = new ArrayList<>();

@Override
public void accept(ActionListener<Response> listener) {
assert totalShardCopyCount == 0 && successShardCopyCount == 0 && allFailures.isEmpty() : "shouldn't call this twice";

final ClusterState clusterState = clusterService.state();
final List<ShardId> shards = shards(request, clusterState);
final Map<String, IndexMetadata> indexMetadataByName = clusterState.getMetadata().indices();

try (var refs = new RefCountingRunnable(() -> finish(listener))) {
for (final ShardId shardId : shards) {
// NB This sends O(#shards) requests in a tight loop; TODO add some throttling here?
shardExecute(
task,
request,
shardId,
ActionListener.releaseAfter(new ReplicationResponseActionListener(shardId, indexMetadataByName), refs.acquire())
);
}
}
}

private synchronized void addShardResponse(int numCopies, int successful, List<DefaultShardOperationFailedException> failures) {
totalShardCopyCount += numCopies;
successShardCopyCount += successful;
allFailures.addAll(failures);
}

void finish(ActionListener<Response> listener) {
// no need for synchronized here, the RefCountingRunnable guarantees that all the addShardResponse calls happen-before here
logger.trace("{}: got all shard responses", actionName);
listener.onResponse(newResponse(successShardCopyCount, allFailures.size(), totalShardCopyCount, allFailures));
}

class ReplicationResponseActionListener implements ActionListener<ShardResponse> {
private final ShardId shardId;
private final Map<String, IndexMetadata> indexMetadataByName;

ReplicationResponseActionListener(ShardId shardId, Map<String, IndexMetadata> indexMetadataByName) {
this.shardId = shardId;
this.indexMetadataByName = indexMetadataByName;
}

@Override
public void onFailure(Exception e) {
logger.trace("{}: got failure from {}", actionName, shardId);
int totalNumCopies = clusterState.getMetadata().getIndexSafe(shardId.getIndex()).getNumberOfReplicas() + 1;
ShardResponse shardResponse = newShardResponse();
ReplicationResponse.ShardInfo.Failure[] failures;
if (TransportActions.isShardNotAvailableException(e)) {
failures = new ReplicationResponse.ShardInfo.Failure[0];
} else {
ReplicationResponse.ShardInfo.Failure failure = new ReplicationResponse.ShardInfo.Failure(
shardId,
null,
e,
ExceptionsHelper.status(e),
true
);
failures = new ReplicationResponse.ShardInfo.Failure[totalNumCopies];
Arrays.fill(failures, failure);
}
shardResponse.setShardInfo(new ReplicationResponse.ShardInfo(totalNumCopies, 0, failures));
shardsResponses.add(shardResponse);
@Override
public void onResponse(ShardResponse shardResponse) {
assert shardResponse != null;
logger.trace("{}: got response from {}", actionName, shardId);
addShardResponse(
shardResponse.getShardInfo().getTotal(),
shardResponse.getShardInfo().getSuccessful(),
Arrays.stream(shardResponse.getShardInfo().getFailures())
.map(
f -> new DefaultShardOperationFailedException(
new BroadcastShardOperationFailedException(shardId, f.getCause())
)
)
.toList()
);
}

@Override
public void onFailure(Exception e) {
logger.trace("{}: got failure from {}", actionName, shardId);
final int numCopies = indexMetadataByName.get(shardId.getIndexName()).getNumberOfReplicas() + 1;
final List<DefaultShardOperationFailedException> result;
if (TransportActions.isShardNotAvailableException(e)) {
result = List.of();
} else {
final var failures = new DefaultShardOperationFailedException[numCopies];
Arrays.fill(
failures,
new DefaultShardOperationFailedException(new BroadcastShardOperationFailedException(shardId, e))
);
result = Arrays.asList(failures);
}
};
shardExecute(task, request, shardId, ActionListener.releaseAfter(shardActionListener, refs.acquire()));
addShardResponse(numCopies, 0, result);
}
}
}

};
}

protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener<ShardResponse> shardActionListener) {
assert Transports.assertNotTransportThread("may hit all the shards");
ShardRequest shardRequest = newShardRequest(request, shardId);
shardRequest.setParentTask(clusterService.localNode().getId(), task.getId());
client.executeLocally(replicatedBroadcastShardAction, shardRequest, shardActionListener);
Expand All @@ -119,6 +176,7 @@ protected void shardExecute(Task task, Request request, ShardId shardId, ActionL
* @return all shard ids the request should run on
*/
protected List<ShardId> shards(Request request, ClusterState clusterState) {
assert Transports.assertNotTransportThread("may hit all the shards");
List<ShardId> shardIds = new ArrayList<>();
String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(clusterState, request);
for (String index : concreteIndices) {
Expand All @@ -133,43 +191,13 @@ protected List<ShardId> shards(Request request, ClusterState clusterState) {
return shardIds;
}

protected abstract ShardResponse newShardResponse();

protected abstract ShardRequest newShardRequest(Request request, ShardId shardId);

private void finishAndNotifyListener(ActionListener<Response> listener, CopyOnWriteArrayList<ShardResponse> shardsResponses) {
logger.trace("{}: got all shard responses", actionName);
int successfulShards = 0;
int failedShards = 0;
int totalNumCopies = 0;
List<DefaultShardOperationFailedException> shardFailures = null;
for (int i = 0; i < shardsResponses.size(); i++) {
ReplicationResponse shardResponse = shardsResponses.get(i);
if (shardResponse == null) {
// non active shard, ignore
} else {
failedShards += shardResponse.getShardInfo().getFailed();
successfulShards += shardResponse.getShardInfo().getSuccessful();
totalNumCopies += shardResponse.getShardInfo().getTotal();
if (shardFailures == null) {
shardFailures = new ArrayList<>();
}
for (ReplicationResponse.ShardInfo.Failure failure : shardResponse.getShardInfo().getFailures()) {
shardFailures.add(
new DefaultShardOperationFailedException(
new BroadcastShardOperationFailedException(failure.fullShardId(), failure.getCause())
)
);
}
}
}
listener.onResponse(newResponse(successfulShards, failedShards, totalNumCopies, shardFailures));
}

protected abstract Response newResponse(
int successfulShards,
int failedShards,
int totalNumCopies,
List<DefaultShardOperationFailedException> shardFailures
);

}
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,11 @@ private class TestBroadcastReplicationAction extends TransportBroadcastReplicati
null,
actionFilters,
indexNameExpressionResolver,
null
null,
ThreadPool.Names.SAME
);
}

@Override
protected ReplicationResponse newShardResponse() {
return new ReplicationResponse();
}

@Override
protected BasicReplicationRequest newShardRequest(DummyBroadcastRequest request, ShardId shardId) {
return new BasicReplicationRequest(shardId);
Expand Down

0 comments on commit 4aa4a0d

Please sign in to comment.