From 2641e9a6d6466978eb95db2d408af1bb90834ab6 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Tue, 8 Dec 2020 09:10:12 -0500 Subject: [PATCH] Cancel proxy requests when the proxy channel closes (#65850) Since #43332 and #56327 we cancel rest requests when the rest channel closes and transport requests when the transport channel closes. This commit cancels proxy requests and its descendant requests when the proxy channel closes. This change is also required to support cross-clusters task cancellation. Relates #43332 Relates #56327 --- .../search/ccs/CrossClusterSearchIT.java | 15 ++++++ .../action/search/SearchTransportService.java | 22 ++++---- .../transport/TransportActionProxy.java | 51 +++++++++++++++++-- .../transport/TransportActionProxyTests.java | 50 ++++++++++++++---- .../ClearCcrRestoreSessionAction.java | 2 +- .../GetCcrRestoreFileChunkAction.java | 2 +- .../TransportOpenPointInTimeAction.java | 1 + 7 files changed, 115 insertions(+), 28 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java index fdeab35a158c8..a1cdd1a5824d7 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java @@ -32,7 +32,12 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.test.AbstractMultiClustersTestCase; +import org.elasticsearch.test.InternalTestCluster; +import org.elasticsearch.test.NodeRoles; +import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.transport.TransportService; import org.junit.Before; import java.util.Collection; @@ -82,6 +87,16 @@ public void testProxyConnectionDisconnect() throws Exception { client(LOCAL_CLUSTER).search(searchRequest, future); SearchListenerPlugin.waitSearchStarted(); disconnectFromRemoteClusters(); + // Cancellable tasks on the remote cluster should be cancelled + assertBusy(() -> { + final Iterable transportServices = cluster("cluster_a").getInstances(TransportService.class); + for (TransportService transportService : transportServices) { + Collection cancellableTasks = transportService.getTaskManager().getCancellableTasks().values(); + for (CancellableTask cancellableTask : cancellableTasks) { + assertTrue(cancellableTask.getDescription(), cancellableTask.isCancelled()); + } + } + }); assertBusy(() -> assertTrue(future.isDone())); configureAndConnectsToRemoteClusters(); } finally { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index fb5980dda03b7..b069ad5d54664 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -297,20 +297,20 @@ public static void registerRequestHandler(TransportService transportService, Sea boolean freed = searchService.freeReaderContext(request.id()); channel.sendResponse(new SearchFreeContextResponse(freed)); }); - TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_SCROLL_ACTION_NAME, SearchFreeContextResponse::new); + TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_SCROLL_ACTION_NAME, false, SearchFreeContextResponse::new); transportService.registerRequestHandler(FREE_CONTEXT_ACTION_NAME, ThreadPool.Names.SAME, SearchFreeContextRequest::new, (request, channel, task) -> { boolean freed = searchService.freeReaderContext(request.id()); channel.sendResponse(new SearchFreeContextResponse(freed)); }); - TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, SearchFreeContextResponse::new); + TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, false, SearchFreeContextResponse::new); transportService.registerRequestHandler(CLEAR_SCROLL_CONTEXTS_ACTION_NAME, ThreadPool.Names.SAME, TransportRequest.Empty::new, (request, channel, task) -> { searchService.freeAllScrollContexts(); channel.sendResponse(TransportResponse.Empty.INSTANCE); }); - TransportActionProxy.registerProxyAction(transportService, CLEAR_SCROLL_CONTEXTS_ACTION_NAME, + TransportActionProxy.registerProxyAction(transportService, CLEAR_SCROLL_CONTEXTS_ACTION_NAME, false, (in) -> TransportResponse.Empty.INSTANCE); transportService.registerRequestHandler(DFS_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, @@ -319,14 +319,14 @@ public static void registerRequestHandler(TransportService transportService, Sea new ChannelActionListener<>(channel, DFS_ACTION_NAME, request)) ); - TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, DfsSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new); transportService.registerRequestHandler(QUERY_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, (request, channel, task) -> { searchService.executeQueryPhase(request, keepStatesInContext(channel.getVersion()), (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME, + TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME, true, (request) -> ((ShardSearchRequest)request).numberOfShards() == 1 ? QueryFetchSearchResult::new : QuerySearchResult::new); transportService.registerRequestHandler(QUERY_ID_ACTION_NAME, ThreadPool.Names.SAME, QuerySearchRequest::new, @@ -334,42 +334,42 @@ public static void registerRequestHandler(TransportService transportService, Sea searchService.executeQueryPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_ID_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, QuerySearchResult::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, true, QuerySearchResult::new); transportService.registerRequestHandler(QUERY_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, InternalScrollSearchRequest::new, (request, channel, task) -> { searchService.executeQueryPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_SCROLL_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, ScrollQuerySearchResult::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, true, ScrollQuerySearchResult::new); transportService.registerRequestHandler(QUERY_FETCH_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, InternalScrollSearchRequest::new, (request, channel, task) -> { searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_FETCH_SCROLL_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, ScrollQueryFetchSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new); transportService.registerRequestHandler(FETCH_ID_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, ShardFetchRequest::new, (request, channel, task) -> { searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, FETCH_ID_SCROLL_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, FetchSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, true, FetchSearchResult::new); transportService.registerRequestHandler(FETCH_ID_ACTION_NAME, ThreadPool.Names.SAME, true, true, ShardFetchSearchRequest::new, (request, channel, task) -> { searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, FETCH_ID_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, FetchSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, true, FetchSearchResult::new); // this is cheap, it does not fetch during the rewrite phase, so we can let it quickly execute on a networking thread transportService.registerRequestHandler(QUERY_CAN_MATCH_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, (request, channel, task) -> { searchService.canMatch(request, new ChannelActionListener<>(channel, QUERY_CAN_MATCH_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, SearchService.CanMatchResponse::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, true, SearchService.CanMatchResponse::new); } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java b/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java index beb5cfb8084d1..c8cf82ee31c52 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java @@ -22,11 +22,15 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.Map; import java.util.function.Function; /** @@ -55,9 +59,21 @@ private static class ProxyRequestHandler implements Tran public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { DiscoveryNode targetNode = request.targetNode; TransportRequest wrappedRequest = request.wrapped; + assert assertConsistentTaskType(task, wrappedRequest); + TaskId taskId = task.taskInfo(service.localNode.getId(), false).getTaskId(); + wrappedRequest.setParentTask(taskId); service.sendRequest(targetNode, action, wrappedRequest, new ProxyResponseHandler<>(channel, responseFunction.apply(wrappedRequest))); } + + private boolean assertConsistentTaskType(Task proxyTask, TransportRequest wrapped) { + final Task targetTask = + wrapped.createTask(0, proxyTask.getType(), proxyTask.getAction(), TaskId.EMPTY_TASK_ID, Collections.emptyMap()); + assert targetTask instanceof CancellableTask == proxyTask instanceof CancellableTask : + "Cancellable property of proxy action [" + proxyTask.getAction() + "] is configured inconsistently: " + + "expected [" + (targetTask instanceof CancellableTask) + "] actual [" + (proxyTask instanceof CancellableTask) + "]"; + return true; + } } private static class ProxyResponseHandler implements TransportResponseHandler { @@ -117,27 +133,54 @@ public void writeTo(StreamOutput out) throws IOException { } } + private static class CancellableProxyRequest extends ProxyRequest { + CancellableProxyRequest(StreamInput in, Writeable.Reader reader) throws IOException { + super(in, reader); + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, "", parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return true; + } + + @Override + public String getDescription() { + return "proxy task [" + wrapped.getDescription() + "]"; + } + }; + } + } + /** * Registers a proxy request handler that allows to forward requests for the given action to another node. To be used when the * response type changes based on the upcoming request (quite rare) */ - public static void registerProxyActionWithDynamicResponseType(TransportService service, String action, + public static void registerProxyActionWithDynamicResponseType(TransportService service, String action, boolean cancellable, Function> responseFunction) { RequestHandlerRegistry requestHandler = service.getRequestHandler(action); service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false, - in -> new ProxyRequest<>(in, requestHandler::newRequest), new ProxyRequestHandler<>(service, action, responseFunction)); + in -> cancellable ? + new CancellableProxyRequest<>(in, requestHandler::newRequest) : + new ProxyRequest<>(in, requestHandler::newRequest), + new ProxyRequestHandler<>(service, action, responseFunction)); } /** * Registers a proxy request handler that allows to forward requests for the given action to another node. To be used when the * response type is always the same (most of the cases). */ - public static void registerProxyAction(TransportService service, String action, + public static void registerProxyAction(TransportService service, String action, boolean cancellable, Writeable.Reader reader) { RequestHandlerRegistry requestHandler = service.getRequestHandler(action); service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false, - in -> new ProxyRequest<>(in, requestHandler::newRequest), new ProxyRequestHandler<>(service, action, request -> reader)); + in -> cancellable ? + new CancellableProxyRequest<>(in, requestHandler::newRequest) : + new ProxyRequest<>(in, requestHandler::newRequest), + new ProxyRequestHandler<>(service, action, request -> reader)); } private static final String PROXY_ACTION_PREFIX = "internal:transport/proxy/"; diff --git a/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java b/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java index 6c3011b7a6695..9a1f0dae00598 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java @@ -26,6 +26,9 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; @@ -33,8 +36,11 @@ import org.junit.Before; import java.io.IOException; +import java.util.Map; import java.util.concurrent.CountDownLatch; +import static org.hamcrest.Matchers.equalTo; + public class TransportActionProxyTests extends ESTestCase { protected ThreadPool threadPool; // we use always a non-alpha or beta version here otherwise minimumCompatibilityVersion will be different for the two used versions @@ -89,28 +95,32 @@ public void testSendMessage() throws InterruptedException { SimpleTestResponse response = new SimpleTestResponse("TS_A"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceA, "internal:test", SimpleTestResponse::new); + final boolean cancellable = randomBoolean(); + TransportActionProxy.registerProxyAction(serviceA, "internal:test", cancellable, SimpleTestResponse::new); serviceA.connectToNode(nodeB); serviceB.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { + assertThat(task instanceof CancellableTask, equalTo(cancellable)); assertEquals(request.sourceNode, "TS_A"); SimpleTestResponse response = new SimpleTestResponse("TS_B"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceB, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceB, "internal:test", cancellable, SimpleTestResponse::new); serviceB.connectToNode(nodeC); serviceC.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { + assertThat(task instanceof CancellableTask, equalTo(cancellable)); assertEquals(request.sourceNode, "TS_A"); SimpleTestResponse response = new SimpleTestResponse("TS_C"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceC, "internal:test", SimpleTestResponse::new); + + TransportActionProxy.registerProxyAction(serviceC, "internal:test", cancellable, SimpleTestResponse::new); CountDownLatch latch = new CountDownLatch(1); serviceA.sendRequest(nodeB, TransportActionProxy.getProxyAction("internal:test"), TransportActionProxy.wrapRequest(nodeC, - new SimpleTestRequest("TS_A")), new TransportResponseHandler() { + new SimpleTestRequest("TS_A", cancellable)), new TransportResponseHandler() { @Override public SimpleTestResponse read(StreamInput in) throws IOException { return new SimpleTestResponse(in); @@ -138,13 +148,14 @@ public void handleException(TransportException exp) { } public void testException() throws InterruptedException { + boolean cancellable = randomBoolean(); serviceA.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { assertEquals(request.sourceNode, "TS_A"); SimpleTestResponse response = new SimpleTestResponse("TS_A"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceA, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceA, "internal:test", cancellable, SimpleTestResponse::new); serviceA.connectToNode(nodeB); serviceB.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, @@ -153,17 +164,17 @@ public void testException() throws InterruptedException { SimpleTestResponse response = new SimpleTestResponse("TS_B"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceB, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceB, "internal:test", cancellable, SimpleTestResponse::new); serviceB.connectToNode(nodeC); serviceC.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { throw new ElasticsearchException("greetings from TS_C"); }); - TransportActionProxy.registerProxyAction(serviceC, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceC, "internal:test", cancellable, SimpleTestResponse::new); CountDownLatch latch = new CountDownLatch(1); serviceA.sendRequest(nodeB, TransportActionProxy.getProxyAction("internal:test"), TransportActionProxy.wrapRequest(nodeC, - new SimpleTestRequest("TS_A")), new TransportResponseHandler() { + new SimpleTestRequest("TS_A", cancellable)), new TransportResponseHandler() { @Override public SimpleTestResponse read(StreamInput in) throws IOException { return new SimpleTestResponse(in); @@ -192,22 +203,39 @@ public void handleException(TransportException exp) { } public static class SimpleTestRequest extends TransportRequest { - String sourceNode; + final boolean cancellable; + final String sourceNode; - public SimpleTestRequest(String sourceNode) { + public SimpleTestRequest(String sourceNode, boolean cancellable) { this.sourceNode = sourceNode; + this.cancellable = cancellable; } - public SimpleTestRequest() {} public SimpleTestRequest(StreamInput in) throws IOException { super(in); sourceNode = in.readString(); + cancellable = in.readBoolean(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(sourceNode); + out.writeBoolean(cancellable); + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + if (cancellable) { + return new CancellableTask(id, type, action, "", parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return randomBoolean(); + } + }; + } else { + return super.createTask(id, type, action, parentTaskId, headers); + } } } diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java index e0ed85883df93..7484fc92e81a6 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java @@ -36,7 +36,7 @@ public static class TransportDeleteCcrRestoreSessionAction public TransportDeleteCcrRestoreSessionAction(ActionFilters actionFilters, TransportService transportService, CcrRestoreSourceService ccrRestoreService) { super(NAME, transportService, actionFilters, ClearCcrRestoreSessionRequest::new, ThreadPool.Names.GENERIC); - TransportActionProxy.registerProxyAction(transportService, NAME, in -> ActionResponse.Empty.INSTANCE); + TransportActionProxy.registerProxyAction(transportService, NAME, false, in -> ActionResponse.Empty.INSTANCE); this.ccrRestoreService = ccrRestoreService; } diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java index b858531d7614f..96f795cfc5f95 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java @@ -45,7 +45,7 @@ public static class TransportGetCcrRestoreFileChunkAction public TransportGetCcrRestoreFileChunkAction(BigArrays bigArrays, TransportService transportService, ActionFilters actionFilters, CcrRestoreSourceService restoreSourceService) { super(NAME, transportService, actionFilters, GetCcrRestoreFileChunkRequest::new, ThreadPool.Names.GENERIC); - TransportActionProxy.registerProxyAction(transportService, NAME, GetCcrRestoreFileChunkResponse::new); + TransportActionProxy.registerProxyAction(transportService, NAME, false, GetCcrRestoreFileChunkResponse::new); this.restoreSourceService = restoreSourceService; this.bigArrays = bigArrays; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java index 75ae1c01a5b5b..d65c1f3b8d256 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java @@ -61,6 +61,7 @@ public TransportOpenPointInTimeAction( TransportActionProxy.registerProxyAction( transportService, OPEN_SHARD_READER_CONTEXT_NAME, + false, TransportOpenPointInTimeAction.ShardOpenReaderResponse::new ); }