Skip to content

Commit

Permalink
Send ban parent per outstanding child connection (#65443)
Browse files Browse the repository at this point in the history
This commit sends a parent-task ban for each connection so that we can
reply on channel disconnect instead of node leave events to remove bans.

Backport #65443
  • Loading branch information
dnhatn authored Dec 12, 2020
1 parent 68fce39 commit fa31cb0
Show file tree
Hide file tree
Showing 18 changed files with 130 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
Expand Down Expand Up @@ -123,6 +124,7 @@
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.mockito.Mockito.mock;

public class AsyncBulkByScrollActionTests extends ESTestCase {
private MyMockClient client;
Expand Down Expand Up @@ -761,7 +763,7 @@ private static class DummyTransportAsyncBulkByScrollAction


protected DummyTransportAsyncBulkByScrollAction(String actionName, ActionFilters actionFilters, TaskManager taskManager) {
super(actionName, actionFilters, taskManager);
super(actionName, actionFilters, mock(Transport.Connection.class), taskManager);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ protected HandledTransportAction(String actionName, boolean canTripCircuitBreake
protected HandledTransportAction(String actionName, boolean canTripCircuitBreaker,
TransportService transportService, ActionFilters actionFilters,
Writeable.Reader<Request> requestReader, String executor) {
super(actionName, actionFilters, transportService.getTaskManager());
super(actionName, actionFilters, transportService.getLocalNodeConnection(), transportService.getTaskManager());
transportService.registerRequestHandler(actionName, executor, false, canTripCircuitBreaker, requestReader,
new TransportHandler());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskListener;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.transport.Transport;

import java.util.concurrent.atomic.AtomicInteger;

Expand All @@ -39,21 +40,25 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
public final String actionName;
private final ActionFilter[] filters;
protected final TaskManager taskManager;
protected final Transport.Connection localConnection;

/**
* @deprecated declare your own logger.
*/
@Deprecated
protected Logger logger = LogManager.getLogger(getClass());

protected TransportAction(String actionName, ActionFilters actionFilters, TaskManager taskManager) {
protected TransportAction(String actionName, ActionFilters actionFilters,
Transport.Connection localConnection, TaskManager taskManager) {
this.actionName = actionName;
this.filters = actionFilters.filters();
this.localConnection = localConnection;
this.taskManager = taskManager;
}

private Releasable registerChildNode(TaskId parentTask) {
if (parentTask.isSet()) {
return taskManager.registerChildNode(parentTask.getId(), taskManager.localNode());
return taskManager.registerChildConnection(parentTask.getId(), localConnection);
} else {
return () -> {};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ protected TransportReplicationAction(Settings settings, String actionName, Trans
ActionFilters actionFilters, Writeable.Reader<Request> requestReader,
Writeable.Reader<ReplicaRequest> replicaRequestReader, String executor,
boolean syncGlobalCheckpointAfterOperation, boolean forceExecutionOnPrimary) {
super(actionName, actionFilters, transportService.getTaskManager());
super(actionName, actionFilters, transportService.getLocalNodeConnection(), transportService.getTaskManager());
this.threadPool = threadPool;
this.transportService = transportService;
this.clusterService = clusterService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ protected TransportSingleShardAction(String actionName, ThreadPool threadPool, C
TransportService transportService, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver, Writeable.Reader<Request> request,
String executor) {
super(actionName, actionFilters, transportService.getTaskManager());
super(actionName, actionFilters, transportService.getLocalNodeConnection(), transportService.getTaskManager());
this.threadPool = threadPool;
this.clusterService = clusterService;
this.transportService = transportService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,31 @@
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.EmptyTransportResponseHandler;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestDeduplicator;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;

public class TaskCancellationService {
public static final String BAN_PARENT_ACTION_NAME = "internal:admin/tasks/ban";
private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
private final TransportService transportService;
private final TaskManager taskManager;
private final TransportRequestDeduplicator<CancelRequest> deduplicator = new TransportRequestDeduplicator<>();

public TaskCancellationService(TransportService transportService) {
this.transportService = transportService;
Expand All @@ -61,35 +65,63 @@ private String localNodeId() {
return transportService.getLocalNode().getId();
}

void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
private static class CancelRequest {
final CancellableTask task;
final boolean waitForCompletion;

CancelRequest(CancellableTask task, boolean waitForCompletion) {
this.task = task;
this.waitForCompletion = waitForCompletion;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final CancelRequest that = (CancelRequest) o;
return waitForCompletion == that.waitForCompletion && Objects.equals(task, that.task);
}

@Override
public int hashCode() {
return Objects.hash(task, waitForCompletion);
}
}

void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> finalListener) {
deduplicator.executeOnce(new CancelRequest(task, waitForCompletion), finalListener,
(r, listener) -> doCancelTaskAndDescendants(task, reason, waitForCompletion, listener));
}

void doCancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
final TaskId taskId = task.taskInfo(localNodeId(), false).getTaskId();
if (task.shouldCancelChildrenOnCancellation()) {
logger.trace("cancelling task [{}] and its descendants", taskId);
StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(completedListener.map(r -> null), 3);
Collection<DiscoveryNode> childrenNodes = taskManager.startBanOnChildrenNodes(task.getId(), () -> {
Collection<Transport.Connection> childConnections = taskManager.startBanOnChildTasks(task.getId(), () -> {
logger.trace("child tasks of parent [{}] are completed", taskId);
groupedListener.onResponse(null);
});
taskManager.cancel(task, reason, () -> {
logger.trace("task [{}] is cancelled", taskId);
groupedListener.onResponse(null);
});
StepListener<Void> banOnNodesListener = new StepListener<>();
setBanOnNodes(reason, waitForCompletion, task, childrenNodes, banOnNodesListener);
banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
StepListener<Void> setBanListener = new StepListener<>();
setBanOnChildConnections(reason, waitForCompletion, task, childConnections, setBanListener);
setBanListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
// If we start unbanning when the last child task completed and that child task executed with a specific user, then unban
// requests are denied because internal requests can't run with a user. We need to remove bans with the current thread context.
final Runnable removeBansRunnable = transportService.getThreadPool().getThreadContext()
.preserveContext(() -> removeBanOnNodes(task, childrenNodes));
// We remove bans after all child tasks are completed although in theory we can do it on a per-node basis.
.preserveContext(() -> removeBanOnChildConnections(task, childConnections));
// We remove bans after all child tasks are completed although in theory we can do it on a per-connection basis.
completedListener.whenComplete(r -> removeBansRunnable.run(), e -> removeBansRunnable.run());
// if wait_for_completion is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes.
// if wait_for_completion is true, then only return when (1) bans are placed on child connections, (2) child tasks are
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child connections.
if (waitForCompletion) {
completedListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
} else {
banOnNodesListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
setBanListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
}
} else {
logger.trace("task [{}] doesn't have any children that should be cancelled", taskId);
Expand All @@ -102,47 +134,48 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF
}
}

private void setBanOnNodes(String reason, boolean waitForCompletion, CancellableTask task,
Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
if (childNodes.isEmpty()) {
private void setBanOnChildConnections(String reason, boolean waitForCompletion, CancellableTask task,
Collection<Transport.Connection> childConnections, ActionListener<Void> listener) {
if (childConnections.isEmpty()) {
listener.onResponse(null);
return;
}
final TaskId taskId = new TaskId(localNodeId(), task.getId());
logger.trace("cancelling child tasks of [{}] on child nodes {}", taskId, childNodes);
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(listener.map(r -> null), childNodes.size());
logger.trace("cancelling child tasks of [{}] on child connections {}", taskId, childConnections);
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(listener.map(r -> null), childConnections.size());
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion);
for (DiscoveryNode node : childNodes) {
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest,
for (Transport.Connection connection : childConnections) {
transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, banRequest, TransportRequestOptions.EMPTY,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleResponse(TransportResponse.Empty response) {
logger.trace("sent ban for tasks with the parent [{}] to the node [{}]", taskId, node);
logger.trace("sent ban for tasks with the parent [{}] for connection [{}]", taskId, connection);
groupedListener.onResponse(null);
}

@Override
public void handleException(TransportException exp) {
assert ExceptionsHelper.unwrapCause(exp) instanceof ElasticsearchSecurityException == false;
logger.warn("Cannot send ban for tasks with the parent [{}] to the node [{}]", taskId, node);
logger.warn("Cannot send ban for tasks with the parent [{}] for connection [{}]", taskId, connection);
groupedListener.onFailure(exp);
}
});
}
}

private void removeBanOnNodes(CancellableTask task, Collection<DiscoveryNode> childNodes) {
private void removeBanOnChildConnections(CancellableTask task, Collection<Transport.Connection> childConnections) {
final BanParentTaskRequest request =
BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(localNodeId(), task.getId()));
for (DiscoveryNode node : childNodes) {
logger.trace("Sending remove ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node);
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, request, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleException(TransportException exp) {
assert ExceptionsHelper.unwrapCause(exp) instanceof ElasticsearchSecurityException == false;
logger.info("failed to remove the parent ban for task {} on node {}", request.parentTaskId, node);
}
});
for (Transport.Connection connection : childConnections) {
logger.trace("Sending remove ban for tasks with the parent [{}] for connection [{}]", request.parentTaskId, connection);
transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, request, TransportRequestOptions.EMPTY,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleException(TransportException exp) {
assert ExceptionsHelper.unwrapCause(exp) instanceof ElasticsearchSecurityException == false;
logger.info("failed to remove the parent ban for task {} for connection {}", request.parentTaskId, connection);
}
});
}
}

Expand Down
Loading

0 comments on commit fa31cb0

Please sign in to comment.