Skip to content

Commit

Permalink
Make GetSnapshotsAction Cancellable (#72644)
Browse files Browse the repository at this point in the history
If this runs needlessly for large repositories (especially in timeout/retry situations)
it's a significant memory+cpu hit => made it cancellable like we recently did for many
other endpoints.
  • Loading branch information
original-brownbear authored May 4, 2021
1 parent 98db349 commit 70f1e8c
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.http;

import org.apache.http.client.methods.HttpGet;
import org.elasticsearch.action.admin.cluster.snapshots.get.GetSnapshotsAction;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Cancellable;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.snapshots.AbstractSnapshotIntegTestCase;
import org.elasticsearch.snapshots.SnapshotState;
import org.elasticsearch.snapshots.mockstore.MockRepository;
import org.elasticsearch.test.ESIntegTestCase;

import java.util.Collection;
import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.test.TaskAssertions.assertAllCancellableTasksAreCancelled;
import static org.elasticsearch.test.TaskAssertions.assertAllTasksHaveFinished;
import static org.elasticsearch.test.TaskAssertions.awaitTaskWithPrefix;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.core.IsEqual.equalTo;

@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0)
public class RestGetSnapshotsCancellationIT extends HttpSmokeTestCase {

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return CollectionUtils.appendToCopy(super.nodePlugins(), MockRepository.Plugin.class);
}

