Skip to content

Commit

Permalink
Make SnapshotStatusAction Cancellable (elastic#73818)
Browse files Browse the repository at this point in the history
Same as elastic#72644. This is a much longer running action than normal
get snapshots even so it should definitely be cancellable.
Parallelization for this action will be introduced in a separate PR.
  • Loading branch information
original-brownbear committed Jun 7, 2021
1 parent 30da196 commit 44f762f
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.http;

import org.apache.http.client.methods.HttpGet;
import org.elasticsearch.action.admin.cluster.snapshots.status.SnapshotsStatusAction;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Cancellable;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.snapshots.AbstractSnapshotIntegTestCase;
import org.elasticsearch.snapshots.SnapshotState;
import org.elasticsearch.snapshots.mockstore.MockRepository;
import org.elasticsearch.test.ESIntegTestCase;

import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.test.TaskAssertions.assertAllCancellableTasksAreCancelled;
import static org.elasticsearch.test.TaskAssertions.assertAllTasksHaveFinished;
import static org.elasticsearch.test.TaskAssertions.awaitTaskWithPrefix;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;

@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0)
public class RestSnapshotsStatusCancellationIT extends HttpSmokeTestCase {

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return CollectionUtils.appendToCopy(super.nodePlugins(), MockRepository.Plugin.class);
}

public void testSnapshotStatusCancellation() throws Exception {
internalCluster().startMasterOnlyNode();
internalCluster().startDataOnlyNode();
ensureStableCluster(2);

createIndex("test-idx");
final String repoName = "test-repo";
assertAcked(
client().admin().cluster().preparePutRepository(repoName)
.setType("mock").setSettings(Settings.builder().put("location", randomRepoPath())));

final int snapshotCount = randomIntBetween(1, 5);
final Collection<String> snapshotNames = new ArrayList<>();
for (int i = 0; i < snapshotCount; i++) {
final String snapshotName = "snapshot-" + i;
snapshotNames.add(snapshotName);
assertEquals(
SnapshotState.SUCCESS,
client().admin().cluster().prepareCreateSnapshot(repoName, "snapshot-" + i).setWaitForCompletion(true)
.get().getSnapshotInfo().state()
);
}

final MockRepository repository = AbstractSnapshotIntegTestCase.getRepositoryOnMaster(repoName);
repository.setBlockOnAnyFiles();

final Request request = new Request(
HttpGet.METHOD_NAME,
"/_snapshot/" + repoName + "/"
+ String.join(",", randomSubsetOf(randomIntBetween(1, snapshotCount), snapshotNames))
+ "/_status"
);
final PlainActionFuture<Void> future = new PlainActionFuture<>();
final Cancellable cancellable = getRestClient().performRequestAsync(request, new ResponseListener() {
@Override
public void onSuccess(Response response) {
future.onResponse(null);
}

@Override
public void onFailure(Exception exception) {
future.onFailure(exception);
}
});

assertFalse(future.isDone());
awaitTaskWithPrefix(SnapshotsStatusAction.NAME);
assertBusy(() -> assertTrue(repository.blocked()), 30L, TimeUnit.SECONDS);
cancellable.cancel();
assertAllCancellableTasksAreCancelled(SnapshotsStatusAction.NAME);
repository.unblock();
expectThrows(CancellationException.class, future::actionGet);

assertAllTasksHaveFinished(SnapshotsStatusAction.NAME);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.util.Map;

import static org.elasticsearch.action.ValidateActions.addValidationError;

Expand Down Expand Up @@ -79,6 +83,11 @@ public ActionRequestValidationException validate() {
return validationException;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers);
}

