From 5b04545cacb4d243bc567fc5e0029c0ceb847255 Mon Sep 17 00:00:00 2001 From: Iraklis Psaroudakis Date: Thu, 22 Dec 2022 11:50:33 +0200 Subject: [PATCH] Child requests proactively cancel children tasks To make this possible we modify the CancellableTasksTracker to track children tasks by the Request ID as well. That way, we can send an Action to cancel a child based on the parent task and the Request ID. This is especially useful when parents' children requests timeout on the parents' side. Fixes #90353 Relates #66992 --- docs/changelog/92588.yaml | 6 + .../node/tasks/CancellableTasksIT.java | 108 ++++++++++++----- .../TransportReplicationAction.java | 10 ++ .../cluster/service/MasterService.java | 3 + .../PersistentTasksNodeService.java | 5 + .../tasks/CancellableTasksTracker.java | 112 ++++++++++++------ .../elasticsearch/tasks/TaskAwareRequest.java | 12 ++ .../tasks/TaskCancellationService.java | 69 +++++++++++ .../org/elasticsearch/tasks/TaskManager.java | 53 +++++++-- .../transport/InboundHandler.java | 2 + .../transport/TransportRequest.java | 18 +++ .../transport/TransportService.java | 19 ++- .../tasks/BanFailureLoggingTests.java | 5 + .../tasks/CancellableTasksTrackerTests.java | 10 +- .../elasticsearch/tasks/TaskManagerTests.java | 9 ++ ...ortServiceDeserializationFailureTests.java | 5 + .../action/InternalExecutePolicyAction.java | 5 + .../xpack/enrich/EnrichPolicyRunnerTests.java | 6 + .../TrainedModelAssignmentNodeService.java | 5 + .../InferencePyTorchActionTests.java | 3 + .../ql/async/AsyncTaskManagementService.java | 10 ++ .../rest-api-spec/test/10_analyze.yml | 4 +- .../blobstore/testkit/BlobAnalyzeAction.java | 22 ++-- .../testkit/GetBlobChecksumAction.java | 3 +- 24 files changed, 413 insertions(+), 91 deletions(-) create mode 100644 docs/changelog/92588.yaml diff --git a/docs/changelog/92588.yaml b/docs/changelog/92588.yaml new file mode 100644 index 0000000000000..0447207b398b7 --- /dev/null +++ b/docs/changelog/92588.yaml @@ -0,0 +1,6 @@ +pr: 92588 +summary: Failed tasks proactively cancel children tasks +area: Snapshot/Restore +type: enhancement +issues: + - 90353 diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java index 3bcd7626a5f02..fa46558eb1ada 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java @@ -34,6 +34,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.tasks.CancellableTask; @@ -44,8 +45,11 @@ import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.threadpool.ThreadPoolStats; +import org.elasticsearch.transport.ReceiveTimeoutTransportException; import org.elasticsearch.transport.SendRequestTransportException; import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; import org.junit.After; @@ -63,6 +67,7 @@ import java.util.stream.Collectors; import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsStringIgnoringCase; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -77,24 +82,25 @@ public class CancellableTasksIT extends ESIntegTestCase { static final Map completedLatches = ConcurrentCollections.newConcurrentMap(); @After - public void ensureAllBansRemoved() throws Exception { + public void ensureBansAndCancellationsConsistency() throws Exception { assertBusy(() -> { for (String node : internalCluster().getNodeNames()) { TaskManager taskManager = internalCluster().getInstance(TransportService.class, node).getTaskManager(); assertThat("node " + node, taskManager.getBannedTaskIds(), empty()); + assertThat("node " + node, taskManager.assertCancellableTaskConsistency(), equalTo(true)); } }, 30, TimeUnit.SECONDS); } - static TestRequest generateTestRequest(Set nodes, int level, int maxLevel) { + static TestRequest generateTestRequest(Set nodes, int level, int maxLevel, boolean timeout) { List subRequests = new ArrayList<>(); int lower = level == 0 ? 1 : 0; int upper = 10 / (level + 1); int numOfSubRequests = randomIntBetween(lower, upper); for (int i = 0; i < numOfSubRequests && level <= maxLevel; i++) { - subRequests.add(generateTestRequest(nodes, level + 1, maxLevel)); + subRequests.add(generateTestRequest(nodes, level + 1, maxLevel, timeout)); } - final TestRequest request = new TestRequest(idGenerator++, randomFrom(nodes), subRequests); + final TestRequest request = new TestRequest(idGenerator++, randomFrom(nodes), subRequests, level == 0 ? false : timeout); beforeSendLatches.put(request, new CountDownLatch(1)); arrivedLatches.put(request, new CountDownLatch(1)); beforeExecuteLatches.put(request, new CountDownLatch(1)); @@ -157,7 +163,7 @@ public void testBanOnlyNodesWithOutstandingDescendantTasks() throws Exception { internalCluster().startNodes(randomIntBetween(1, 3)); } Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4)); + final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4), false); ActionFuture rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); Set pendingRequests = allowPartialRequest(rootRequest); TaskId rootTaskId = getRootTaskId(rootRequest); @@ -203,14 +209,14 @@ public void testBanOnlyNodesWithOutstandingDescendantTasks() throws Exception { } finally { allowEntireRequest(rootRequest); cancelFuture.actionGet(); - waitForRootTask(rootTaskFuture); - ensureAllBansRemoved(); + waitForRootTask(rootTaskFuture, false); + ensureBansAndCancellationsConsistency(); } } public void testCancelTaskMultipleTimes() throws Exception { Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - TestRequest rootRequest = generateTestRequest(nodes, 0, randomIntBetween(1, 3)); + TestRequest rootRequest = generateTestRequest(nodes, 0, randomIntBetween(1, 3), false); ActionFuture mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); TaskId taskId = getRootTaskId(rootRequest); allowPartialRequest(rootRequest); @@ -227,7 +233,7 @@ public void testCancelTaskMultipleTimes() throws Exception { allowEntireRequest(rootRequest); assertThat(cancelFuture.actionGet().getTaskFailures(), empty()); assertThat(cancelFuture.actionGet().getTaskFailures(), empty()); - waitForRootTask(mainTaskFuture); + waitForRootTask(mainTaskFuture, false); CancelTasksResponse cancelError = client().admin() .cluster() .prepareCancelTasks() @@ -237,12 +243,12 @@ public void testCancelTaskMultipleTimes() throws Exception { assertThat(cancelError.getNodeFailures(), hasSize(1)); final Throwable notFound = ExceptionsHelper.unwrap(cancelError.getNodeFailures().get(0), ResourceNotFoundException.class); assertThat(notFound.getMessage(), equalTo("task [" + taskId + "] is not found")); - ensureAllBansRemoved(); + ensureBansAndCancellationsConsistency(); } public void testDoNotWaitForCompletion() throws Exception { Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3)); + TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3), false); ActionFuture mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); TaskId taskId = getRootTaskId(rootRequest); if (randomBoolean()) { @@ -261,34 +267,34 @@ public void testDoNotWaitForCompletion() throws Exception { assertBusy(() -> assertTrue(cancelFuture.isDone())); } allowEntireRequest(rootRequest); - waitForRootTask(mainTaskFuture); + waitForRootTask(mainTaskFuture, false); cancelFuture.actionGet(); - ensureAllBansRemoved(); + ensureBansAndCancellationsConsistency(); } public void testFailedToStartChildTaskAfterCancelled() throws Exception { Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3)); + TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3), false); ActionFuture rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); TaskId taskId = getRootTaskId(rootRequest); client().admin().cluster().prepareCancelTasks().setTargetTaskId(taskId).waitForCompletion(false).get(); DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get(); TransportTestAction mainAction = internalCluster().getInstance(TransportTestAction.class, nodeWithParentTask.getName()); PlainActionFuture future = new PlainActionFuture<>(); - TestRequest subRequest = generateTestRequest(nodes, 0, between(0, 1)); + TestRequest subRequest = generateTestRequest(nodes, 0, between(0, 1), false); beforeSendLatches.get(subRequest).countDown(); mainAction.startSubTask(taskId, subRequest, future); TaskCancelledException te = expectThrows(TaskCancelledException.class, future::actionGet); assertThat(te.getMessage(), equalTo("parent task was cancelled [by user request]")); allowEntireRequest(rootRequest); - waitForRootTask(rootTaskFuture); - ensureAllBansRemoved(); + waitForRootTask(rootTaskFuture, false); + ensureBansAndCancellationsConsistency(); } public void testCancelOrphanedTasks() throws Exception { final String nodeWithRootTask = internalCluster().startDataOnlyNode(); Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3)); + TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3), false); client(nodeWithRootTask).execute(TransportTestAction.ACTION, rootRequest); allowPartialRequest(rootRequest); try { @@ -307,13 +313,13 @@ public void testCancelOrphanedTasks() throws Exception { }, 30, TimeUnit.SECONDS); } finally { allowEntireRequest(rootRequest); - ensureAllBansRemoved(); + ensureBansAndCancellationsConsistency(); } } public void testRemoveBanParentsOnDisconnect() throws Exception { Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4)); + final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4), false); client().execute(TransportTestAction.ACTION, rootRequest); Set pendingRequests = allowPartialRequest(rootRequest); TaskId rootTaskId = getRootTaskId(rootRequest); @@ -367,10 +373,28 @@ public void testRemoveBanParentsOnDisconnect() throws Exception { } finally { allowEntireRequest(rootRequest); cancelFuture.actionGet(); - ensureAllBansRemoved(); + ensureBansAndCancellationsConsistency(); } } + public void testChildrenTasksCancelledOnTimeout() throws Exception { + Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); + final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4), true); + ActionFuture rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); + allowEntireRequest(rootRequest); + waitForRootTask(rootTaskFuture, true); + assertBusy(() -> { + for (DiscoveryNode node : nodes) { + TransportService transportService = internalCluster().getInstance(TransportService.class, node.getName()); + for (ThreadPoolStats.Stats stat : transportService.getThreadPool().stats()) { + assertEquals(0, stat.getActive()); + assertEquals(0, stat.getQueue()); + } + } + }, 60L, TimeUnit.SECONDS); + ensureBansAndCancellationsConsistency(); + } + static TaskId getRootTaskId(TestRequest request) throws Exception { SetOnce taskId = new SetOnce<>(); assertBusy(() -> { @@ -390,19 +414,24 @@ static TaskId getRootTaskId(TestRequest request) throws Exception { return taskId.get(); } - static void waitForRootTask(ActionFuture rootTask) { + static void waitForRootTask(ActionFuture rootTask, boolean timeout) { try { rootTask.actionGet(); } catch (Exception e) { - final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class); + final Throwable cause = ExceptionsHelper.unwrap( + e, + timeout ? ReceiveTimeoutTransportException.class : TaskCancelledException.class + ); assertNotNull(cause); assertThat( cause.getMessage(), - anyOf( - equalTo("parent task was cancelled [by user request]"), - equalTo("task cancelled before starting [by user request]"), - equalTo("task cancelled [by user request]") - ) + timeout + ? containsStringIgnoringCase("timed out after") + : anyOf( + equalTo("parent task was cancelled [by user request]"), + equalTo("task cancelled before starting [by user request]"), + equalTo("task cancelled [by user request]") + ) ); } } @@ -411,11 +440,13 @@ static class TestRequest extends ActionRequest { final int id; final DiscoveryNode node; final List subRequests; + final boolean timeout; - TestRequest(int id, DiscoveryNode node, List subRequests) { + TestRequest(int id, DiscoveryNode node, List subRequests, boolean timeout) { this.id = id; this.node = node; this.subRequests = subRequests; + this.timeout = timeout; } TestRequest(StreamInput in) throws IOException { @@ -423,6 +454,7 @@ static class TestRequest extends ActionRequest { this.id = in.readInt(); this.node = new DiscoveryNode(in); this.subRequests = in.readList(TestRequest::new); + this.timeout = in.readBoolean(); } List descendants() { @@ -445,6 +477,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(id); node.writeTo(out); out.writeList(subRequests); + out.writeBoolean(timeout); } @Override @@ -513,7 +546,16 @@ protected void doExecute(Task task, TestRequest request, ActionListener { assertTrue(beforeExecuteLatches.get(request).await(60, TimeUnit.SECONDS)); - ((CancellableTask) task).ensureNotCancelled(); + if (request.timeout) { + // Simulate working until cancelled + while (((CancellableTask) task).isCancelled() == false) { + try { + Thread.sleep(1); + } catch (InterruptedException e) {} + } + } else { + ((CancellableTask) task).ensureNotCancelled(); + } return new TestResponse(); })); for (TestRequest subRequest : subRequests) { @@ -535,17 +577,21 @@ public void onFailure(Exception e) { @Override protected void doRun() throws Exception { assertTrue(beforeSendLatches.get(subRequest).await(60, TimeUnit.SECONDS)); - if (client.getLocalNodeId().equals(subRequest.node.getId()) && randomBoolean()) { + if (client.getLocalNodeId().equals(subRequest.node.getId()) && subRequest.timeout == false && randomBoolean()) { try { client.executeLocally(TransportTestAction.ACTION, subRequest, latchedListener); } catch (TaskCancelledException e) { latchedListener.onFailure(new SendRequestTransportException(subRequest.node, ACTION.name(), e)); } } else { + final TransportRequestOptions transportRequestOptions = subRequest.timeout + ? TransportRequestOptions.timeout(TimeValue.timeValueSeconds(1)) + : TransportRequestOptions.EMPTY; transportService.sendRequest( subRequest.node, ACTION.name(), subRequest, + transportRequestOptions, new ActionListenerResponseHandler(latchedListener, TestResponse::new) ); } diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java index 7e299f5baa256..b55136f52b7cd 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java @@ -1361,6 +1361,16 @@ public TaskId getParentTask() { return request.getParentTask(); } + @Override + public void setRequestId(long requestId) { + request.setRequestId(requestId); + } + + @Override + public long getRequestId() { + return request.getRequestId(); + } + @Override public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { return request.createTask(id, type, action, parentTaskId, headers); diff --git a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java index 06c4df9a87a31..3c4015e88eb52 100644 --- a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java +++ b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java @@ -318,6 +318,9 @@ private void publishClusterStateUpdate( @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java index dd2bfa489e5b8..42d514ae23b45 100644 --- a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java +++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java @@ -175,6 +175,11 @@ public void setParentTask(TaskId taskId) { throw new UnsupportedOperationException("parent task if for persistent tasks shouldn't change"); } + @Override + public void setRequestId(long requestId) { + throw new UnsupportedOperationException("does not have a request ID"); + } + @Override public TaskId getParentTask() { return parentTaskId; diff --git a/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java b/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java index c1723e492dde3..a44b653d66fb2 100644 --- a/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java +++ b/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java @@ -31,21 +31,45 @@ public CancellableTasksTracker(T[] empty) { } private final Map byTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); - private final Map byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); + private final Map> byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); + + /** + * Gets the cancellable children of a parent task. + * + * Note: children of non-positive request IDs (e.g., -1) may be grouped together. + */ + public Stream getChildrenByRequestId(TaskId parentTaskId, long childRequestId) { + Map byRequestId = byParentTaskId.get(parentTaskId); + if (byRequestId != null) { + T[] children = byRequestId.get(childRequestId); + if (children != null) { + return Arrays.stream(children); + } + } + return Stream.empty(); + } /** * Add an item for the given task. Should only be called once for each task, and {@code item} must be unique per task too. */ - public void put(Task task, T item) { + public void put(Task task, long requestId, T item) { final long taskId = task.getId(); if (task.getParentTaskId().isSet()) { - byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> { - if (oldValue == null) { - oldValue = empty; + byParentTaskId.compute(task.getParentTaskId(), (taskKey, oldRequestIdMap) -> { + if (oldRequestIdMap == null) { + oldRequestIdMap = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); } - final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1); - newValue[oldValue.length] = item; - return newValue; + + oldRequestIdMap.compute(requestId, (requestIdKey, oldValue) -> { + if (oldValue == null) { + oldValue = empty; + } + final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1); + newValue[oldValue.length] = item; + return newValue; + }); + + return oldRequestIdMap; }); } final T oldItem = byTaskId.put(taskId, item); @@ -60,36 +84,50 @@ public T get(long id) { } /** - * Remove (and return) the item that corresponds with the given task. Return {@code null} if not present. Safe to call multiple times - * for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if the removal is - * actually being completed by a concurrent call that's still ongoing. + * Remove (and return) the item that corresponds with the given task and request ID. Return {@code null} if not present. Safe to call + * multiple times for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if + * the removal is actually being completed by a concurrent call that's still ongoing. */ public T remove(Task task) { final long taskId = task.getId(); final T oldItem = byTaskId.remove(taskId); if (oldItem != null && task.getParentTaskId().isSet()) { - byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> { - if (oldValue == null) { + byParentTaskId.compute(task.getParentTaskId(), (taskKey, oldRequestIdMap) -> { + if (oldRequestIdMap == null) { return null; } - if (oldValue.length == 1) { - if (oldValue[0] == oldItem) { - return null; - } else { + + for (Long requestId : oldRequestIdMap.keySet()) { + oldRequestIdMap.compute(requestId, (requestIdKey, oldValue) -> { + if (oldValue == null) { + return null; + } + if (oldValue.length == 1) { + if (oldValue[0] == oldItem) { + return null; + } else { + return oldValue; + } + } + if (oldValue[0] == oldItem) { + return Arrays.copyOfRange(oldValue, 1, oldValue.length); + } + for (int i = 1; i < oldValue.length; i++) { + if (oldValue[i] == oldItem) { + final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1); + System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1); + return newValue; + } + } return oldValue; - } - } - if (oldValue[0] == oldItem) { - return Arrays.copyOfRange(oldValue, 1, oldValue.length); + }); } - for (int i = 1; i < oldValue.length; i++) { - if (oldValue[i] == oldItem) { - final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1); - System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1); - return newValue; - } + + if (oldRequestIdMap.keySet().isEmpty()) { + return null; } - return oldValue; + + return oldRequestIdMap; }); } return oldItem; @@ -109,11 +147,11 @@ public Collection values() { * started before this method was called have not completed. */ public Stream getByParent(TaskId parentTaskId) { - final T[] byParent = byParentTaskId.get(parentTaskId); + final Map byParent = byParentTaskId.get(parentTaskId); if (byParent == null) { return Stream.empty(); } - return Arrays.stream(byParent); + return byParent.values().stream().flatMap(Stream::of); } // assertion for tests, not an invariant but should eventually be true @@ -123,12 +161,14 @@ boolean assertConsistent() { // every by-parent value must be tracked by task too; the converse isn't true since we don't track values without a parent final Set byTaskValues = new HashSet<>(byTaskId.values()); - for (T[] byParent : byParentTaskId.values()) { - assert byParent.length > 0; - for (T t : byParent) { - assert byTaskValues.contains(t); - } - } + byParentTaskId.values().forEach(byParentMap -> { + byParentMap.forEach((requestId, byParentArray) -> { + assert byParentArray.length > 0; + for (T t : byParentArray) { + assert byTaskValues.contains(t); + } + }); + }); return true; } diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java b/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java index d0f7e3565e233..a791066ea5089 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java @@ -26,6 +26,18 @@ default void setParentTask(String parentTaskNode, long parentTaskId) { */ void setParentTask(TaskId taskId); + /** + * Gets the request ID. Defaults to -1, meaning "no request ID is set". + */ + default long getRequestId() { + return -1; + } + + /** + * Set the request ID related to this task. + */ + void setRequestId(long requestId); + /** * Get a reference to the task that created this request. Implementers should default to * {@link TaskId#EMPTY_TASK_ID}, meaning "there is no parent". diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java index 96e02bfa4f50f..6be784180e442 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java @@ -43,6 +43,7 @@ public class TaskCancellationService { public static final String BAN_PARENT_ACTION_NAME = "internal:admin/tasks/ban"; + public static final String CANCEL_CHILD_ACTION_NAME = "internal:admin/tasks/cancel_child"; private static final Logger logger = LogManager.getLogger(TaskCancellationService.class); private final TransportService transportService; private final TaskManager taskManager; @@ -58,6 +59,12 @@ public TaskCancellationService(TransportService transportService) { BanParentTaskRequest::new, new BanParentRequestHandler() ); + transportService.registerRequestHandler( + CANCEL_CHILD_ACTION_NAME, + ThreadPool.Names.SAME, + CancelChildRequest::new, + new CancelChildRequestHandler() + ); } private String localNodeId() { @@ -328,4 +335,66 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC } } } + + private static class CancelChildRequest extends TransportRequest { + + private final TaskId parentTaskId; + private final long childRequestId; + private final String reason; + + static CancelChildRequest createCancelChildRequest(TaskId parentTaskId, long childRequestId, String reason) { + return new CancelChildRequest(parentTaskId, childRequestId, reason); + } + + private CancelChildRequest(TaskId parentTaskId, long childRequestId, String reason) { + this.parentTaskId = parentTaskId; + this.childRequestId = childRequestId; + this.reason = reason; + } + + private CancelChildRequest(StreamInput in) throws IOException { + super(in); + parentTaskId = TaskId.readFromStream(in); + childRequestId = in.readLong(); + reason = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + parentTaskId.writeTo(out); + out.writeLong(childRequestId); + out.writeString(reason); + } + } + + private class CancelChildRequestHandler implements TransportRequestHandler { + @Override + public void messageReceived(final CancelChildRequest request, final TransportChannel channel, Task task) throws Exception { + taskManager.cancelChildLocal(request.parentTaskId, request.childRequestId, request.reason); + channel.sendResponse(TransportResponse.Empty.INSTANCE); + } + } + + /** + * Sends an action to cancel a child task, associated with the given request ID and parent task. + */ + public void cancelChildRemote(TaskId parentTask, long childRequestId, Transport.Connection childConnection, String reason) { + logger.debug( + "sending cancellation of child of parent task [{}] with request ID [{}] on the connection [{}] because of [{}]", + parentTask, + childRequestId, + childConnection, + reason + ); + final CancelChildRequest request = CancelChildRequest.createCancelChildRequest(parentTask, childRequestId, reason); + transportService.sendRequest( + childConnection, + CANCEL_CHILD_ACTION_NAME, + request, + TransportRequestOptions.EMPTY, + EmptyTransportResponseHandler.INSTANCE_SAME + ); + } + } diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java index e4cec96f73aba..2dd53ce9835cb 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -154,7 +154,7 @@ public Task register(String type, String action, TaskAwareRequest request, boole } if (task instanceof CancellableTask) { - registerCancellableTask(task, traceRequest); + registerCancellableTask(task, request.getRequestId(), traceRequest); } else { Task previousTask = tasks.put(task.getId(), task); assert previousTask == null; @@ -212,6 +212,9 @@ public void onResponse(Response response) { @Override public void onFailure(Exception e) { try { + if (request.getParentTask().isSet()) { + cancelChildLocal(request.getParentTask(), request.getRequestId(), e.toString()); + } release(); } finally { taskListener.onFailure(e); @@ -231,10 +234,10 @@ private void release() { } } - private void registerCancellableTask(Task task, boolean traceRequest) { + private void registerCancellableTask(Task task, long requestId, boolean traceRequest) { CancellableTask cancellableTask = (CancellableTask) task; CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask); - cancellableTasks.put(task, holder); + cancellableTasks.put(task, requestId, holder); if (traceRequest) { startTrace(threadPool.getThreadContext(), task); } @@ -253,6 +256,16 @@ private void registerCancellableTask(Task task, boolean traceRequest) { } } + private TaskCancellationService getCancellationService() { + final TaskCancellationService service = cancellationService.get(); + if (service != null) { + return service; + } else { + assert false : "TaskCancellationService is not initialized"; + throw new IllegalStateException("TaskCancellationService is not initialized"); + } + } + /** * Cancels a task *

