Skip to content

Commit

Permalink
Make GetSnapshotsAction Cancellable (elastic#72644) (elastic#73820)
Browse files Browse the repository at this point in the history
If this runs needlessly for large repositories (especially in timeout/retry situations)
it's a significant memory+cpu hit => made it cancellable like we recently did for many
other endpoints.
  • Loading branch information
original-brownbear authored Jun 7, 2021
1 parent feae8e9 commit 30da196
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.http;

import org.apache.http.client.methods.HttpGet;
import org.elasticsearch.action.admin.cluster.snapshots.get.GetSnapshotsAction;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Cancellable;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.snapshots.AbstractSnapshotIntegTestCase;
import org.elasticsearch.snapshots.SnapshotState;
import org.elasticsearch.snapshots.mockstore.MockRepository;
import org.elasticsearch.test.ESIntegTestCase;

import java.util.Collection;
import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.test.TaskAssertions.assertAllCancellableTasksAreCancelled;
import static org.elasticsearch.test.TaskAssertions.assertAllTasksHaveFinished;
import static org.elasticsearch.test.TaskAssertions.awaitTaskWithPrefix;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.core.IsEqual.equalTo;

@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0)
public class RestGetSnapshotsCancellationIT extends HttpSmokeTestCase {

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return CollectionUtils.appendToCopy(super.nodePlugins(), MockRepository.Plugin.class);
}