/**
* Sets repository name
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
import org.elasticsearch.snapshots.SnapshotShardsService;
import org.elasticsearch.snapshots.SnapshotState;
import org.elasticsearch.snapshots.SnapshotsService;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

Expand Down Expand Up @@ -84,14 +87,24 @@ protected ClusterBlockException checkBlock(SnapshotsStatusRequest request, Clust
}

@Override
protected void masterOperation(final SnapshotsStatusRequest request,
protected void masterOperation(SnapshotsStatusRequest request, ClusterState state,
ActionListener<SnapshotsStatusResponse> listener) throws Exception {
throw new UnsupportedOperationException("The task parameter is required");
}

@Override
protected void masterOperation(final Task task,
final SnapshotsStatusRequest request,
final ClusterState state,
final ActionListener<SnapshotsStatusResponse> listener) throws Exception {
assert task instanceof CancellableTask : task + " not cancellable";
final CancellableTask cancellableTask = (CancellableTask) task;

final SnapshotsInProgress snapshotsInProgress = state.custom(SnapshotsInProgress.TYPE, SnapshotsInProgress.EMPTY);
List<SnapshotsInProgress.Entry> currentSnapshots =
SnapshotsService.currentSnapshots(snapshotsInProgress, request.repository(), Arrays.asList(request.snapshots()));
if (currentSnapshots.isEmpty()) {
buildResponse(snapshotsInProgress, request, currentSnapshots, null, listener);
buildResponse(snapshotsInProgress, request, currentSnapshots, null, cancellableTask, listener);
return;
}

Expand All @@ -115,18 +128,20 @@ protected void masterOperation(final SnapshotsStatusRequest request,
.snapshots(snapshots).timeout(request.masterNodeTimeout()),
ActionListener.wrap(nodeSnapshotStatuses -> threadPool.generic().execute(
ActionRunnable.wrap(listener,
l -> buildResponse(snapshotsInProgress, request, currentSnapshots, nodeSnapshotStatuses, l))
l -> buildResponse(snapshotsInProgress, request, currentSnapshots, nodeSnapshotStatuses, cancellableTask, l))
), listener::onFailure));
} else {
// We don't have any in-progress shards, just return current stats
buildResponse(snapshotsInProgress, request, currentSnapshots, null, listener);
buildResponse(snapshotsInProgress, request, currentSnapshots, null, cancellableTask, listener);
}

}

private void buildResponse(SnapshotsInProgress snapshotsInProgress, SnapshotsStatusRequest request,
private void buildResponse(SnapshotsInProgress snapshotsInProgress,
SnapshotsStatusRequest request,
List<SnapshotsInProgress.Entry> currentSnapshotEntries,
TransportNodesSnapshotsStatus.NodesSnapshotStatus nodeSnapshotStatuses,
CancellableTask task,
ActionListener<SnapshotsStatusResponse> listener) {
// First process snapshot that are currently processed
List<SnapshotStatus> builder = new ArrayList<>();
Expand Down Expand Up @@ -212,19 +227,24 @@ private void buildResponse(SnapshotsInProgress snapshotsInProgress, SnapshotsSta
// Now add snapshots on disk that are not currently running
final String repositoryName = request.repository();
if (Strings.hasText(repositoryName) && CollectionUtils.isEmpty(request.snapshots()) == false) {
loadRepositoryData(snapshotsInProgress, request, builder, currentSnapshotNames, repositoryName, listener);
loadRepositoryData(snapshotsInProgress, request, builder, currentSnapshotNames, repositoryName, task, listener);
} else {
listener.onResponse(new SnapshotsStatusResponse(Collections.unmodifiableList(builder)));
}
}

private void loadRepositoryData(SnapshotsInProgress snapshotsInProgress, SnapshotsStatusRequest request,
List<SnapshotStatus> builder, Set<String> currentSnapshotNames, String repositoryName,
private void loadRepositoryData(SnapshotsInProgress snapshotsInProgress,
SnapshotsStatusRequest request,
List<SnapshotStatus> builder,
Set<String> currentSnapshotNames,
String repositoryName,
CancellableTask task,
ActionListener<SnapshotsStatusResponse> listener) {
final Set<String> requestedSnapshotNames = Sets.newHashSet(request.snapshots());
final ListenableFuture<RepositoryData> repositoryDataListener = new ListenableFuture<>();
repositoriesService.getRepositoryData(repositoryName, repositoryDataListener);
repositoryDataListener.addListener(ActionListener.wrap(repositoryData -> {
ensureNotCancelled(task);
final Map<String, SnapshotId> matchedSnapshotIds = repositoryData.getSnapshotIds().stream()
.filter(s -> requestedSnapshotNames.contains(s.getName()))
.collect(Collectors.toMap(SnapshotId::getName, Function.identity()));
Expand All @@ -248,7 +268,8 @@ private void loadRepositoryData(SnapshotsInProgress snapshotsInProgress, Snapsho
SnapshotInfo snapshotInfo = snapshot(snapshotsInProgress, repositoryName, snapshotId);
List<SnapshotIndexShardStatus> shardStatusBuilder = new ArrayList<>();
if (snapshotInfo.state().completed()) {
Map<ShardId, IndexShardSnapshotStatus> shardStatuses = snapshotShards(repositoryName, repositoryData, snapshotInfo);
Map<ShardId, IndexShardSnapshotStatus> shardStatuses =
snapshotShards(repositoryName, repositoryData, task, snapshotInfo);
for (Map.Entry<ShardId, IndexShardSnapshotStatus> shardStatus : shardStatuses.entrySet()) {
IndexShardSnapshotStatus.Copy lastSnapshotStatus = shardStatus.getValue().asCopy();
shardStatusBuilder.add(new SnapshotIndexShardStatus(shardStatus.getKey(), lastSnapshotStatus));
Expand Down Expand Up @@ -313,12 +334,14 @@ private SnapshotInfo snapshot(SnapshotsInProgress snapshotsInProgress, String re
* @return map of shard id to snapshot status
*/
private Map<ShardId, IndexShardSnapshotStatus> snapshotShards(final String repositoryName,
final RepositoryData repositoryData,
final SnapshotInfo snapshotInfo) throws IOException {
final RepositoryData repositoryData,
final CancellableTask task,
final SnapshotInfo snapshotInfo) throws IOException {
final Repository repository = repositoriesService.repository(repositoryName);
final Map<ShardId, IndexShardSnapshotStatus> shardStatus = new HashMap<>();
for (String index : snapshotInfo.indices()) {
IndexId indexId = repositoryData.resolveIndexId(index);
ensureNotCancelled(task);
IndexMetadata indexMetadata = repository.getSnapshotIndexMetaData(repositoryData, snapshotInfo.snapshotId(), indexId);
if (indexMetadata != null) {
int numberOfShards = indexMetadata.getNumberOfShards();
Expand All @@ -339,6 +362,7 @@ private Map<ShardId, IndexShardSnapshotStatus> snapshotShards(final String repos
// could not be taken due to partial being set to false.
shardSnapshotStatus = IndexShardSnapshotStatus.newFailed("skipped");
} else {
ensureNotCancelled(task);
shardSnapshotStatus = repository.getShardSnapshotStatus(
snapshotInfo.snapshotId(),
indexId,
Expand All @@ -352,6 +376,12 @@ private Map<ShardId, IndexShardSnapshotStatus> snapshotShards(final String repos
return unmodifiableMap(shardStatus);
}

private static void ensureNotCancelled(CancellableTask task) {
if (task.isCancelled()) {
throw new TaskCancelledException("task cancelled");
}
}

private static SnapshotShardFailure findShardFailure(List<SnapshotShardFailure> shardFailures, ShardId shardId) {
for (SnapshotShardFailure shardFailure : shardFailures) {
if (shardId.getIndexName().equals(shardFailure.index()) && shardId.getId() == shardFailure.shardId()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestToXContentListener;

import java.io.IOException;
Expand Down Expand Up @@ -52,6 +53,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
snapshotsStatusRequest.ignoreUnavailable(request.paramAsBoolean("ignore_unavailable", snapshotsStatusRequest.ignoreUnavailable()));

snapshotsStatusRequest.masterNodeTimeout(request.paramAsTime("master_timeout", snapshotsStatusRequest.masterNodeTimeout()));
return channel -> client.admin().cluster().snapshotsStatus(snapshotsStatusRequest, new RestToXContentListener<>(channel));
return channel -> new RestCancellableNodeClient(client, request.getHttpChannel())
.admin().cluster().snapshotsStatus(snapshotsStatusRequest, new RestToXContentListener<>(channel));
}
}

0 comments on commit 44f762f

Please sign in to comment.