public void testGetSnapshotsCancellation() throws Exception {
internalCluster().startMasterOnlyNode();
internalCluster().startDataOnlyNode();
ensureStableCluster(2);

final String repoName = "test-repo";
assertAcked(
client().admin().cluster().preparePutRepository(repoName)
.setType("mock").setSettings(Settings.builder().put("location", randomRepoPath())));

final int snapshotCount = randomIntBetween(1, 5);
for (int i = 0; i < snapshotCount; i++) {
assertEquals(
SnapshotState.SUCCESS,
client().admin().cluster().prepareCreateSnapshot(repoName, "snapshot-" + i).setWaitForCompletion(true)
.get().getSnapshotInfo().state()
);
}

final MockRepository repository = AbstractSnapshotIntegTestCase.getRepositoryOnMaster(repoName);
repository.setBlockOnAnyFiles();

final Request request = new Request(HttpGet.METHOD_NAME, "/_snapshot/" + repoName + "/*");
final PlainActionFuture<Void> future = new PlainActionFuture<>();
final Cancellable cancellable = getRestClient().performRequestAsync(request, new ResponseListener() {
@Override
public void onSuccess(Response response) {
future.onResponse(null);
}

@Override
public void onFailure(Exception exception) {
future.onFailure(exception);
}
});

assertThat(future.isDone(), equalTo(false));
awaitTaskWithPrefix(GetSnapshotsAction.NAME);
assertBusy(() -> assertTrue(repository.blocked()), 30L, TimeUnit.SECONDS);
cancellable.cancel();
assertAllCancellableTasksAreCancelled(GetSnapshotsAction.NAME);
repository.unblock();
expectThrows(CancellationException.class, future::actionGet);

assertAllTasksHaveFinished(GetSnapshotsAction.NAME);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.util.Map;

import static org.elasticsearch.action.ValidateActions.addValidationError;

Expand Down Expand Up @@ -174,4 +178,9 @@ public GetSnapshotsRequest verbose(boolean verbose) {
public boolean verbose() {
return verbose;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
import org.elasticsearch.snapshots.SnapshotInfo;
import org.elasticsearch.snapshots.SnapshotMissingException;
import org.elasticsearch.snapshots.SnapshotsService;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

Expand Down Expand Up @@ -80,13 +82,25 @@ protected ClusterBlockException checkBlock(GetSnapshotsRequest request, ClusterS
@Override
protected void masterOperation(Task task, final GetSnapshotsRequest request, final ClusterState state,
final ActionListener<GetSnapshotsResponse> listener) {
getMultipleReposSnapshotInfo(state.custom(SnapshotsInProgress.TYPE, SnapshotsInProgress.EMPTY),
TransportGetRepositoriesAction.getRepositories(state, request.repositories()), request.snapshots(),
request.ignoreUnavailable(), request.verbose(), listener);
assert task instanceof CancellableTask : task + " not cancellable";

getMultipleReposSnapshotInfo(
state.custom(SnapshotsInProgress.TYPE, SnapshotsInProgress.EMPTY),
TransportGetRepositoriesAction.getRepositories(state, request.repositories()),
request.snapshots(),
request.ignoreUnavailable(),
request.verbose(),
(CancellableTask) task,
listener
);
}

private void getMultipleReposSnapshotInfo(SnapshotsInProgress snapshotsInProgress, List<RepositoryMetadata> repos,
String[] snapshots, boolean ignoreUnavailable, boolean verbose,
private void getMultipleReposSnapshotInfo(SnapshotsInProgress snapshotsInProgress,
List<RepositoryMetadata> repos,
String[] snapshots,
boolean ignoreUnavailable,
boolean verbose,
CancellableTask cancellableTask,
ActionListener<GetSnapshotsResponse> listener) {
// short-circuit if there are no repos, because we can not create GroupedActionListener of size 0
if (repos.isEmpty()) {
Expand All @@ -102,7 +116,7 @@ private void getMultipleReposSnapshotInfo(SnapshotsInProgress snapshotsInProgres

for (final RepositoryMetadata repo : repos) {
final String repoName = repo.name();
getSingleRepoSnapshotInfo(snapshotsInProgress, repoName, snapshots, ignoreUnavailable, verbose,
getSingleRepoSnapshotInfo(snapshotsInProgress, repoName, snapshots, ignoreUnavailable, verbose, cancellableTask,
groupedActionListener.delegateResponse((groupedListener, e) -> {
if (e instanceof ElasticsearchException) {
groupedListener.onResponse(GetSnapshotsResponse.Response.error(repoName, (ElasticsearchException) e));
Expand All @@ -114,7 +128,8 @@ private void getMultipleReposSnapshotInfo(SnapshotsInProgress snapshotsInProgres
}

private void getSingleRepoSnapshotInfo(SnapshotsInProgress snapshotsInProgress, String repo, String[] snapshots,
boolean ignoreUnavailable, boolean verbose, ActionListener<List<SnapshotInfo>> listener) {
boolean ignoreUnavailable, boolean verbose, CancellableTask task,
ActionListener<List<SnapshotInfo>> listener) {
final Map<String, SnapshotId> allSnapshotIds = new HashMap<>();
final List<SnapshotInfo> currentSnapshots = new ArrayList<>();
for (SnapshotInfo snapshotInfo : sortedCurrentSnapshots(snapshotsInProgress, repo)) {
Expand All @@ -131,7 +146,7 @@ private void getSingleRepoSnapshotInfo(SnapshotsInProgress snapshotsInProgress,
}

repositoryDataListener.whenComplete(repositoryData -> loadSnapshotInfos(snapshotsInProgress, repo, snapshots,
ignoreUnavailable, verbose, allSnapshotIds, currentSnapshots, repositoryData, listener), listener::onFailure);
ignoreUnavailable, verbose, allSnapshotIds, currentSnapshots, repositoryData, task, listener), listener::onFailure);
}

/**
Expand All @@ -156,7 +171,12 @@ private static List<SnapshotInfo> sortedCurrentSnapshots(SnapshotsInProgress sna
private void loadSnapshotInfos(SnapshotsInProgress snapshotsInProgress, String repo, String[] snapshots,
boolean ignoreUnavailable, boolean verbose, Map<String, SnapshotId> allSnapshotIds,
List<SnapshotInfo> currentSnapshots, @Nullable RepositoryData repositoryData,
ActionListener<List<SnapshotInfo>> listener) {
CancellableTask task, ActionListener<List<SnapshotInfo>> listener) {
if (task.isCancelled()) {
listener.onFailure(new TaskCancelledException("task cancelled"));
return;
}

if (repositoryData != null) {
for (SnapshotId snapshotId : repositoryData.getSnapshotIds()) {
allSnapshotIds.put(snapshotId.getName(), snapshotId);
Expand Down Expand Up @@ -192,7 +212,7 @@ private void loadSnapshotInfos(SnapshotsInProgress snapshotsInProgress, String r

if (verbose) {
threadPool.generic().execute(ActionRunnable.supply(
listener, () -> snapshots(snapshotsInProgress, repo, new ArrayList<>(toResolve), ignoreUnavailable)));
listener, () -> snapshots(snapshotsInProgress, repo, new ArrayList<>(toResolve), ignoreUnavailable, task)));
} else {
final List<SnapshotInfo> snapshotInfos;
if (repositoryData != null) {
Expand All @@ -218,7 +238,10 @@ private void loadSnapshotInfos(SnapshotsInProgress snapshotsInProgress, String r
* @return list of snapshots
*/
private List<SnapshotInfo> snapshots(SnapshotsInProgress snapshotsInProgress, String repositoryName,
List<SnapshotId> snapshotIds, boolean ignoreUnavailable) {
List<SnapshotId> snapshotIds, boolean ignoreUnavailable, CancellableTask task) {
if (task.isCancelled()) {
throw new TaskCancelledException("task cancelled");
}
final Set<SnapshotInfo> snapshotSet = new HashSet<>();
final Set<SnapshotId> snapshotIdsToIterate = new HashSet<>(snapshotIds);
// first, look at the snapshots in progress
Expand All @@ -232,6 +255,9 @@ private List<SnapshotInfo> snapshots(SnapshotsInProgress snapshotsInProgress, St
// then, look in the repository
final Repository repository = repositoriesService.repository(repositoryName);
for (SnapshotId snapshotId : snapshotIdsToIterate) {
if (task.isCancelled()) {
throw new TaskCancelledException("task cancelled");
}
try {
snapshotSet.add(repository.getSnapshotInfo(snapshotId));
} catch (Exception ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.rest.action.RestCancellableNodeClient;

import java.io.IOException;
import java.util.Collections;
Expand Down Expand Up @@ -53,6 +54,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
getSnapshotsRequest.ignoreUnavailable(request.paramAsBoolean("ignore_unavailable", getSnapshotsRequest.ignoreUnavailable()));
getSnapshotsRequest.verbose(request.paramAsBoolean("verbose", getSnapshotsRequest.verbose()));
getSnapshotsRequest.masterNodeTimeout(request.paramAsTime("master_timeout", getSnapshotsRequest.masterNodeTimeout()));
return channel -> client.admin().cluster().getSnapshots(getSnapshotsRequest, new RestToXContentListener<>(channel));
return channel -> new RestCancellableNodeClient(client, request.getHttpChannel()).admin().cluster()
.getSnapshots(getSnapshotsRequest, new RestToXContentListener<>(channel));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public static void blockMasterFromFinalizingSnapshotOnSnapFile(final String repo
}

@SuppressWarnings("unchecked")
protected static <T extends Repository> T getRepositoryOnMaster(String repositoryName) {
public static <T extends Repository> T getRepositoryOnMaster(String repositoryName) {
return ((T) internalCluster().getCurrentMasterNodeInstance(RepositoriesService.class).repository(repositoryName));
}

Expand Down

0 comments on commit 70f1e8c

Please sign in to comment.