public void testGetSnapshotsCancellation() throws Exception {
internalCluster().startMasterOnlyNode();
internalCluster().startDataOnlyNode();
ensureStableCluster(2);

final String repoName = "test-repo";
assertAcked(
client().admin().cluster().preparePutRepository(repoName)
.setType("mock").setSettings(Settings.builder().put("location", randomRepoPath())));

final int snapshotCount = randomIntBetween(1, 5);
for (int i = 0; i < snapshotCount; i++) {
assertEquals(
SnapshotState.SUCCESS,
client().admin().cluster().prepareCreateSnapshot(repoName, "snapshot-" + i).setWaitForCompletion(true)
.get().getSnapshotInfo().state()
);
}

final MockRepository repository = AbstractSnapshotIntegTestCase.getRepositoryOnMaster(repoName);
repository.setBlockOnAnyFiles();

final Request request = new Request(HttpGet.METHOD_NAME, "/_snapshot/" + repoName + "/*");
final PlainActionFuture<Void> future = new PlainActionFuture<>();
final Cancellable cancellable = getRestClient().performRequestAsync(request, new ResponseListener() {
@Override
public void onSuccess(Response response) {
future.onResponse(null);
}

@Override
public void onFailure(Exception exception) {
future.onFailure(exception);
}
});

assertThat(future.isDone(), equalTo(false));
awaitTaskWithPrefix(GetSnapshotsAction.NAME);
assertBusy(() -> assertTrue(repository.blocked()), 30L, TimeUnit.SECONDS);
cancellable.cancel();
assertAllCancellableTasksAreCancelled(GetSnapshotsAction.NAME);
repository.unblock();
expectThrows(CancellationException.class, future::actionGet);

assertAllTasksHaveFinished(GetSnapshotsAction.NAME);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.util.Map;

import static org.elasticsearch.action.ValidateActions.addValidationError;

Expand Down Expand Up @@ -160,4 +164,9 @@ public GetSnapshotsRequest verbose(boolean verbose) {
public boolean verbose() {
return verbose;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
import org.elasticsearch.snapshots.SnapshotInfo;
import org.elasticsearch.snapshots.SnapshotMissingException;
import org.elasticsearch.snapshots.SnapshotsService;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

Expand Down Expand Up @@ -73,11 +76,19 @@ protected ClusterBlockException checkBlock(GetSnapshotsRequest request, ClusterS
}

@Override
protected void masterOperation(final GetSnapshotsRequest request, final ClusterState state,
protected void masterOperation(GetSnapshotsRequest request, ClusterState state,
ActionListener<GetSnapshotsResponse> listener) throws Exception {
throw new UnsupportedOperationException("The task parameter is required");
}

@Override
protected void masterOperation(final Task task, final GetSnapshotsRequest request, final ClusterState state,
final ActionListener<GetSnapshotsResponse> listener) {
final String repo = request.repository();
final String[] snapshots = request.snapshots();
final SnapshotsInProgress snapshotsInProgress = state.custom(SnapshotsInProgress.TYPE, SnapshotsInProgress.EMPTY);
assert task instanceof CancellableTask : task + " not cancellable";

final Map<String, SnapshotId> allSnapshotIds = new HashMap<>();
final List<SnapshotInfo> currentSnapshots = new ArrayList<>();
for (SnapshotInfo snapshotInfo : sortedCurrentSnapshots(snapshotsInProgress, repo)) {
Expand All @@ -94,7 +105,7 @@ protected void masterOperation(final GetSnapshotsRequest request, final ClusterS
}

repositoryDataListener.whenComplete(repositoryData -> loadSnapshotInfos(snapshotsInProgress, repo, snapshots,
request.ignoreUnavailable(), request.verbose(), allSnapshotIds, currentSnapshots, repositoryData,
request.ignoreUnavailable(), request.verbose(), allSnapshotIds, currentSnapshots, repositoryData, (CancellableTask) task,
listener.map(GetSnapshotsResponse::new)), listener::onFailure);
}

Expand All @@ -120,7 +131,12 @@ private static List<SnapshotInfo> sortedCurrentSnapshots(SnapshotsInProgress sna
private void loadSnapshotInfos(SnapshotsInProgress snapshotsInProgress, String repo, String[] snapshots,
boolean ignoreUnavailable, boolean verbose, Map<String, SnapshotId> allSnapshotIds,
List<SnapshotInfo> currentSnapshots, @Nullable RepositoryData repositoryData,
ActionListener<List<SnapshotInfo>> listener) {
CancellableTask task, ActionListener<List<SnapshotInfo>> listener) {
if (task.isCancelled()) {
listener.onFailure(new TaskCancelledException("task cancelled"));
return;
}

if (repositoryData != null) {
for (SnapshotId snapshotId : repositoryData.getSnapshotIds()) {
allSnapshotIds.put(snapshotId.getName(), snapshotId);
Expand Down Expand Up @@ -156,7 +172,7 @@ private void loadSnapshotInfos(SnapshotsInProgress snapshotsInProgress, String r

if (verbose) {
threadPool.generic().execute(ActionRunnable.supply(
listener, () -> snapshots(snapshotsInProgress, repo, new ArrayList<>(toResolve), ignoreUnavailable)));
listener, () -> snapshots(snapshotsInProgress, repo, new ArrayList<>(toResolve), ignoreUnavailable, task)));
} else {
final List<SnapshotInfo> snapshotInfos;
if (repositoryData != null) {
Expand All @@ -182,7 +198,10 @@ private void loadSnapshotInfos(SnapshotsInProgress snapshotsInProgress, String r
* @return list of snapshots
*/
private List<SnapshotInfo> snapshots(SnapshotsInProgress snapshotsInProgress, String repositoryName,
List<SnapshotId> snapshotIds, boolean ignoreUnavailable) {
List<SnapshotId> snapshotIds, boolean ignoreUnavailable, CancellableTask task) {
if (task.isCancelled()) {
throw new TaskCancelledException("task cancelled");
}
final Set<SnapshotInfo> snapshotSet = new HashSet<>();
final Set<SnapshotId> snapshotIdsToIterate = new HashSet<>(snapshotIds);
// first, look at the snapshots in progress
Expand All @@ -196,6 +215,9 @@ private List<SnapshotInfo> snapshots(SnapshotsInProgress snapshotsInProgress, St
// then, look in the repository
final Repository repository = repositoriesService.repository(repositoryName);
for (SnapshotId snapshotId : snapshotIdsToIterate) {
if (task.isCancelled()) {
throw new TaskCancelledException("task cancelled");
}
try {
snapshotSet.add(repository.getSnapshotInfo(snapshotId));
} catch (Exception ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.rest.action.RestCancellableNodeClient;

import java.io.IOException;
import java.util.Collections;
Expand Down Expand Up @@ -54,6 +55,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
getSnapshotsRequest.ignoreUnavailable(request.paramAsBoolean("ignore_unavailable", getSnapshotsRequest.ignoreUnavailable()));
getSnapshotsRequest.verbose(request.paramAsBoolean("verbose", getSnapshotsRequest.verbose()));
getSnapshotsRequest.masterNodeTimeout(request.paramAsTime("master_timeout", getSnapshotsRequest.masterNodeTimeout()));
return channel -> client.admin().cluster().getSnapshots(getSnapshotsRequest, new RestToXContentListener<>(channel));
return channel -> new RestCancellableNodeClient(client, request.getHttpChannel()).admin().cluster()
.getSnapshots(getSnapshotsRequest, new RestToXContentListener<>(channel));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public static void blockMasterFromFinalizingSnapshotOnSnapFile(final String repo
}

@SuppressWarnings("unchecked")
protected static <T extends Repository> T getRepositoryOnMaster(String repositoryName) {
public static <T extends Repository> T getRepositoryOnMaster(String repositoryName) {
return ((T) internalCluster().getCurrentMasterNodeInstance(RepositoriesService.class).repository(repositoryName));
}

Expand Down

0 comments on commit 30da196

Please sign in to comment.