diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java index fdeab35a158c8..a1cdd1a5824d7 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java @@ -32,7 +32,12 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.test.AbstractMultiClustersTestCase; +import org.elasticsearch.test.InternalTestCluster; +import org.elasticsearch.test.NodeRoles; +import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.transport.TransportService; import org.junit.Before; import java.util.Collection; @@ -82,6 +87,16 @@ public void testProxyConnectionDisconnect() throws Exception { client(LOCAL_CLUSTER).search(searchRequest, future); SearchListenerPlugin.waitSearchStarted(); disconnectFromRemoteClusters(); + // Cancellable tasks on the remote cluster should be cancelled + assertBusy(() -> { + final Iterable transportServices = cluster("cluster_a").getInstances(TransportService.class); + for (TransportService transportService : transportServices) { + Collection cancellableTasks = transportService.getTaskManager().getCancellableTasks().values(); + for (CancellableTask cancellableTask : cancellableTasks) { + assertTrue(cancellableTask.getDescription(), cancellableTask.isCancelled()); + } + } + }); assertBusy(() -> assertTrue(future.isDone())); configureAndConnectsToRemoteClusters(); } finally { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index fb5980dda03b7..b069ad5d54664 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -297,20 +297,20 @@ public static void registerRequestHandler(TransportService transportService, Sea boolean freed = searchService.freeReaderContext(request.id()); channel.sendResponse(new SearchFreeContextResponse(freed)); }); - TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_SCROLL_ACTION_NAME, SearchFreeContextResponse::new); + TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_SCROLL_ACTION_NAME, false, SearchFreeContextResponse::new); transportService.registerRequestHandler(FREE_CONTEXT_ACTION_NAME, ThreadPool.Names.SAME, SearchFreeContextRequest::new, (request, channel, task) -> { boolean freed = searchService.freeReaderContext(request.id()); channel.sendResponse(new SearchFreeContextResponse(freed)); }); - TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, SearchFreeContextResponse::new); + TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, false, SearchFreeContextResponse::new); transportService.registerRequestHandler(CLEAR_SCROLL_CONTEXTS_ACTION_NAME, ThreadPool.Names.SAME, TransportRequest.Empty::new, (request, channel, task) -> { searchService.freeAllScrollContexts(); channel.sendResponse(TransportResponse.Empty.INSTANCE); }); - TransportActionProxy.registerProxyAction(transportService, CLEAR_SCROLL_CONTEXTS_ACTION_NAME, + TransportActionProxy.registerProxyAction(transportService, CLEAR_SCROLL_CONTEXTS_ACTION_NAME, false, (in) -> TransportResponse.Empty.INSTANCE); transportService.registerRequestHandler(DFS_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, @@ -319,14 +319,14 @@ public static void registerRequestHandler(TransportService transportService, Sea new ChannelActionListener<>(channel, DFS_ACTION_NAME, request)) ); - TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, DfsSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new); transportService.registerRequestHandler(QUERY_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, (request, channel, task) -> { searchService.executeQueryPhase(request, keepStatesInContext(channel.getVersion()), (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME, + TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME, true, (request) -> ((ShardSearchRequest)request).numberOfShards() == 1 ? QueryFetchSearchResult::new : QuerySearchResult::new); transportService.registerRequestHandler(QUERY_ID_ACTION_NAME, ThreadPool.Names.SAME, QuerySearchRequest::new, @@ -334,42 +334,42 @@ public static void registerRequestHandler(TransportService transportService, Sea searchService.executeQueryPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_ID_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, QuerySearchResult::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, true, QuerySearchResult::new); transportService.registerRequestHandler(QUERY_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, InternalScrollSearchRequest::new, (request, channel, task) -> { searchService.executeQueryPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_SCROLL_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, ScrollQuerySearchResult::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, true, ScrollQuerySearchResult::new); transportService.registerRequestHandler(QUERY_FETCH_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, InternalScrollSearchRequest::new, (request, channel, task) -> { searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_FETCH_SCROLL_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, ScrollQueryFetchSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new); transportService.registerRequestHandler(FETCH_ID_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, ShardFetchRequest::new, (request, channel, task) -> { searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, FETCH_ID_SCROLL_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, FetchSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, true, FetchSearchResult::new); transportService.registerRequestHandler(FETCH_ID_ACTION_NAME, ThreadPool.Names.SAME, true, true, ShardFetchSearchRequest::new, (request, channel, task) -> { searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, FETCH_ID_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, FetchSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, true, FetchSearchResult::new); // this is cheap, it does not fetch during the rewrite phase, so we can let it quickly execute on a networking thread transportService.registerRequestHandler(QUERY_CAN_MATCH_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, (request, channel, task) -> { searchService.canMatch(request, new ChannelActionListener<>(channel, QUERY_CAN_MATCH_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, SearchService.CanMatchResponse::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, true, SearchService.CanMatchResponse::new); } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java b/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java index beb5cfb8084d1..c8cf82ee31c52 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java @@ -22,11 +22,15 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.Map; import java.util.function.Function; /** @@ -55,9 +59,21 @@ private static class ProxyRequestHandler implements Tran public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { DiscoveryNode targetNode = request.targetNode; TransportRequest wrappedRequest = request.wrapped; + assert assertConsistentTaskType(task, wrappedRequest); + TaskId taskId = task.taskInfo(service.localNode.getId(), false).getTaskId(); + wrappedRequest.setParentTask(taskId); service.sendRequest(targetNode, action, wrappedRequest, new ProxyResponseHandler<>(channel, responseFunction.apply(wrappedRequest))); } + + private boolean assertConsistentTaskType(Task proxyTask, TransportRequest wrapped) { + final Task targetTask = + wrapped.createTask(0, proxyTask.getType(), proxyTask.getAction(), TaskId.EMPTY_TASK_ID, Collections.emptyMap()); + assert targetTask instanceof CancellableTask == proxyTask instanceof CancellableTask : + "Cancellable property of proxy action [" + proxyTask.getAction() + "] is configured inconsistently: " + + "expected [" + (targetTask instanceof CancellableTask) + "] actual [" + (proxyTask instanceof CancellableTask) + "]"; + return true; + } } private static class ProxyResponseHandler implements TransportResponseHandler { @@ -117,27 +133,54 @@ public void writeTo(StreamOutput out) throws IOException { } } + private static class CancellableProxyRequest extends ProxyRequest { + CancellableProxyRequest(StreamInput in, Writeable.Reader reader) throws IOException { + super(in, reader); + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, "", parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return true; + } + + @Override + public String getDescription() { + return "proxy task [" + wrapped.getDescription() + "]"; + } + }; + } + } + /** * Registers a proxy request handler that allows to forward requests for the given action to another node. To be used when the * response type changes based on the upcoming request (quite rare) */ - public static void registerProxyActionWithDynamicResponseType(TransportService service, String action, + public static void registerProxyActionWithDynamicResponseType(TransportService service, String action, boolean cancellable, Function> responseFunction) { RequestHandlerRegistry requestHandler = service.getRequestHandler(action); service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false, - in -> new ProxyRequest<>(in, requestHandler::newRequest), new ProxyRequestHandler<>(service, action, responseFunction)); + in -> cancellable ? + new CancellableProxyRequest<>(in, requestHandler::newRequest) : + new ProxyRequest<>(in, requestHandler::newRequest), + new ProxyRequestHandler<>(service, action, responseFunction)); } /** * Registers a proxy request handler that allows to forward requests for the given action to another node. To be used when the * response type is always the same (most of the cases). */ - public static void registerProxyAction(TransportService service, String action, + public static void registerProxyAction(TransportService service, String action, boolean cancellable, Writeable.Reader reader) { RequestHandlerRegistry requestHandler = service.getRequestHandler(action); service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false, - in -> new ProxyRequest<>(in, requestHandler::newRequest), new ProxyRequestHandler<>(service, action, request -> reader)); + in -> cancellable ? + new CancellableProxyRequest<>(in, requestHandler::newRequest) : + new ProxyRequest<>(in, requestHandler::newRequest), + new ProxyRequestHandler<>(service, action, request -> reader)); } private static final String PROXY_ACTION_PREFIX = "internal:transport/proxy/"; diff --git a/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java b/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java index 6c3011b7a6695..9a1f0dae00598 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java @@ -26,6 +26,9 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; @@ -33,8 +36,11 @@ import org.junit.Before; import java.io.IOException; +import java.util.Map; import java.util.concurrent.CountDownLatch; +import static org.hamcrest.Matchers.equalTo; + public class TransportActionProxyTests extends ESTestCase { protected ThreadPool threadPool; // we use always a non-alpha or beta version here otherwise minimumCompatibilityVersion will be different for the two used versions @@ -89,28 +95,32 @@ public void testSendMessage() throws InterruptedException { SimpleTestResponse response = new SimpleTestResponse("TS_A"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceA, "internal:test", SimpleTestResponse::new); + final boolean cancellable = randomBoolean(); + TransportActionProxy.registerProxyAction(serviceA, "internal:test", cancellable, SimpleTestResponse::new); serviceA.connectToNode(nodeB); serviceB.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { + assertThat(task instanceof CancellableTask, equalTo(cancellable)); assertEquals(request.sourceNode, "TS_A"); SimpleTestResponse response = new SimpleTestResponse("TS_B"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceB, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceB, "internal:test", cancellable, SimpleTestResponse::new); serviceB.connectToNode(nodeC); serviceC.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { + assertThat(task instanceof CancellableTask, equalTo(cancellable)); assertEquals(request.sourceNode, "TS_A"); SimpleTestResponse response = new SimpleTestResponse("TS_C"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceC, "internal:test", SimpleTestResponse::new); + + TransportActionProxy.registerProxyAction(serviceC, "internal:test", cancellable, SimpleTestResponse::new); CountDownLatch latch = new CountDownLatch(1); serviceA.sendRequest(nodeB, TransportActionProxy.getProxyAction("internal:test"), TransportActionProxy.wrapRequest(nodeC, - new SimpleTestRequest("TS_A")), new TransportResponseHandler() { + new SimpleTestRequest("TS_A", cancellable)), new TransportResponseHandler() { @Override public SimpleTestResponse read(StreamInput in) throws IOException { return new SimpleTestResponse(in); @@ -138,13 +148,14 @@ public void handleException(TransportException exp) { } public void testException() throws InterruptedException { + boolean cancellable = randomBoolean(); serviceA.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { assertEquals(request.sourceNode, "TS_A"); SimpleTestResponse response = new SimpleTestResponse("TS_A"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceA, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceA, "internal:test", cancellable, SimpleTestResponse::new); serviceA.connectToNode(nodeB); serviceB.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, @@ -153,17 +164,17 @@ public void testException() throws InterruptedException { SimpleTestResponse response = new SimpleTestResponse("TS_B"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceB, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceB, "internal:test", cancellable, SimpleTestResponse::new); serviceB.connectToNode(nodeC); serviceC.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { throw new ElasticsearchException("greetings from TS_C"); }); - TransportActionProxy.registerProxyAction(serviceC, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceC, "internal:test", cancellable, SimpleTestResponse::new); CountDownLatch latch = new CountDownLatch(1); serviceA.sendRequest(nodeB, TransportActionProxy.getProxyAction("internal:test"), TransportActionProxy.wrapRequest(nodeC, - new SimpleTestRequest("TS_A")), new TransportResponseHandler() { + new SimpleTestRequest("TS_A", cancellable)), new TransportResponseHandler() { @Override public SimpleTestResponse read(StreamInput in) throws IOException { return new SimpleTestResponse(in); @@ -192,22 +203,39 @@ public void handleException(TransportException exp) { } public static class SimpleTestRequest extends TransportRequest { - String sourceNode; + final boolean cancellable; + final String sourceNode; - public SimpleTestRequest(String sourceNode) { + public SimpleTestRequest(String sourceNode, boolean cancellable) { this.sourceNode = sourceNode; + this.cancellable = cancellable; } - public SimpleTestRequest() {} public SimpleTestRequest(StreamInput in) throws IOException { super(in); sourceNode = in.readString(); + cancellable = in.readBoolean(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(sourceNode); + out.writeBoolean(cancellable); + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + if (cancellable) { + return new CancellableTask(id, type, action, "", parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return randomBoolean(); + } + }; + } else { + return super.createTask(id, type, action, parentTaskId, headers); + } } } diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java index e0ed85883df93..7484fc92e81a6 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java @@ -36,7 +36,7 @@ public static class TransportDeleteCcrRestoreSessionAction public TransportDeleteCcrRestoreSessionAction(ActionFilters actionFilters, TransportService transportService, CcrRestoreSourceService ccrRestoreService) { super(NAME, transportService, actionFilters, ClearCcrRestoreSessionRequest::new, ThreadPool.Names.GENERIC); - TransportActionProxy.registerProxyAction(transportService, NAME, in -> ActionResponse.Empty.INSTANCE); + TransportActionProxy.registerProxyAction(transportService, NAME, false, in -> ActionResponse.Empty.INSTANCE); this.ccrRestoreService = ccrRestoreService; } diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java index b858531d7614f..96f795cfc5f95 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java @@ -45,7 +45,7 @@ public static class TransportGetCcrRestoreFileChunkAction public TransportGetCcrRestoreFileChunkAction(BigArrays bigArrays, TransportService transportService, ActionFilters actionFilters, CcrRestoreSourceService restoreSourceService) { super(NAME, transportService, actionFilters, GetCcrRestoreFileChunkRequest::new, ThreadPool.Names.GENERIC); - TransportActionProxy.registerProxyAction(transportService, NAME, GetCcrRestoreFileChunkResponse::new); + TransportActionProxy.registerProxyAction(transportService, NAME, false, GetCcrRestoreFileChunkResponse::new); this.restoreSourceService = restoreSourceService; this.bigArrays = bigArrays; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java index 75ae1c01a5b5b..d65c1f3b8d256 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java @@ -61,6 +61,7 @@ public TransportOpenPointInTimeAction( TransportActionProxy.registerProxyAction( transportService, OPEN_SHARD_READER_CONTEXT_NAME, + false, TransportOpenPointInTimeAction.ShardOpenReaderResponse::new ); }