Skip to content

Commit

Permalink
Discard intermediate node results when a request is cancelled (#82685)
Browse files Browse the repository at this point in the history
Resolves #82337
  • Loading branch information
gmarouli authored Feb 10, 2022
1 parent d4caeea commit d4655e8
Show file tree
Hide file tree
Showing 9 changed files with 341 additions and 79 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/82685.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 82685
summary: Discard intermediate results upon cancellation for stats endpoints
area: Stats
type: bug
issues:
- 82337
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.action.support;

import java.util.Collection;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;

/**
* This class tracks the intermediate responses that will be used to create aggregated cluster response to a request. It also gives the
* possibility to discard the intermediate results when asked, for example when the initial request is cancelled, in order to release the
* resources.
*/
public class NodeResponseTracker {

private final AtomicInteger counter = new AtomicInteger();
private final int expectedResponsesCount;
private volatile AtomicReferenceArray<Object> responses;
private volatile Exception causeOfDiscarding;

public NodeResponseTracker(int size) {
this.expectedResponsesCount = size;
this.responses = new AtomicReferenceArray<>(size);
}

public NodeResponseTracker(Collection<Object> array) {
this.expectedResponsesCount = array.size();
this.responses = new AtomicReferenceArray<>(array.toArray());
}

/**
* This method discards the results collected so far to free up the resources.
* @param cause the discarding, this will be communicated if they try to access the discarded results
*/
public void discardIntermediateResponses(Exception cause) {
if (responses != null) {
this.causeOfDiscarding = cause;
responses = null;
}
}

public boolean responsesDiscarded() {
return responses == null;
}

/**
* This method stores a new node response if the intermediate responses haven't been discarded yet. If the responses are not discarded
* the method asserts that this is the first response encountered from this node to protect from miscounting the responses in case of a
* double invocation. If the responses have been discarded we accept this risk for simplicity.
* @param nodeIndex, the index that represents a single node of the cluster
* @param response, a response can be either a NodeResponse or an error
* @return true if all the nodes' responses have been received, else false
*/
public boolean trackResponseAndCheckIfLast(int nodeIndex, Object response) {
AtomicReferenceArray<Object> responses = this.responses;

if (responsesDiscarded() == false) {
boolean firstEncounter = responses.compareAndSet(nodeIndex, null, response);
assert firstEncounter : "a response should be tracked only once";
}
return counter.incrementAndGet() == getExpectedResponseCount();
}

/**
* Returns the tracked response or null if the response hasn't been received yet for a specific index that represents a node of the
* cluster.
* @throws DiscardedResponsesException if the responses have been discarded
*/
public Object getResponse(int nodeIndex) throws DiscardedResponsesException {
AtomicReferenceArray<Object> responses = this.responses;
if (responsesDiscarded()) {
throw new DiscardedResponsesException(causeOfDiscarding);
}
return responses.get(nodeIndex);
}

public int getExpectedResponseCount() {
return expectedResponsesCount;
}

/**
* This exception is thrown when the {@link NodeResponseTracker} is asked to give information about the responses after they have been
* discarded.
*/
public static class DiscardedResponsesException extends Exception {

public DiscardedResponsesException(Exception cause) {
super(cause);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.action.support.DefaultShardOperationFailedException;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.NodeResponseTracker;
import org.elasticsearch.action.support.TransportActions;
import org.elasticsearch.action.support.broadcast.BroadcastRequest;
import org.elasticsearch.action.support.broadcast.BroadcastResponse;
Expand Down Expand Up @@ -51,7 +52,6 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.function.Consumer;

/**
Expand Down Expand Up @@ -118,28 +118,29 @@ public TransportBroadcastByNodeAction(

private Response newResponse(
Request request,
AtomicReferenceArray<?> responses,
NodeResponseTracker nodeResponseTracker,
int unavailableShardCount,
Map<String, List<ShardRouting>> nodes,
ClusterState clusterState
) {
) throws NodeResponseTracker.DiscardedResponsesException {
int totalShards = 0;
int successfulShards = 0;
List<ShardOperationResult> broadcastByNodeResponses = new ArrayList<>();
List<DefaultShardOperationFailedException> exceptions = new ArrayList<>();
for (int i = 0; i < responses.length(); i++) {
if (responses.get(i)instanceof FailedNodeException exception) {
for (int i = 0; i < nodeResponseTracker.getExpectedResponseCount(); i++) {
Object response = nodeResponseTracker.getResponse(i);
if (response instanceof FailedNodeException exception) {
totalShards += nodes.get(exception.nodeId()).size();
for (ShardRouting shard : nodes.get(exception.nodeId())) {
exceptions.add(new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), exception));
}
} else {
@SuppressWarnings("unchecked")
NodeResponse response = (NodeResponse) responses.get(i);
broadcastByNodeResponses.addAll(response.results);
totalShards += response.getTotalShards();
successfulShards += response.getSuccessfulShards();
for (BroadcastShardOperationFailedException throwable : response.getExceptions()) {
NodeResponse nodeResponse = (NodeResponse) response;
broadcastByNodeResponses.addAll(nodeResponse.results);
totalShards += nodeResponse.getTotalShards();
successfulShards += nodeResponse.getSuccessfulShards();
for (BroadcastShardOperationFailedException throwable : nodeResponse.getExceptions()) {
if (TransportActions.isShardNotAvailableException(throwable) == false) {
exceptions.add(
new DefaultShardOperationFailedException(
Expand Down Expand Up @@ -256,16 +257,15 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
new AsyncAction(task, request, listener).start();
}

protected class AsyncAction {
protected class AsyncAction implements CancellableTask.CancellationListener {
private final Task task;
private final Request request;
private final ActionListener<Response> listener;
private final ClusterState clusterState;
private final DiscoveryNodes nodes;
private final Map<String, List<ShardRouting>> nodeIds;
private final AtomicReferenceArray<Object> responses;
private final AtomicInteger counter = new AtomicInteger();
private final int unavailableShardCount;
private final NodeResponseTracker nodeResponseTracker;

protected AsyncAction(Task task, Request request, ActionListener<Response> listener) {
this.task = task;
Expand Down Expand Up @@ -312,10 +312,13 @@ protected AsyncAction(Task task, Request request, ActionListener<Response> liste

}
this.unavailableShardCount = unavailableShardCount;
responses = new AtomicReferenceArray<>(nodeIds.size());
nodeResponseTracker = new NodeResponseTracker(nodeIds.size());
}

public void start() {
if (task instanceof CancellableTask cancellableTask) {
cancellableTask.addListener(this);
}
if (nodeIds.size() == 0) {
try {
onCompletion();
Expand Down Expand Up @@ -373,38 +376,34 @@ protected void onNodeResponse(DiscoveryNode node, int nodeIndex, NodeResponse re
logger.trace("received response for [{}] from node [{}]", actionName, node.getId());
}

// this is defensive to protect against the possibility of double invocation
// the current implementation of TransportService#sendRequest guards against this
// but concurrency is hard, safety is important, and the small performance loss here does not matter
if (responses.compareAndSet(nodeIndex, null, response)) {
if (counter.incrementAndGet() == responses.length()) {
onCompletion();
}
if (nodeResponseTracker.trackResponseAndCheckIfLast(nodeIndex, response)) {
onCompletion();
}
}

protected void onNodeFailure(DiscoveryNode node, int nodeIndex, Throwable t) {
String nodeId = node.getId();
logger.debug(new ParameterizedMessage("failed to execute [{}] on node [{}]", actionName, nodeId), t);

// this is defensive to protect against the possibility of double invocation
// the current implementation of TransportService#sendRequest guards against this
// but concurrency is hard, safety is important, and the small performance loss here does not matter
if (responses.compareAndSet(nodeIndex, null, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t))) {
if (counter.incrementAndGet() == responses.length()) {
onCompletion();
}
if (nodeResponseTracker.trackResponseAndCheckIfLast(
nodeIndex,
new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t)
)) {
onCompletion();
}
}

protected void onCompletion() {
if (task instanceof CancellableTask && ((CancellableTask) task).notifyIfCancelled(listener)) {
if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) {
return;
}

Response response = null;
try {
response = newResponse(request, responses, unavailableShardCount, nodeIds, clusterState);
response = newResponse(request, nodeResponseTracker, unavailableShardCount, nodeIds, clusterState);
} catch (NodeResponseTracker.DiscardedResponsesException e) {
// We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take
// follow-up actions
listener.onFailure((Exception) e.getCause());
} catch (Exception e) {
logger.debug("failed to combine responses from nodes", e);
listener.onFailure(e);
Expand All @@ -417,6 +416,21 @@ protected void onCompletion() {
}
}
}

@Override
public void onCancelled() {
assert task instanceof CancellableTask : "task must be cancellable";
try {
((CancellableTask) task).ensureNotCancelled();
} catch (TaskCancelledException e) {
nodeResponseTracker.discardIntermediateResponses(e);
}
}

// For testing purposes
public NodeResponseTracker getNodeResponseTracker() {
return nodeResponseTracker;
}
}

class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler<NodeRequest> {
Expand Down
Loading

0 comments on commit d4655e8

Please sign in to comment.