diff --git a/qa/smoke-test-http/src/test/java/org/elasticsearch/http/RestSnapshotsStatusCancellationIT.java b/qa/smoke-test-http/src/test/java/org/elasticsearch/http/RestSnapshotsStatusCancellationIT.java new file mode 100644 index 0000000000000..08429bdbff3de --- /dev/null +++ b/qa/smoke-test-http/src/test/java/org/elasticsearch/http/RestSnapshotsStatusCancellationIT.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.http; + +import org.apache.http.client.methods.HttpGet; +import org.elasticsearch.action.admin.cluster.snapshots.status.SnapshotsStatusAction; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.Cancellable; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseListener; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.snapshots.AbstractSnapshotIntegTestCase; +import org.elasticsearch.snapshots.SnapshotState; +import org.elasticsearch.snapshots.mockstore.MockRepository; +import org.elasticsearch.test.ESIntegTestCase; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.CancellationException; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.test.TaskAssertions.assertAllCancellableTasksAreCancelled; +import static org.elasticsearch.test.TaskAssertions.assertAllTasksHaveFinished; +import static org.elasticsearch.test.TaskAssertions.awaitTaskWithPrefix; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0) +public class RestSnapshotsStatusCancellationIT extends HttpSmokeTestCase { + + @Override + protected Collection> nodePlugins() { + return CollectionUtils.appendToCopy(super.nodePlugins(), MockRepository.Plugin.class); + } + + public void testSnapshotStatusCancellation() throws Exception { + internalCluster().startMasterOnlyNode(); + internalCluster().startDataOnlyNode(); + ensureStableCluster(2); + + createIndex("test-idx"); + final String repoName = "test-repo"; + assertAcked( + client().admin().cluster().preparePutRepository(repoName) + .setType("mock").setSettings(Settings.builder().put("location", randomRepoPath()))); + + final int snapshotCount = randomIntBetween(1, 5); + final Collection snapshotNames = new ArrayList<>(); + for (int i = 0; i < snapshotCount; i++) { + final String snapshotName = "snapshot-" + i; + snapshotNames.add(snapshotName); + assertEquals( + SnapshotState.SUCCESS, + client().admin().cluster().prepareCreateSnapshot(repoName, "snapshot-" + i).setWaitForCompletion(true) + .get().getSnapshotInfo().state() + ); + } + + final MockRepository repository = AbstractSnapshotIntegTestCase.getRepositoryOnMaster(repoName); + repository.setBlockOnAnyFiles(); + + final Request request = new Request( + HttpGet.METHOD_NAME, + "/_snapshot/" + repoName + "/" + + String.join(",", randomSubsetOf(randomIntBetween(1, snapshotCount), snapshotNames)) + + "/_status" + ); + final PlainActionFuture future = new PlainActionFuture<>(); + final Cancellable cancellable = getRestClient().performRequestAsync(request, new ResponseListener() { + @Override + public void onSuccess(Response response) { + future.onResponse(null); + } + + @Override + public void onFailure(Exception exception) { + future.onFailure(exception); + } + }); + + assertFalse(future.isDone()); + awaitTaskWithPrefix(SnapshotsStatusAction.NAME); + assertBusy(() -> assertTrue(repository.blocked()), 30L, TimeUnit.SECONDS); + cancellable.cancel(); + assertAllCancellableTasksAreCancelled(SnapshotsStatusAction.NAME); + repository.unblock(); + expectThrows(CancellationException.class, future::actionGet); + + assertAllTasksHaveFinished(SnapshotsStatusAction.NAME); + } +} diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/SnapshotsStatusRequest.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/SnapshotsStatusRequest.java index c99a1602c7b63..b6fc25e7e2655 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/SnapshotsStatusRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/SnapshotsStatusRequest.java @@ -13,8 +13,12 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import java.io.IOException; +import java.util.Map; import static org.elasticsearch.action.ValidateActions.addValidationError; @@ -79,6 +83,11 @@ public ActionRequestValidationException validate() { return validationException; } + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers); + } + /** * Sets repository name * diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/TransportSnapshotsStatusAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/TransportSnapshotsStatusAction.java index 8f60ed3ebdac2..353464150512b 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/TransportSnapshotsStatusAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/status/TransportSnapshotsStatusAction.java @@ -42,6 +42,9 @@ import org.elasticsearch.snapshots.SnapshotShardsService; import org.elasticsearch.snapshots.SnapshotState; import org.elasticsearch.snapshots.SnapshotsService; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -84,14 +87,24 @@ protected ClusterBlockException checkBlock(SnapshotsStatusRequest request, Clust } @Override - protected void masterOperation(final SnapshotsStatusRequest request, + protected void masterOperation(SnapshotsStatusRequest request, ClusterState state, + ActionListener listener) throws Exception { + throw new UnsupportedOperationException("The task parameter is required"); + } + + @Override + protected void masterOperation(final Task task, + final SnapshotsStatusRequest request, final ClusterState state, final ActionListener listener) throws Exception { + assert task instanceof CancellableTask : task + " not cancellable"; + final CancellableTask cancellableTask = (CancellableTask) task; + final SnapshotsInProgress snapshotsInProgress = state.custom(SnapshotsInProgress.TYPE, SnapshotsInProgress.EMPTY); List currentSnapshots = SnapshotsService.currentSnapshots(snapshotsInProgress, request.repository(), Arrays.asList(request.snapshots())); if (currentSnapshots.isEmpty()) { - buildResponse(snapshotsInProgress, request, currentSnapshots, null, listener); + buildResponse(snapshotsInProgress, request, currentSnapshots, null, cancellableTask, listener); return; } @@ -115,18 +128,20 @@ protected void masterOperation(final SnapshotsStatusRequest request, .snapshots(snapshots).timeout(request.masterNodeTimeout()), ActionListener.wrap(nodeSnapshotStatuses -> threadPool.generic().execute( ActionRunnable.wrap(listener, - l -> buildResponse(snapshotsInProgress, request, currentSnapshots, nodeSnapshotStatuses, l)) + l -> buildResponse(snapshotsInProgress, request, currentSnapshots, nodeSnapshotStatuses, cancellableTask, l)) ), listener::onFailure)); } else { // We don't have any in-progress shards, just return current stats - buildResponse(snapshotsInProgress, request, currentSnapshots, null, listener); + buildResponse(snapshotsInProgress, request, currentSnapshots, null, cancellableTask, listener); } } - private void buildResponse(SnapshotsInProgress snapshotsInProgress, SnapshotsStatusRequest request, + private void buildResponse(SnapshotsInProgress snapshotsInProgress, + SnapshotsStatusRequest request, List currentSnapshotEntries, TransportNodesSnapshotsStatus.NodesSnapshotStatus nodeSnapshotStatuses, + CancellableTask task, ActionListener listener) { // First process snapshot that are currently processed List builder = new ArrayList<>(); @@ -212,19 +227,24 @@ private void buildResponse(SnapshotsInProgress snapshotsInProgress, SnapshotsSta // Now add snapshots on disk that are not currently running final String repositoryName = request.repository(); if (Strings.hasText(repositoryName) && CollectionUtils.isEmpty(request.snapshots()) == false) { - loadRepositoryData(snapshotsInProgress, request, builder, currentSnapshotNames, repositoryName, listener); + loadRepositoryData(snapshotsInProgress, request, builder, currentSnapshotNames, repositoryName, task, listener); } else { listener.onResponse(new SnapshotsStatusResponse(Collections.unmodifiableList(builder))); } } - private void loadRepositoryData(SnapshotsInProgress snapshotsInProgress, SnapshotsStatusRequest request, - List builder, Set currentSnapshotNames, String repositoryName, + private void loadRepositoryData(SnapshotsInProgress snapshotsInProgress, + SnapshotsStatusRequest request, + List builder, + Set currentSnapshotNames, + String repositoryName, + CancellableTask task, ActionListener listener) { final Set requestedSnapshotNames = Sets.newHashSet(request.snapshots()); final ListenableFuture repositoryDataListener = new ListenableFuture<>(); repositoriesService.getRepositoryData(repositoryName, repositoryDataListener); repositoryDataListener.addListener(ActionListener.wrap(repositoryData -> { + ensureNotCancelled(task); final Map matchedSnapshotIds = repositoryData.getSnapshotIds().stream() .filter(s -> requestedSnapshotNames.contains(s.getName())) .collect(Collectors.toMap(SnapshotId::getName, Function.identity())); @@ -248,7 +268,8 @@ private void loadRepositoryData(SnapshotsInProgress snapshotsInProgress, Snapsho SnapshotInfo snapshotInfo = snapshot(snapshotsInProgress, repositoryName, snapshotId); List shardStatusBuilder = new ArrayList<>(); if (snapshotInfo.state().completed()) { - Map shardStatuses = snapshotShards(repositoryName, repositoryData, snapshotInfo); + Map shardStatuses = + snapshotShards(repositoryName, repositoryData, task, snapshotInfo); for (Map.Entry shardStatus : shardStatuses.entrySet()) { IndexShardSnapshotStatus.Copy lastSnapshotStatus = shardStatus.getValue().asCopy(); shardStatusBuilder.add(new SnapshotIndexShardStatus(shardStatus.getKey(), lastSnapshotStatus)); @@ -313,12 +334,14 @@ private SnapshotInfo snapshot(SnapshotsInProgress snapshotsInProgress, String re * @return map of shard id to snapshot status */ private Map snapshotShards(final String repositoryName, - final RepositoryData repositoryData, - final SnapshotInfo snapshotInfo) throws IOException { + final RepositoryData repositoryData, + final CancellableTask task, + final SnapshotInfo snapshotInfo) throws IOException { final Repository repository = repositoriesService.repository(repositoryName); final Map shardStatus = new HashMap<>(); for (String index : snapshotInfo.indices()) { IndexId indexId = repositoryData.resolveIndexId(index); + ensureNotCancelled(task); IndexMetadata indexMetadata = repository.getSnapshotIndexMetaData(repositoryData, snapshotInfo.snapshotId(), indexId); if (indexMetadata != null) { int numberOfShards = indexMetadata.getNumberOfShards(); @@ -339,6 +362,7 @@ private Map snapshotShards(final String repos // could not be taken due to partial being set to false. shardSnapshotStatus = IndexShardSnapshotStatus.newFailed("skipped"); } else { + ensureNotCancelled(task); shardSnapshotStatus = repository.getShardSnapshotStatus( snapshotInfo.snapshotId(), indexId, @@ -352,6 +376,12 @@ private Map snapshotShards(final String repos return unmodifiableMap(shardStatus); } + private static void ensureNotCancelled(CancellableTask task) { + if (task.isCancelled()) { + throw new TaskCancelledException("task cancelled"); + } + } + private static SnapshotShardFailure findShardFailure(List shardFailures, ShardId shardId) { for (SnapshotShardFailure shardFailure : shardFailures) { if (shardId.getIndexName().equals(shardFailure.index()) && shardId.getId() == shardFailure.shardId()) { diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestSnapshotsStatusAction.java b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestSnapshotsStatusAction.java index 963baceb077e4..59d6aeb389747 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestSnapshotsStatusAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestSnapshotsStatusAction.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestCancellableNodeClient; import org.elasticsearch.rest.action.RestToXContentListener; import java.io.IOException; @@ -52,6 +53,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC snapshotsStatusRequest.ignoreUnavailable(request.paramAsBoolean("ignore_unavailable", snapshotsStatusRequest.ignoreUnavailable())); snapshotsStatusRequest.masterNodeTimeout(request.paramAsTime("master_timeout", snapshotsStatusRequest.masterNodeTimeout())); - return channel -> client.admin().cluster().snapshotsStatus(snapshotsStatusRequest, new RestToXContentListener<>(channel)); + return channel -> new RestCancellableNodeClient(client, request.getHttpChannel()) + .admin().cluster().snapshotsStatus(snapshotsStatusRequest, new RestToXContentListener<>(channel)); } }