@@ -270,6 +283,32 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { } } + /** + * Cancels children tasks of the specified parent, with the request ID specified, as long as the request ID is positive. + * + * Note: There may be multiple children for the same request ID. In this edge case all these multiple children are cancelled. + */ + public void cancelChildLocal(TaskId parentTaskId, long childRequestId, String reason) { + if (childRequestId > 0) { + List children = cancellableTasks.getChildrenByRequestId(parentTaskId, childRequestId).toList(); + if (children.isEmpty() == false) { + logger.trace("cancelling children of task [{}] and request ID [{}] with reason [{}]", parentTaskId, childRequestId, reason); + for (CancellableTaskHolder child : children) { + child.cancel(reason); + } + } + } + } + + /** + * Send an Action to cancel children tasks of the specified parent, with the request ID specified. + * + * Note: There may be multiple children for the same request ID. In this edge case all these multiple children are cancelled. + */ + public void cancelChildRemote(TaskId parentTask, long childRequestId, Transport.Connection childConnection, String reason) { + getCancellationService().cancelChildRemote(parentTask, childRequestId, childConnection, reason); + } + /** * Unregister the task */ @@ -778,13 +817,7 @@ protected void doRun() { } public void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) { - final TaskCancellationService service = cancellationService.get(); - if (service != null) { - service.cancelTaskAndDescendants(task, reason, waitForCompletion, listener); - } else { - assert false : "TaskCancellationService is not initialized"; - throw new IllegalStateException("TaskCancellationService is not initialized"); - } + getCancellationService().cancelTaskAndDescendants(task, reason, waitForCompletion, listener); } public List getTaskHeaders() { diff --git a/server/src/main/java/org/elasticsearch/transport/InboundHandler.java b/server/src/main/java/org/elasticsearch/transport/InboundHandler.java index c79f9b0cef3db..999553c2c8828 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundHandler.java @@ -260,6 +260,8 @@ private void handleRequest(TcpChannel channel, Head } try { request.remoteAddress(channel.getRemoteAddress()); + assert requestId > 0; + request.setRequestId(requestId); // in case we throw an exception, i.e. when the limit is hit, we don't want to verify final int nextByte = stream.read(); // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker diff --git a/server/src/main/java/org/elasticsearch/transport/TransportRequest.java b/server/src/main/java/org/elasticsearch/transport/TransportRequest.java index 094d441d8a1c8..382db6a83076d 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportRequest.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportRequest.java @@ -31,6 +31,11 @@ public Empty(StreamInput in) throws IOException { */ private TaskId parentTaskId = TaskId.EMPTY_TASK_ID; + /** + * Request ID. Defaults to -1, meaning "no request ID is set". + */ + private long requestId = -1; + public TransportRequest() {} public TransportRequest(StreamInput in) throws IOException { @@ -53,6 +58,19 @@ public TaskId getParentTask() { return parentTaskId; } + /** + * Set the request ID of this request. + */ + @Override + public void setRequestId(long requestId) { + this.requestId = requestId; + } + + @Override + public long getRequestId() { + return requestId; + } + @Override public void writeTo(StreamOutput out) throws IOException { parentTaskId.writeTo(out); diff --git a/server/src/main/java/org/elasticsearch/transport/TransportService.java b/server/src/main/java/org/elasticsearch/transport/TransportService.java index 51781077e1c37..0977a0e2da3e0 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportService.java @@ -773,7 +773,14 @@ public final void sendRequest( if (unregisterChildNode == null) { delegate = handler; } else { - delegate = new UnregisterChildTransportResponseHandler<>(unregisterChildNode, handler, action); + delegate = new UnregisterChildTransportResponseHandler<>( + unregisterChildNode, + handler, + action, + request, + unwrappedConn, + taskManager + ); } } else { delegate = handler; @@ -875,6 +882,7 @@ private void sendRequestInternal( ContextRestoreResponseHandler responseHandler = new ContextRestoreResponseHandler<>(storedContextSupplier, handler); // TODO we can probably fold this entire request ID dance into connection.sendRequest but it will be a bigger refactoring final long requestId = responseHandlers.add(new Transport.ResponseContext<>(responseHandler, connection, action)); + request.setRequestId(requestId); final TimeoutHandler timeoutHandler; if (options.timeout() != null) { timeoutHandler = new TimeoutHandler(requestId, connection.getNode(), action); @@ -895,6 +903,7 @@ private void sendRequestInternal( assert options.timeout() != null; timeoutHandler.scheduleTimeout(options.timeout()); } + logger.trace("sending internal request id [{}] action [{}] request [{}] options [{}]", requestId, action, request, options); connection.sendRequest(requestId, action, request, options); // local node optimization happens upstream } catch (final Exception e) { handleInternalSendException(action, node, requestId, timeoutHandler, e); @@ -1631,7 +1640,10 @@ Releasable withRef() { private record UnregisterChildTransportResponseHandler ( Releasable unregisterChildNode, TransportResponseHandler handler, - String action + String action, + TransportRequest childRequest, + Transport.Connection childConnection, + TaskManager taskManager ) implements TransportResponseHandler { @Override @@ -1642,6 +1654,9 @@ public void handleResponse(T response) { @Override public void handleException(TransportException exp) { + assert childRequest.getParentTask().isSet(); + taskManager.cancelChildRemote(childRequest.getParentTask(), childRequest.getRequestId(), childConnection, exp.toString()); + unregisterChildNode.close(); handler.handleException(exp); } diff --git a/server/src/test/java/org/elasticsearch/tasks/BanFailureLoggingTests.java b/server/src/test/java/org/elasticsearch/tasks/BanFailureLoggingTests.java index 54eb3ca175066..c8f8bf6eb1150 100644 --- a/server/src/test/java/org/elasticsearch/tasks/BanFailureLoggingTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/BanFailureLoggingTests.java @@ -202,6 +202,11 @@ public void setParentTask(TaskId taskId) { fail("setParentTask should not be called"); } + @Override + public void setRequestId(long requestId) { + fail("setRequestId should not be called"); + } + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java b/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java index 7c29405ec7e0d..da40e307bce28 100644 --- a/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java @@ -35,6 +35,7 @@ private static class TestTask { // 0 == before put, 1 == during put, 2 == after put, before remove, 3 == during remove, 4 == after remove private final AtomicInteger state = new AtomicInteger(); private final boolean concurrentRemove = randomBoolean(); + private final long requestId = randomIntBetween(-1, 10); TestTask(Task task, String item, CancellableTasksTracker tracker, Runnable awaitStart) { if (concurrentRemove) { @@ -58,7 +59,7 @@ private static class TestTask { awaitStart.run(); state.incrementAndGet(); - tracker.put(task, item); + tracker.put(task, requestId, item); state.incrementAndGet(); Thread.yield(); @@ -80,6 +81,8 @@ private static class TestTask { final int stateBefore = state.get(); final String getResult = tracker.get(task.getId()); final Set getByParentResult = tracker.getByParent(task.getParentTaskId()).collect(Collectors.toSet()); + final Set getByChildrenResult = tracker.getChildrenByRequestId(task.getParentTaskId(), requestId) + .collect(Collectors.toSet()); final Set values = new HashSet<>(tracker.values()); final int stateAfter = state.get(); @@ -87,11 +90,13 @@ private static class TestTask { if (getResult != null && task.getParentTaskId().isSet() && tracker.get(task.getId()) != null) { assertThat(getByParentResult, hasItem(item)); + assertThat(getByChildrenResult, hasItem(item)); } if (stateAfter == 0) { assertNull(getResult); assertThat(getByParentResult, not(hasItem(item))); + assertThat(getByChildrenResult, not(hasItem(item))); assertThat(values, not(hasItem(item))); } @@ -99,8 +104,10 @@ private static class TestTask { assertSame(item, getResult); if (task.getParentTaskId().isSet()) { assertThat(getByParentResult, hasItem(item)); + assertThat(getByChildrenResult, hasItem(item)); } else { assertThat(getByParentResult, empty()); + assertThat(getByChildrenResult, empty()); } assertThat(values, hasItem(item)); } @@ -109,6 +116,7 @@ private static class TestTask { assertNull(getResult); if (concurrentRemove == false) { assertThat(getByParentResult, not(hasItem(item))); + assertThat(getByChildrenResult, not(hasItem(item))); } // else our remove might have completed but the concurrent one hasn't updated the parent ID map yet assertThat(values, not(hasItem(item))); } diff --git a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java index a1e7aad45293b..02e76016c5ccc 100644 --- a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java @@ -288,6 +288,9 @@ public void testRegisterTaskStartsTracing() { @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; @@ -309,6 +312,9 @@ public void testUnregisterTaskStopsTracing() { @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; @@ -471,6 +477,9 @@ private TaskAwareRequest makeTaskRequest(boolean cancellable, final int parentTa @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return new TaskId("something", parentTaskNum); diff --git a/server/src/test/java/org/elasticsearch/transport/TransportServiceDeserializationFailureTests.java b/server/src/test/java/org/elasticsearch/transport/TransportServiceDeserializationFailureTests.java index ec1944a65519b..4cfda499f028c 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportServiceDeserializationFailureTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportServiceDeserializationFailureTests.java @@ -125,6 +125,11 @@ public void setParentTask(TaskId taskId) { fail("should not be called"); } + @Override + public void setRequestId(long requestId) { + fail("should not be called"); + } + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/InternalExecutePolicyAction.java b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/InternalExecutePolicyAction.java index feeba1a4c3ccf..e99b787926361 100644 --- a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/InternalExecutePolicyAction.java +++ b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/InternalExecutePolicyAction.java @@ -142,6 +142,11 @@ public void setParentTask(TaskId taskId) { request.setParentTask(taskId); } + @Override + public void setRequestId(long requestId) { + request.setRequestId(requestId); + } + @Override public TaskId getParentTask() { return request.getParentTask(); diff --git a/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunnerTests.java b/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunnerTests.java index 90a9462f770d4..b8e1aaade568e 100644 --- a/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunnerTests.java +++ b/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunnerTests.java @@ -1780,6 +1780,9 @@ public void testRunnerWithForceMergeRetry() throws Exception { @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; @@ -2025,6 +2028,9 @@ private EnrichPolicyRunner createPolicyRunner( @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index 8a1f818ed22f6..2d6045534bfd3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -299,6 +299,11 @@ public void setParentTask(TaskId taskId) { throw new UnsupportedOperationException("parent task id for model assignment tasks shouldn't change"); } + @Override + public void setRequestId(long requestId) { + throw new UnsupportedOperationException("does not have request ID"); + } + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java index aa7831bcc03f1..181b6abd5b549 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java @@ -194,6 +194,9 @@ public void testCallingRunAfterParentTaskCancellation() throws Exception { @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java index d42f2619a166a..f20bf52158e85 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java @@ -107,6 +107,16 @@ public TaskId getParentTask() { return request.getParentTask(); } + @Override + public void setRequestId(long requestId) { + request.setRequestId(requestId); + } + + @Override + public long getRequestId() { + return request.getRequestId(); + } + @Override public Task createTask(long id, String type, String actionName, TaskId parentTaskId, Map headers) { Map originHeaders = ClientHelper.getPersistableSafeSecurityHeaders( diff --git a/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml b/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml index 25c1ed26bcdc7..6223ca8443b0e 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml +++ b/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml @@ -158,8 +158,8 @@ setup: --- "Timeout with large blobs": - skip: - version: all - reason: "AwaitsFix https://github.com/elastic/elasticsearch/issues/90353" + version: "- 7.13.99" + reason: "abortWrites flag introduced in 7.14, and mixed-cluster support not required" - do: catch: request diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/BlobAnalyzeAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/BlobAnalyzeAction.java index 0046f05919071..5ed2e800664b0 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/BlobAnalyzeAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/BlobAnalyzeAction.java @@ -32,6 +32,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.CancellableThreads; import org.elasticsearch.core.Nullable; import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.repositories.Repository; @@ -39,7 +40,6 @@ import org.elasticsearch.repositories.blobstore.BlobStoreRepository; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskAwareRequest; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportRequestOptions; @@ -274,11 +274,19 @@ static class BlobAnalysis { } void run() { - writeRandomBlob( - request.readEarly || request.getAbortWrite() || (request.targetLength <= MAX_ATOMIC_WRITE_SIZE && random.nextBoolean()), - true, - this::onLastReadForInitialWrite, - write1Step + final CancellableThreads cancellableThreads = new CancellableThreads(); + task.addListener(() -> { + // This interrupts the blob writing thread in case it stuck in a sleep() due to rate limiting. + cancellableThreads.cancel(task.getReasonCancelled()); + }); + + cancellableThreads.execute( + () -> writeRandomBlob( + request.readEarly || request.getAbortWrite() || (request.targetLength <= MAX_ATOMIC_WRITE_SIZE && random.nextBoolean()), + true, + this::onLastReadForInitialWrite, + write1Step + ) ); if (request.writeAndOverwrite) { @@ -621,7 +629,7 @@ private WriteDetails(long bytesWritten, long elapsedNanos, long throttledNanos, } } - public static class Request extends ActionRequest implements TaskAwareRequest { + public static class Request extends ActionRequest { private final String repositoryName; private final String blobPath; private final String blobName; diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/GetBlobChecksumAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/GetBlobChecksumAction.java index 14e760875c9c7..96828fd5a4c04 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/GetBlobChecksumAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/GetBlobChecksumAction.java @@ -28,7 +28,6 @@ import org.elasticsearch.repositories.blobstore.BlobStoreRepository; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskAwareRequest; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -186,7 +185,7 @@ protected void doExecute(Task task, Request request, ActionListener li } - public static class Request extends ActionRequest implements TaskAwareRequest { + public static class Request extends ActionRequest { private final String repositoryName; private final String blobPath;