diff --git a/docs/reference/modules/threadpool.asciidoc b/docs/reference/modules/threadpool.asciidoc
index 48280029311a0..0306d2dbb17e7 100644
--- a/docs/reference/modules/threadpool.asciidoc
+++ b/docs/reference/modules/threadpool.asciidoc
@@ -23,6 +23,11 @@ There are several thread pools, but the important ones include:
     Thread pool type is `fixed_auto_queue_size` with a size of `1`, and initial
     queue_size of `100`.
 
+`search_coordination`::
+    For lightweight search-related coordination operations. Thread pool type is
+    `fixed` with a size of a max of `min(5, (`<<node.processors,
+`# of allocated processors`>>`) / 2)`, and queue_size of `1000`.
+
 `get`::
     For get operations. Thread pool type is `fixed`
     with a size of <<node.processors, `# of allocated processors`>>,
diff --git a/qa/ccs-rolling-upgrade-remote-cluster/src/test/java/org/elasticsearch/upgrades/SearchStatesIT.java b/qa/ccs-rolling-upgrade-remote-cluster/src/test/java/org/elasticsearch/upgrades/SearchStatesIT.java
index 9ea49f35b20d5..30c0706d5e621 100644
--- a/qa/ccs-rolling-upgrade-remote-cluster/src/test/java/org/elasticsearch/upgrades/SearchStatesIT.java
+++ b/qa/ccs-rolling-upgrade-remote-cluster/src/test/java/org/elasticsearch/upgrades/SearchStatesIT.java
@@ -89,7 +89,7 @@ static int indexDocs(RestHighLevelClient client, String index, int numDocs) thro
         return numDocs;
     }
 
-    void verifySearch(String localIndex, int localNumDocs, String remoteIndex, int remoteNumDocs) {
+    void verifySearch(String localIndex, int localNumDocs, String remoteIndex, int remoteNumDocs, Integer preFilterShardSize) {
         try (RestHighLevelClient localClient = newLocalClient(LOGGER)) {
             Request request = new Request("POST", "/_search");
             final int expectedDocs;
@@ -103,6 +103,12 @@ void verifySearch(String localIndex, int localNumDocs, String remoteIndex, int r
             if (UPGRADE_FROM_VERSION.onOrAfter(Version.V_7_0_0)) {
                 request.addParameter("ccs_minimize_roundtrips", Boolean.toString(randomBoolean()));
             }
+            if (preFilterShardSize == null && randomBoolean()) {
+                preFilterShardSize = randomIntBetween(1, 100);
+            }
+            if (preFilterShardSize != null) {
+                request.addParameter("pre_filter_shard_size", Integer.toString(preFilterShardSize));
+            }
             int size = between(1, 100);
             request.setJsonEntity("{\"sort\": \"f\", \"size\": " + size + "}");
             Response response = localClient.getLowLevelClient().performRequest(request);
@@ -142,7 +148,32 @@ public void testBWCSearchStates() throws Exception {
             configureRemoteClusters(remoteNodes, CLUSTER_ALIAS, UPGRADE_FROM_VERSION, LOGGER);
             int iterations = between(1, 20);
             for (int i = 0; i < iterations; i++) {
-                verifySearch(localIndex, localNumDocs, CLUSTER_ALIAS + ":" + remoteIndex, remoteNumDocs);
+                verifySearch(localIndex, localNumDocs, CLUSTER_ALIAS + ":" + remoteIndex, remoteNumDocs, null);
+            }
+            localClient.indices().delete(new DeleteIndexRequest(localIndex), RequestOptions.DEFAULT);
+            remoteClient.indices().delete(new DeleteIndexRequest(remoteIndex), RequestOptions.DEFAULT);
+        }
+    }
+
+    public void testCanMatch() throws Exception {
+        String localIndex = "test_can_match_local_index";
+        String remoteIndex = "test_can_match_remote_index";
+        try (RestHighLevelClient localClient = newLocalClient(LOGGER);
+             RestHighLevelClient remoteClient = newRemoteClient()) {
+            localClient.indices().create(new CreateIndexRequest(localIndex)
+                    .settings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(5, 20))),
+                RequestOptions.DEFAULT);
+            int localNumDocs = indexDocs(localClient, localIndex, between(10, 100));
+
+            remoteClient.indices().create(new CreateIndexRequest(remoteIndex)
+                    .settings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(5, 20))),
+                RequestOptions.DEFAULT);
+            int remoteNumDocs = indexDocs(remoteClient, remoteIndex, between(10, 100));
+
+            configureRemoteClusters(getNodes(remoteClient.getLowLevelClient()), CLUSTER_ALIAS, UPGRADE_FROM_VERSION, LOGGER);
+            int iterations = between(1, 10);
+            for (int i = 0; i < iterations; i++) {
+                verifySearch(localIndex, localNumDocs, CLUSTER_ALIAS + ":" + remoteIndex, remoteNumDocs, between(1, 10));
             }
             localClient.indices().delete(new DeleteIndexRequest(localIndex), RequestOptions.DEFAULT);
             remoteClient.indices().delete(new DeleteIndexRequest(remoteIndex), RequestOptions.DEFAULT);
diff --git a/qa/multi-cluster-search/src/test/java/org/elasticsearch/search/CCSDuelIT.java b/qa/multi-cluster-search/src/test/java/org/elasticsearch/search/CCSDuelIT.java
index c2b9dbdb44462..3eb698d3d0c0e 100644
--- a/qa/multi-cluster-search/src/test/java/org/elasticsearch/search/CCSDuelIT.java
+++ b/qa/multi-cluster-search/src/test/java/org/elasticsearch/search/CCSDuelIT.java
@@ -724,7 +724,11 @@ private static void assumeMultiClusterSetup() {
     private static SearchRequest initSearchRequest() {
         List<String> indices = Arrays.asList(INDEX_NAME, "my_remote_cluster:" + INDEX_NAME);
         Collections.shuffle(indices, random());
-        return new SearchRequest(indices.toArray(new String[0]));
+        final SearchRequest request = new SearchRequest(indices.toArray(new String[0]));
+        if (randomBoolean()) {
+            request.setPreFilterShardSize(between(1, 20));
+        }
+        return request;
     }
 
     private static void duelSearch(SearchRequest searchRequest, Consumer<SearchResponse> responseChecker) throws Exception {
diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java
index 3d785c1baf2f3..0f3e0296f96e1 100644
--- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java
+++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java
@@ -23,12 +23,10 @@
 import org.elasticsearch.action.support.TransportActions;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.routing.GroupShardsIterator;
-import org.elasticsearch.core.Releasable;
-import org.elasticsearch.core.Releasables;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
-import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.index.seqno.SequenceNumbers;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.search.SearchContextMissingException;
 import org.elasticsearch.search.SearchPhaseResult;
@@ -65,7 +63,6 @@
  */
 abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> extends SearchPhase implements SearchPhaseContext {
     private static final float DEFAULT_INDEX_BOOST = 1.0f;
-    private static final long[] EMPTY_LONG_ARRAY = new long[0];
     private final Logger logger;
     private final SearchTransportService searchTransportService;
     private final Executor executor;
@@ -736,21 +733,9 @@ public final ShardSearchRequest buildShardSearchRequest(SearchShardIterator shar
         AliasFilter filter = aliasFilter.get(shardIt.shardId().getIndex().getUUID());
         assert filter != null;
         float indexBoost = concreteIndexBoosts.getOrDefault(shardIt.shardId().getIndex().getUUID(), DEFAULT_INDEX_BOOST);
-        final Map<String, long[]> indexToWaitForCheckpoints = request.getWaitForCheckpoints();
-        final TimeValue waitForCheckpointsTimeout = request.getWaitForCheckpointsTimeout();
-        final long[] waitForCheckpoints = indexToWaitForCheckpoints.getOrDefault(shardIt.shardId().getIndex().getName(), EMPTY_LONG_ARRAY);
-
-        long waitForCheckpoint;
-        if (waitForCheckpoints.length == 0) {
-            waitForCheckpoint = SequenceNumbers.UNASSIGNED_SEQ_NO;
-        } else {
-            assert waitForCheckpoints.length > shardIndex;
-            waitForCheckpoint = waitForCheckpoints[shardIndex];
-        }
         ShardSearchRequest shardRequest = new ShardSearchRequest(shardIt.getOriginalIndices(), request,
             shardIt.shardId(), shardIndex, getNumShards(), filter, indexBoost, timeProvider.getAbsoluteStartMillis(),
-            shardIt.getClusterAlias(), shardIt.getSearchContextId(), shardIt.getSearchContextKeepAlive(), waitForCheckpoint,
-            waitForCheckpointsTimeout);
+            shardIt.getClusterAlias(), shardIt.getSearchContextId(), shardIt.getSearchContextKeepAlive());
         // if we already received a search result we can inform the shard that it
         // can return a null response if the request rewrites to match none rather
         // than creating an empty response in the search thread pool.
diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchNodeRequest.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchNodeRequest.java
new file mode 100644
index 0000000000000..f89acaec64c7d
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchNodeRequest.java
@@ -0,0 +1,223 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.search;
+
+import org.elasticsearch.action.IndicesRequest;
+import org.elasticsearch.action.OriginalIndices;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.search.Scroll;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.search.internal.AliasFilter;
+import org.elasticsearch.search.internal.ShardSearchContextId;
+import org.elasticsearch.search.internal.ShardSearchRequest;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
+import org.elasticsearch.transport.TransportRequest;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Node-level request used during can-match phase
+ */
+public class CanMatchNodeRequest extends TransportRequest implements IndicesRequest {
+
+    private final SearchSourceBuilder source;
+    private final List<Shard> shards;
+    private final SearchType searchType;
+    private final String[] types;
+    private final Boolean requestCache;
+    private final boolean allowPartialSearchResults;
+    private final Scroll scroll;
+    private final int numberOfShards;
+    private final long nowInMillis;
+    @Nullable
+    private final String clusterAlias;
+    private final String[] indices;
+    private final IndicesOptions indicesOptions;
+    private final TimeValue waitForCheckpointsTimeout;
+
+    public static class Shard implements Writeable {
+        private final String[] indices;
+        private final ShardId shardId;
+        private final int shardRequestIndex;
+        private final AliasFilter aliasFilter;
+        private final float indexBoost;
+        private final ShardSearchContextId readerId;
+        private final TimeValue keepAlive;
+        private final long waitForCheckpoint;
+
+        public Shard(String[] indices,
+                     ShardId shardId,
+                     int shardRequestIndex,
+                     AliasFilter aliasFilter,
+                     float indexBoost,
+                     ShardSearchContextId readerId,
+                     TimeValue keepAlive,
+                     long waitForCheckpoint) {
+            this.indices = indices;
+            this.shardId = shardId;
+            this.shardRequestIndex = shardRequestIndex;
+            this.aliasFilter = aliasFilter;
+            this.indexBoost = indexBoost;
+            this.readerId = readerId;
+            this.keepAlive = keepAlive;
+            this.waitForCheckpoint = waitForCheckpoint;
+            assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive;
+        }
+
+        public Shard(StreamInput in) throws IOException {
+            indices = in.readStringArray();
+            shardId = new ShardId(in);
+            shardRequestIndex = in.readVInt();
+            aliasFilter = new AliasFilter(in);
+            indexBoost = in.readFloat();
+            readerId = in.readOptionalWriteable(ShardSearchContextId::new);
+            keepAlive = in.readOptionalTimeValue();
+            waitForCheckpoint = in.readLong();
+            assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeStringArray(indices);
+            shardId.writeTo(out);
+            out.writeVInt(shardRequestIndex);
+            aliasFilter.writeTo(out);
+            out.writeFloat(indexBoost);
+            out.writeOptionalWriteable(readerId);
+            out.writeOptionalTimeValue(keepAlive);
+            out.writeLong(waitForCheckpoint);
+        }
+
+        public int getShardRequestIndex() {
+            return shardRequestIndex;
+        }
+
+        public String[] getOriginalIndices() {
+            return indices;
+        }
+
+        public ShardId shardId() {
+            return shardId;
+        }
+    }
+
+    public CanMatchNodeRequest(
+        SearchRequest searchRequest,
+        IndicesOptions indicesOptions,
+        List<Shard> shards,
+        int numberOfShards,
+        long nowInMillis,
+        @Nullable String clusterAlias
+        ) {
+        this.source = searchRequest.source();
+        this.indicesOptions = indicesOptions;
+        this.shards = new ArrayList<>(shards);
+        this.searchType = searchRequest.searchType();
+        this.types = searchRequest.types();
+        this.requestCache = searchRequest.requestCache();
+        // If allowPartialSearchResults is unset (ie null), the cluster-level default should have been substituted
+        // at this stage. Any NPEs in the above are therefore an error in request preparation logic.
+        assert searchRequest.allowPartialSearchResults() != null;
+        this.allowPartialSearchResults = searchRequest.allowPartialSearchResults();
+        this.scroll = searchRequest.scroll();
+        this.numberOfShards = numberOfShards;
+        this.nowInMillis = nowInMillis;
+        this.clusterAlias = clusterAlias;
+        this.waitForCheckpointsTimeout = searchRequest.getWaitForCheckpointsTimeout();
+        indices = shards.stream().map(Shard::getOriginalIndices).flatMap(Arrays::stream).distinct()
+            .toArray(String[]::new);
+    }
+
+    public CanMatchNodeRequest(StreamInput in) throws IOException {
+        super(in);
+        source = in.readOptionalWriteable(SearchSourceBuilder::new);
+        indicesOptions = IndicesOptions.readIndicesOptions(in);
+        searchType = SearchType.fromId(in.readByte());
+        types = in.readStringArray();
+        scroll = in.readOptionalWriteable(Scroll::new);
+        requestCache = in.readOptionalBoolean();
+        allowPartialSearchResults = in.readBoolean();
+        numberOfShards = in.readVInt();
+        nowInMillis = in.readVLong();
+        clusterAlias = in.readOptionalString();
+        waitForCheckpointsTimeout = in.readTimeValue();
+        shards = in.readList(Shard::new);
+        indices = shards.stream().map(Shard::getOriginalIndices).flatMap(Arrays::stream).distinct()
+            .toArray(String[]::new);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+        out.writeOptionalWriteable(source);
+        indicesOptions.writeIndicesOptions(out);
+        out.writeByte(searchType.id());
+        out.writeStringArray(types);
+        out.writeOptionalWriteable(scroll);
+        out.writeOptionalBoolean(requestCache);
+        out.writeBoolean(allowPartialSearchResults);
+        out.writeVInt(numberOfShards);
+        out.writeVLong(nowInMillis);
+        out.writeOptionalString(clusterAlias);
+        out.writeTimeValue(waitForCheckpointsTimeout);
+        out.writeList(shards);
+    }
+
+    public List<Shard> getShardLevelRequests() {
+        return shards;
+    }
+
+    public List<ShardSearchRequest> createShardSearchRequests() {
+        return shards.stream().map(this::createShardSearchRequest).collect(Collectors.toList());
+    }
+
+    public ShardSearchRequest createShardSearchRequest(Shard r) {
+        ShardSearchRequest shardSearchRequest = new ShardSearchRequest(
+            new OriginalIndices(r.indices, indicesOptions), r.shardId, r.shardRequestIndex, numberOfShards, searchType,
+            source, types, requestCache, r.aliasFilter, r.indexBoost, allowPartialSearchResults, scroll,
+            nowInMillis, clusterAlias, r.readerId, r.keepAlive, r.waitForCheckpoint, waitForCheckpointsTimeout
+        );
+        shardSearchRequest.setParentTask(getParentTask());
+        return shardSearchRequest;
+    }
+
+    @Override
+    public String[] indices() {
+        return indices;
+    }
+
+    @Override
+    public IndicesOptions indicesOptions() {
+        return indicesOptions;
+    }
+
+    @Override
+    public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
+        return new SearchShardTask(id, type, action, getDescription(), parentTaskId, headers);
+    }
+
+    @Override
+    public String getDescription() {
+        // Shard id is enough here, the request itself can be found by looking at the parent task description
+        return "shardIds[" + shards.stream().map(slr -> slr.shardId).collect(Collectors.toList()) + "]";
+    }
+
+}
diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchNodeResponse.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchNodeResponse.java
new file mode 100644
index 0000000000000..05aaaa56583ed
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchNodeResponse.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.search;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.search.CanMatchShardResponse;
+import org.elasticsearch.transport.TransportResponse;
+
+import java.io.IOException;
+import java.util.List;
+
+public class CanMatchNodeResponse extends TransportResponse {
+
+    private final List<ResponseOrFailure> responses;
+
+    public CanMatchNodeResponse(StreamInput in) throws IOException {
+        super(in);
+        responses = in.readList(ResponseOrFailure::new);
+    }
+
+    public CanMatchNodeResponse(List<ResponseOrFailure> responses) {
+        this.responses = responses;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeList(responses);
+    }
+
+    public List<ResponseOrFailure> getResponses() {
+        return responses;
+    }
+
+    public static class ResponseOrFailure implements Writeable {
+
+        public ResponseOrFailure(CanMatchShardResponse response) {
+            this.response = response;
+            this.exception = null;
+        }
+
+        public ResponseOrFailure(Exception exception) {
+            this.exception = exception;
+            this.response = null;
+        }
+
+        @Nullable
+        public CanMatchShardResponse getResponse() {
+            return response;
+        }
+
+        @Nullable
+        public Exception getException() {
+            return exception;
+        }
+
+        private final CanMatchShardResponse response;
+        private final Exception exception;
+
+        public ResponseOrFailure(StreamInput in) throws IOException {
+            if (in.readBoolean()) {
+                response = new CanMatchShardResponse(in);
+                exception = null;
+            } else {
+                exception = in.readException();
+                response = null;
+            }
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            final boolean hasResponse = response != null;
+            out.writeBoolean(hasResponse);
+            if (hasResponse) {
+                response.writeTo(out);
+            } else {
+                out.writeException(exception);
+            }
+        }
+    }
+}
diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java
index 97726626abe55..06ff76db421d9 100644
--- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java
+++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java
@@ -5,31 +5,45 @@
  * in compliance with, at your election, the Elastic License 2.0 or the Server
  * Side Public License, v 1.
  */
+
 package org.elasticsearch.action.search;
 
 import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.apache.lucene.util.CollectionUtil;
 import org.apache.lucene.util.FixedBitSet;
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.routing.GroupShardsIterator;
-import org.elasticsearch.core.Releasable;
+import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.common.util.concurrent.CountDown;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.index.query.CoordinatorRewriteContext;
 import org.elasticsearch.index.query.CoordinatorRewriteContextProvider;
+import org.elasticsearch.search.CanMatchShardResponse;
 import org.elasticsearch.search.SearchService;
-import org.elasticsearch.search.SearchService.CanMatchResponse;
 import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.internal.AliasFilter;
+import org.elasticsearch.search.internal.InternalSearchResponse;
+import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.search.sort.FieldSortBuilder;
 import org.elasticsearch.search.sort.MinAndMax;
 import org.elasticsearch.search.sort.SortOrder;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.Transport;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.concurrent.Executor;
+import java.util.concurrent.atomic.AtomicReferenceArray;
 import java.util.function.BiFunction;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -49,152 +63,409 @@
  * sort them according to the provided order. This can be useful for instance to ensure that shards that contain recent
  * data are executed first when sorting by descending timestamp.
  */
-final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMatchResponse> {
+final class CanMatchPreFilterSearchPhase extends SearchPhase {
 
-    private final Function<GroupShardsIterator<SearchShardIterator>, SearchPhase> phaseFactory;
+    private final Logger logger;
+    private final SearchRequest request;
     private final GroupShardsIterator<SearchShardIterator> shardsIts;
+    private final ActionListener<SearchResponse> listener;
+    private final SearchResponse.Clusters clusters;
+    private final TransportSearchAction.SearchTimeProvider timeProvider;
+    private final BiFunction<String, String, Transport.Connection> nodeIdToConnection;
+    private final SearchTransportService searchTransportService;
+    private final Map<SearchShardIterator, Integer> shardItIndexMap;
+    private final Map<String, Float> concreteIndexBoosts;
+    private final Map<String, AliasFilter> aliasFilter;
+    private final SearchTask task;
+    private final Function<GroupShardsIterator<SearchShardIterator>, SearchPhase> phaseFactory;
+    private final Executor executor;
+
+    private final CanMatchSearchPhaseResults results;
     private final CoordinatorRewriteContextProvider coordinatorRewriteContextProvider;
 
+
     CanMatchPreFilterSearchPhase(Logger logger, SearchTransportService searchTransportService,
-                                 BiFunction<String, String, Transport.Connection> nodeIdToConnection,
-                                 Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
-                                 Executor executor, SearchRequest request,
-                                 ActionListener<SearchResponse> listener, GroupShardsIterator<SearchShardIterator> shardsIts,
-                                 TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState,
-                                 SearchTask task, Function<GroupShardsIterator<SearchShardIterator>, SearchPhase> phaseFactory,
-                                 SearchResponse.Clusters clusters, CoordinatorRewriteContextProvider coordinatorRewriteContextProvider) {
-        //We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests
-        super("can_match", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts,
-                executor, request, listener, shardsIts, timeProvider, clusterState, task,
-                new CanMatchSearchPhaseResults(shardsIts.size()), shardsIts.size(), clusters);
-        this.phaseFactory = phaseFactory;
+                                        BiFunction<String, String, Transport.Connection> nodeIdToConnection,
+                                        Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
+                                        Executor executor, SearchRequest request,
+                                        ActionListener<SearchResponse> listener, GroupShardsIterator<SearchShardIterator> shardsIts,
+                                        TransportSearchAction.SearchTimeProvider timeProvider,
+                                        SearchTask task, Function<GroupShardsIterator<SearchShardIterator>, SearchPhase> phaseFactory,
+                                        SearchResponse.Clusters clusters,
+                                        CoordinatorRewriteContextProvider coordinatorRewriteContextProvider) {
+        super("can_match");
+        this.logger = logger;
+        this.searchTransportService = searchTransportService;
+        this.nodeIdToConnection = nodeIdToConnection;
+        this.request = request;
+        this.listener = listener;
         this.shardsIts = shardsIts;
+        this.clusters = clusters;
+        this.timeProvider = timeProvider;
+        this.concreteIndexBoosts = concreteIndexBoosts;
+        this.aliasFilter = aliasFilter;
+        this.task = task;
+        this.phaseFactory = phaseFactory;
         this.coordinatorRewriteContextProvider = coordinatorRewriteContextProvider;
+        this.executor = executor;
+        this.shardItIndexMap = new HashMap<>();
+        results = new CanMatchSearchPhaseResults(shardsIts.size());
+
+        // we compute the shard index based on the natural order of the shards
+        // that participate in the search request. This means that this number is
+        // consistent between two requests that target the same shards.
+        List<SearchShardIterator> naturalOrder = new ArrayList<>();
+        shardsIts.iterator().forEachRemaining(naturalOrder::add);
+        CollectionUtil.timSort(naturalOrder);
+        for (int i = 0; i < naturalOrder.size(); i++) {
+            shardItIndexMap.put(naturalOrder.get(i), i);
+        }
     }
 
-    @Override
-    public void addReleasable(Releasable releasable) {
-        throw new RuntimeException("cannot add releasable in " + getName() + " phase");
+    private static boolean assertSearchCoordinationThread() {
+        assert Thread.currentThread().getName().contains(ThreadPool.Names.SEARCH_COORDINATION) :
+                "not called from the right thread " + Thread.currentThread().getName();
+        return true;
     }
 
     @Override
-    protected void executePhaseOnShard(SearchShardIterator shardIt, SearchShardTarget shard,
-                                       SearchActionListener<CanMatchResponse> listener) {
-        getSearchTransport().sendCanMatch(getConnection(shard.getClusterAlias(), shard.getNodeId()),
-            buildShardSearchRequest(shardIt, listener.requestIndex), getTask(), listener);
+    public void run() throws IOException {
+        assert assertSearchCoordinationThread();
+        checkNoMissingShards();
+        Version version = request.minCompatibleShardNode();
+        if (version != null && Version.CURRENT.minimumCompatibilityVersion().equals(version) == false) {
+            if (checkMinimumVersion(shardsIts) == false) {
+                throw new VersionMismatchException("One of the shards is incompatible with the required minimum version [{}]",
+                    request.minCompatibleShardNode());
+            }
+        }
+
+        runCoordinatorRewritePhase();
     }
 
-    @Override
-    protected SearchPhase getNextPhase(SearchPhaseResults<CanMatchResponse> results, SearchPhaseContext context) {
+    // tries to pre-filter shards based on information that's available to the coordinator
+    // without having to reach out to the actual shards
+    private void runCoordinatorRewritePhase() {
+        assert assertSearchCoordinationThread();
+        final List<SearchShardIterator> matchedShardLevelRequests = new ArrayList<>();
+        for (SearchShardIterator searchShardIterator : shardsIts) {
+            final CanMatchNodeRequest canMatchNodeRequest =
+                new CanMatchNodeRequest(request, searchShardIterator.getOriginalIndices().indicesOptions(),
+                Collections.emptyList(), getNumShards(), timeProvider.getAbsoluteStartMillis(), searchShardIterator.getClusterAlias());
+            final ShardSearchRequest request = canMatchNodeRequest.createShardSearchRequest(buildShardLevelRequest(searchShardIterator));
+            boolean canMatch = true;
+            CoordinatorRewriteContext coordinatorRewriteContext =
+                coordinatorRewriteContextProvider.getCoordinatorRewriteContext(request.shardId().getIndex());
+            if (coordinatorRewriteContext != null) {
+                try {
+                    canMatch = SearchService.queryStillMatchesAfterRewrite(request, coordinatorRewriteContext);
+                } catch (Exception e) {
+                    // treat as if shard is still a potential match
+                }
+            }
+            if (canMatch) {
+                matchedShardLevelRequests.add(searchShardIterator);
+            } else {
+                CanMatchShardResponse result = new CanMatchShardResponse(canMatch, null);
+                result.setShardIndex(request.shardRequestIndex());
+                results.consumeResult(result, () -> {
+                });
+            }
+        }
 
-        return phaseFactory.apply(getIterator((CanMatchSearchPhaseResults) results, shardsIts));
+        if (matchedShardLevelRequests.isEmpty() == false) {
+            new Round(new GroupShardsIterator<>(matchedShardLevelRequests)).run();
+        } else {
+            finishPhase();
+        }
     }
 
-    private GroupShardsIterator<SearchShardIterator> getIterator(CanMatchSearchPhaseResults results,
-                                                                 GroupShardsIterator<SearchShardIterator> shardsIts) {
-        int cardinality = results.getNumPossibleMatches();
-        FixedBitSet possibleMatches = results.getPossibleMatches();
-        if (cardinality == 0) {
-            // this is a special case where we have no hit but we need to get at least one search response in order
-            // to produce a valid search result with all the aggs etc.
-            // Since it's possible that some of the shards that we're skipping are
-            // unavailable, we would try to query the node that at least has some
-            // shards available in order to produce a valid search result.
-            int shardIndexToQuery = 0;
-            for (int i = 0; i < shardsIts.size(); i++) {
-                if (shardsIts.get(i).size() > 0) {
-                    shardIndexToQuery = i;
-                    break;
+    private void checkNoMissingShards() {
+        assert assertSearchCoordinationThread();
+        assert request.allowPartialSearchResults() != null : "SearchRequest missing setting for allowPartialSearchResults";
+        if (request.allowPartialSearchResults() == false) {
+            final StringBuilder missingShards = new StringBuilder();
+            // Fail-fast verification of all shards being available
+            for (int index = 0; index < shardsIts.size(); index++) {
+                final SearchShardIterator shardRoutings = shardsIts.get(index);
+                if (shardRoutings.size() == 0) {
+                    if (missingShards.length() > 0) {
+                        missingShards.append(", ");
+                    }
+                    missingShards.append(shardRoutings.shardId());
                 }
             }
-            possibleMatches.set(shardIndexToQuery);
+            if (missingShards.length() > 0) {
+                //Status red - shard is missing all copies and would produce partial results for an index search
+                final String msg = "Search rejected due to missing shards ["+ missingShards +
+                    "]. Consider using `allow_partial_search_results` setting to bypass this error.";
+                throw new SearchPhaseExecutionException(getName(), msg, null, ShardSearchFailure.EMPTY_ARRAY);
+            }
         }
-        SearchSourceBuilder source = getRequest().source();
-        int i = 0;
-        for (SearchShardIterator iter : shardsIts) {
-            if (possibleMatches.get(i++)) {
-                iter.reset();
+    }
+
+    private Map<SendingTarget, List<SearchShardIterator>> groupByNode(GroupShardsIterator<SearchShardIterator> shards) {
+        Map<SendingTarget, List<SearchShardIterator>> requests = new HashMap<>();
+        for (int i = 0; i < shards.size(); i++) {
+            final SearchShardIterator shardRoutings = shards.get(i);
+            assert shardRoutings.skip() == false;
+            assert shardItIndexMap.containsKey(shardRoutings);
+            SearchShardTarget target = shardRoutings.nextOrNull();
+            if (target != null) {
+                requests.computeIfAbsent(new SendingTarget(target.getClusterAlias(), target.getNodeId()),
+                    t -> new ArrayList<>()).add(shardRoutings);
             } else {
-                iter.resetAndSkip();
+                requests.computeIfAbsent(new SendingTarget(null, null),
+                    t -> new ArrayList<>()).add(shardRoutings);
             }
         }
-        if (shouldSortShards(results.minAndMaxes) == false) {
-            return shardsIts;
-        }
-        FieldSortBuilder fieldSort = FieldSortBuilder.getPrimaryFieldSortOrNull(source);
-        return new GroupShardsIterator<>(sortShards(shardsIts, results.minAndMaxes, fieldSort.order()));
+        return requests;
     }
 
-    @Override
-    protected void performPhaseOnShard(int shardIndex, SearchShardIterator shardIt, SearchShardTarget shard) {
-        CoordinatorRewriteContext coordinatorRewriteContext =
-            coordinatorRewriteContextProvider.getCoordinatorRewriteContext(shardIt.shardId().getIndex());
+    /**
+     * Sending can-match requests is round-based and grouped per target node.
+     * If there are failures during a round, there will be a follow-up round
+     * to retry on other available shard copies.
+     */
+    class Round extends AbstractRunnable {
+        private final GroupShardsIterator<SearchShardIterator> shards;
+        private final CountDown countDown;
+        private final AtomicReferenceArray<Exception> failedResponses;
 
-        if (coordinatorRewriteContext == null) {
-            super.performPhaseOnShard(shardIndex, shardIt, shard);
-            return;
+        Round(GroupShardsIterator<SearchShardIterator> shards) {
+            this.shards = shards;
+            this.countDown = new CountDown(shards.size());
+            this.failedResponses = new AtomicReferenceArray<>(shardsIts.size());
         }
 
-        try {
-            ShardSearchRequest request = buildShardSearchRequest(shardIt, shardIndex);
-            boolean canMatch = SearchService.queryStillMatchesAfterRewrite(request, coordinatorRewriteContext);
+        @Override
+        protected void doRun() {
+            assert assertSearchCoordinationThread();
+            final Map<SendingTarget, List<SearchShardIterator>> requests = groupByNode(shards);
 
-            // Trigger the query as there's still a chance that we can skip
-            // this shard given other query filters that we cannot apply
-            // in the coordinator
-            if (canMatch) {
-                super.performPhaseOnShard(shardIndex, shardIt, shard);
-                return;
+            for (Map.Entry<SendingTarget, List<SearchShardIterator>> entry : requests.entrySet()) {
+                CanMatchNodeRequest canMatchNodeRequest = createCanMatchRequest(entry);
+                List<CanMatchNodeRequest.Shard> shardLevelRequests = canMatchNodeRequest.getShardLevelRequests();
+
+                if (entry.getKey().nodeId == null) {
+                    // no target node: just mark the requests as failed
+                    for (CanMatchNodeRequest.Shard shard : shardLevelRequests) {
+                        onOperationFailed(shard.getShardRequestIndex(), null);
+                    }
+                    continue;
+                }
+
+                try {
+                    searchTransportService.sendCanMatch(getConnection(entry.getKey()), canMatchNodeRequest,
+                        task, new ActionListener<CanMatchNodeResponse>() {
+                            @Override
+                            public void onResponse(CanMatchNodeResponse canMatchNodeResponse) {
+                                assert canMatchNodeResponse.getResponses().size() == canMatchNodeRequest.getShardLevelRequests().size();
+                                for (int i = 0; i < canMatchNodeResponse.getResponses().size(); i++) {
+                                    CanMatchNodeResponse.ResponseOrFailure response = canMatchNodeResponse.getResponses().get(i);
+                                    if (response.getResponse() != null) {
+                                        CanMatchShardResponse shardResponse = response.getResponse();
+                                        shardResponse.setShardIndex(shardLevelRequests.get(i).getShardRequestIndex());
+                                        onOperation(shardResponse.getShardIndex(), shardResponse);
+                                    } else {
+                                        Exception failure = response.getException();
+                                        assert failure != null;
+                                        onOperationFailed(shardLevelRequests.get(i).getShardRequestIndex(), failure);
+                                    }
+                                }
+                            }
+
+                            @Override
+                            public void onFailure(Exception e) {
+                                for (CanMatchNodeRequest.Shard shard : shardLevelRequests) {
+                                    onOperationFailed(shard.getShardRequestIndex(), e);
+                                }
+                            }
+                        }
+                    );
+                } catch (Exception e) {
+                    for (CanMatchNodeRequest.Shard shard : shardLevelRequests) {
+                        onOperationFailed(shard.getShardRequestIndex(), e);
+                    }
+                }
+            }
+        }
+
+        private void onOperation(int idx, CanMatchShardResponse response) {
+            failedResponses.set(idx, null);
+            results.consumeResult(response, () -> {
+                if (countDown.countDown()) {
+                    finishRound();
+                }
+            });
+        }
+
+        private void onOperationFailed(int idx, Exception e) {
+            failedResponses.set(idx, e);
+            results.consumeShardFailure(idx);
+            if (countDown.countDown()) {
+                finishRound();
+            }
+        }
+
+        private void finishRound() {
+            List<SearchShardIterator> remainingShards = new ArrayList<>();
+            for (SearchShardIterator ssi : shards) {
+                int shardIndex = shardItIndexMap.get(ssi);
+                Exception failedResponse = failedResponses.get(shardIndex);
+                if (failedResponse != null) {
+                    remainingShards.add(ssi);
+                }
+            }
+            if (remainingShards.isEmpty()) {
+                finishPhase();
+            } else {
+                // trigger another round, forcing execution
+                executor.execute(new Round(new GroupShardsIterator<>(remainingShards)) {
+                    @Override
+                    public boolean isForceExecution() {
+                        return true;
+                    }
+                });
             }
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            if (logger.isDebugEnabled()) {
+                logger.debug(new ParameterizedMessage("Failed to execute [{}] while running [{}] phase", request, getName()), e);
+            }
+            onPhaseFailure("round", e);
+        }
+    }
 
-            CanMatchResponse result = new CanMatchResponse(canMatch, null);
-            result.setSearchShardTarget(shard == null ? new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias()) : shard);
-            result.setShardIndex(shardIndex);
-            fork(() -> onShardResult(result, shardIt));
+    private static class SendingTarget {
+        @Nullable
+        private final String clusterAlias;
+        @Nullable
+        private final String nodeId;
+
+        SendingTarget(@Nullable String clusterAlias, @Nullable String nodeId) {
+            this.clusterAlias = clusterAlias;
+            this.nodeId = nodeId;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            SendingTarget that = (SendingTarget) o;
+            return Objects.equals(clusterAlias, that.clusterAlias) &&
+                Objects.equals(nodeId, that.nodeId);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(clusterAlias, nodeId);
+        }
+    }
+
+    private CanMatchNodeRequest createCanMatchRequest(Map.Entry<SendingTarget, List<SearchShardIterator>> entry) {
+        final SearchShardIterator first = entry.getValue().get(0);
+        final List<CanMatchNodeRequest.Shard> shardLevelRequests =
+            entry.getValue().stream().map(this::buildShardLevelRequest).collect(Collectors.toCollection(ArrayList::new));
+        assert entry.getValue().stream().allMatch(Objects::nonNull);
+        assert entry.getValue().stream().allMatch(ssi -> Objects.equals(ssi.getOriginalIndices().indicesOptions(),
+            first.getOriginalIndices().indicesOptions()));
+        assert entry.getValue().stream().allMatch(ssi -> Objects.equals(ssi.getClusterAlias(), first.getClusterAlias()));
+        return new CanMatchNodeRequest(request, first.getOriginalIndices().indicesOptions(),
+            shardLevelRequests, getNumShards(), timeProvider.getAbsoluteStartMillis(), first.getClusterAlias());
+    }
+
+    private void finishPhase() {
+        try {
+            phaseFactory.apply(getIterator(results, shardsIts)).start();
         } catch (Exception e) {
-            // If we fail to rewrite it on the coordinator, just try to execute
-            // the query in the shard.
-            super.performPhaseOnShard(shardIndex, shardIt, shard);
+            if (logger.isDebugEnabled()) {
+                logger.debug(new ParameterizedMessage("Failed to execute [{}] while running [{}] phase", request, getName()), e);
+            }
+            onPhaseFailure("finish", e);
         }
     }
 
-    private static List<SearchShardIterator> sortShards(GroupShardsIterator<SearchShardIterator> shardsIts,
-                                                        MinAndMax<?>[] minAndMaxes,
-                                                        SortOrder order) {
-        return IntStream.range(0, shardsIts.size())
-            .boxed()
-            .sorted(shardComparator(shardsIts, minAndMaxes,  order))
-            .map(shardsIts::get)
-            .collect(Collectors.toList());
+    private static final float DEFAULT_INDEX_BOOST = 1.0f;
+
+    public CanMatchNodeRequest.Shard buildShardLevelRequest(SearchShardIterator shardIt) {
+        AliasFilter filter = aliasFilter.get(shardIt.shardId().getIndex().getUUID());
+        assert filter != null;
+        float indexBoost = concreteIndexBoosts.getOrDefault(shardIt.shardId().getIndex().getUUID(), DEFAULT_INDEX_BOOST);
+        int shardRequestIndex = shardItIndexMap.get(shardIt);
+        return new CanMatchNodeRequest.Shard(shardIt.getOriginalIndices().indices(), shardIt.shardId(),
+            shardRequestIndex, filter, indexBoost, shardIt.getSearchContextId(), shardIt.getSearchContextKeepAlive(),
+            ShardSearchRequest.computeWaitForCheckpoint(request.getWaitForCheckpoints(), shardIt.shardId(), shardRequestIndex));
     }
 
-    private static boolean shouldSortShards(MinAndMax<?>[] minAndMaxes) {
-        Class<?> clazz = null;
-        for (MinAndMax<?> minAndMax : minAndMaxes) {
-            if (clazz == null) {
-                clazz = minAndMax == null ? null : minAndMax.getMin().getClass();
-            } else if (minAndMax != null && clazz != minAndMax.getMin().getClass()) {
-                // we don't support sort values that mix different types (e.g.: long/double, numeric/keyword).
-                // TODO: we could fail the request because there is a high probability
-                //  that the merging of topdocs will fail later for the same reason ?
-                return false;
+    private boolean checkMinimumVersion(GroupShardsIterator<SearchShardIterator> shardsIts) {
+        for (SearchShardIterator it : shardsIts) {
+            if (it.getTargetNodeIds().isEmpty() == false) {
+                boolean isCompatible = it.getTargetNodeIds().stream().anyMatch(nodeId -> {
+                    Transport.Connection conn = getConnection(new SendingTarget(it.getClusterAlias(), nodeId));
+                    return conn == null || conn.getVersion().onOrAfter(request.minCompatibleShardNode());
+                });
+                if (isCompatible == false) {
+                    return false;
+                }
             }
         }
-        return clazz != null;
+        return true;
     }
 
-    private static Comparator<Integer> shardComparator(GroupShardsIterator<SearchShardIterator> shardsIts,
-                                                       MinAndMax<?>[] minAndMaxes,
-                                                       SortOrder order) {
-        final Comparator<Integer> comparator = Comparator.comparing(
-            index -> minAndMaxes[index],
-            forciblyCast(MinAndMax.getComparator(order))
-        );
+    @Override
+    public void start() {
+        if (getNumShards() == 0) {
+            //no search shards to search on, bail with empty response
+            //(it happens with search across _all with no indices around and consistent with broadcast operations)
+            int trackTotalHitsUpTo = request.source() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO :
+                request.source().trackTotalHitsUpTo() == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO :
+                    request.source().trackTotalHitsUpTo();
+            // total hits is null in the response if the tracking of total hits is disabled
+            boolean withTotalHits = trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_DISABLED;
+            listener.onResponse(new SearchResponse(InternalSearchResponse.empty(withTotalHits), null, 0, 0,
+                0, timeProvider.buildTookInMillis(), ShardSearchFailure.EMPTY_ARRAY, clusters, null));
+            return;
+        }
+
+        // Note that the search is failed when this task is rejected by the executor
+        executor.execute(new AbstractRunnable() {
+            @Override
+            public void onFailure(Exception e) {
+                if (logger.isDebugEnabled()) {
+                    logger.debug(new ParameterizedMessage("Failed to execute [{}] while running [{}] phase", request, getName()), e);
+                }
+                onPhaseFailure("start", e);
+            }
+
+            @Override
+            protected void doRun() throws IOException {
+                CanMatchPreFilterSearchPhase.this.run();
+            }
+        });
+    }
 
-        return comparator.thenComparing(index -> shardsIts.get(index));
+
+    public void onPhaseFailure(String msg, Exception cause) {
+        listener.onFailure(new SearchPhaseExecutionException(getName(), msg, cause, ShardSearchFailure.EMPTY_ARRAY));
+    }
+
+    public Transport.Connection getConnection(SendingTarget sendingTarget) {
+        Transport.Connection conn = nodeIdToConnection.apply(sendingTarget.clusterAlias, sendingTarget.nodeId);
+        Version minVersion = request.minCompatibleShardNode();
+        if (minVersion != null && conn != null && conn.getVersion().before(minVersion)) {
+            throw new VersionMismatchException("One of the shards is incompatible with the required minimum version [{}]", minVersion);
+        }
+        return conn;
     }
 
-    private static final class CanMatchSearchPhaseResults extends SearchPhaseResults<CanMatchResponse> {
+    private int getNumShards() {
+        return shardsIts.size();
+    }
+
+    private static final class CanMatchSearchPhaseResults extends SearchPhaseResults<CanMatchShardResponse> {
         private final FixedBitSet possibleMatches;
         private final MinAndMax<?>[] minAndMaxes;
         private int numPossibleMatches;
@@ -206,7 +477,7 @@ private static final class CanMatchSearchPhaseResults extends SearchPhaseResults
         }
 
         @Override
-        void consumeResult(CanMatchResponse result, Runnable next) {
+        void consumeResult(CanMatchShardResponse result, Runnable next) {
             try {
                 consumeResult(result.getShardIndex(), result.canMatch(), result.estimatedMinAndMax());
             } finally {
@@ -242,8 +513,81 @@ synchronized FixedBitSet getPossibleMatches() {
         }
 
         @Override
-        Stream<CanMatchResponse> getSuccessfulResults() {
+        Stream<CanMatchShardResponse> getSuccessfulResults() {
             return Stream.empty();
         }
     }
+
+    private GroupShardsIterator<SearchShardIterator> getIterator(CanMatchSearchPhaseResults results,
+                                                                 GroupShardsIterator<SearchShardIterator> shardsIts) {
+        int cardinality = results.getNumPossibleMatches();
+        FixedBitSet possibleMatches = results.getPossibleMatches();
+        if (cardinality == 0) {
+            // this is a special case where we have no hit but we need to get at least one search response in order
+            // to produce a valid search result with all the aggs etc.
+            // Since it's possible that some of the shards that we're skipping are
+            // unavailable, we would try to query the node that at least has some
+            // shards available in order to produce a valid search result.
+            int shardIndexToQuery = 0;
+            for (int i = 0; i < shardsIts.size(); i++) {
+                if (shardsIts.get(i).size() > 0) {
+                    shardIndexToQuery = i;
+                    break;
+                }
+            }
+            possibleMatches.set(shardIndexToQuery);
+        }
+        SearchSourceBuilder source = request.source();
+        int i = 0;
+        for (SearchShardIterator iter : shardsIts) {
+            if (possibleMatches.get(i++)) {
+                iter.reset();
+            } else {
+                iter.resetAndSkip();
+            }
+        }
+        if (shouldSortShards(results.minAndMaxes) == false) {
+            return shardsIts;
+        }
+        FieldSortBuilder fieldSort = FieldSortBuilder.getPrimaryFieldSortOrNull(source);
+        return new GroupShardsIterator<>(sortShards(shardsIts, results.minAndMaxes, fieldSort.order()));
+    }
+
+    private static List<SearchShardIterator> sortShards(GroupShardsIterator<SearchShardIterator> shardsIts,
+                                                        MinAndMax<?>[] minAndMaxes,
+                                                        SortOrder order) {
+        return IntStream.range(0, shardsIts.size())
+            .boxed()
+            .sorted(shardComparator(shardsIts, minAndMaxes,  order))
+            .map(shardsIts::get)
+            .collect(Collectors.toList());
+    }
+
+    private static boolean shouldSortShards(MinAndMax<?>[] minAndMaxes) {
+        Class<?> clazz = null;
+        for (MinAndMax<?> minAndMax : minAndMaxes) {
+            if (clazz == null) {
+                clazz = minAndMax == null ? null : minAndMax.getMin().getClass();
+            } else if (minAndMax != null && clazz != minAndMax.getMin().getClass()) {
+                // we don't support sort values that mix different types (e.g.: long/double, numeric/keyword).
+                // TODO: we could fail the request because there is a high probability
+                //  that the merging of topdocs will fail later for the same reason ?
+                return false;
+            }
+        }
+        return clazz != null;
+    }
+
+    private static Comparator<Integer> shardComparator(GroupShardsIterator<SearchShardIterator> shardsIts,
+                                                       MinAndMax<?>[] minAndMaxes,
+                                                       SortOrder order) {
+        final Comparator<Integer> comparator = Comparator.comparing(
+            index -> minAndMaxes[index],
+            forciblyCast(MinAndMax.getComparator(order))
+        );
+
+        return comparator.thenComparing(shardsIts::get);
+    }
+
 }
+
diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java
index 83f0001972e81..88da2fdfa3a9e 100644
--- a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java
+++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java
@@ -10,6 +10,7 @@
 import org.elasticsearch.core.CheckedRunnable;
 
 import java.io.IOException;
+import java.io.UncheckedIOException;
 import java.util.Objects;
 
 /**
@@ -28,4 +29,12 @@ protected SearchPhase(String name) {
     public String getName() {
         return name;
     }
+
+    public void start() {
+        try {
+            run();
+        } catch (IOException e) {
+            throw new UncheckedIOException(e);
+        }
+    }
 }
diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java
index 41860c52174d4..e54b983ed59da 100644
--- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java
+++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java
@@ -8,6 +8,7 @@
 
 package org.elasticsearch.action.search;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListenerResponseHandler;
 import org.elasticsearch.action.IndicesRequest;
@@ -19,12 +20,14 @@
 import org.elasticsearch.client.OriginSettingClient;
 import org.elasticsearch.client.node.NodeClient;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.core.Nullable;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
+import org.elasticsearch.common.util.concurrent.CountDown;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.search.CanMatchShardResponse;
 import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.dfs.DfsSearchResult;
@@ -51,9 +54,12 @@
 import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.concurrent.atomic.AtomicReferenceArray;
 import java.util.function.BiFunction;
 
 /**
@@ -73,6 +79,7 @@ public class SearchTransportService {
     public static final String FETCH_ID_SCROLL_ACTION_NAME = "indices:data/read/search[phase/fetch/id/scroll]";
     public static final String FETCH_ID_ACTION_NAME = "indices:data/read/search[phase/fetch/id]";
     public static final String QUERY_CAN_MATCH_NAME = "indices:data/read/search[can_match]";
+    public static final String QUERY_CAN_MATCH_NODE_NAME = "indices:data/read/search[can_match][n]";
 
     private final TransportService transportService;
     private final NodeClient client;
@@ -117,9 +124,57 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI
     }
 
     public void sendCanMatch(Transport.Connection connection, final ShardSearchRequest request, SearchTask task, final
-                            ActionListener<SearchService.CanMatchResponse> listener) {
+                            ActionListener<CanMatchShardResponse> listener) {
         transportService.sendChildRequest(connection, QUERY_CAN_MATCH_NAME, request, task,
-            TransportRequestOptions.EMPTY, new ActionListenerResponseHandler<>(listener, SearchService.CanMatchResponse::new));
+            TransportRequestOptions.EMPTY, new ActionListenerResponseHandler<>(listener, CanMatchShardResponse::new));
+    }
+
+    public void sendCanMatch(Transport.Connection connection, final CanMatchNodeRequest request, SearchTask task, final
+                             ActionListener<CanMatchNodeResponse> listener) {
+        if (connection.getVersion().onOrAfter(Version.V_7_16_0) &&
+            connection.getNode().getVersion().onOrAfter(Version.V_7_16_0)) {
+            transportService.sendChildRequest(connection, QUERY_CAN_MATCH_NODE_NAME, request, task,
+                TransportRequestOptions.EMPTY, new ActionListenerResponseHandler<>(listener, CanMatchNodeResponse::new));
+        } else {
+            // BWC layer: translate into shard-level requests
+            final List<ShardSearchRequest> shardSearchRequests = request.createShardSearchRequests();
+            final AtomicReferenceArray<CanMatchNodeResponse.ResponseOrFailure> results =
+                new AtomicReferenceArray<>(shardSearchRequests.size());
+            final CountDown counter = new CountDown(shardSearchRequests.size());
+            final Runnable maybeFinish = () -> {
+                if (counter.countDown()) {
+                    final CanMatchNodeResponse.ResponseOrFailure[] responses =
+                        new CanMatchNodeResponse.ResponseOrFailure[shardSearchRequests.size()];
+                    for (int i = 0; i < responses.length; i++) {
+                        responses[i] = results.get(i);
+                    }
+                    final CanMatchNodeResponse response = new CanMatchNodeResponse(Arrays.asList(responses));
+                    listener.onResponse(response);
+                }
+            };
+            for (int i = 0; i < shardSearchRequests.size(); i++) {
+                final ShardSearchRequest shardSearchRequest = shardSearchRequests.get(i);
+                final int finalI = i;
+                try {
+                    sendCanMatch(connection, shardSearchRequest, task, new ActionListener<CanMatchShardResponse>() {
+                        @Override
+                        public void onResponse(CanMatchShardResponse response) {
+                            results.set(finalI, new CanMatchNodeResponse.ResponseOrFailure(response));
+                            maybeFinish.run();
+                        }
+
+                        @Override
+                        public void onFailure(Exception e) {
+                            results.set(finalI, new CanMatchNodeResponse.ResponseOrFailure(e));
+                            maybeFinish.run();
+                        }
+                    });
+                } catch (Exception e) {
+                    results.set(finalI, new CanMatchNodeResponse.ResponseOrFailure(e));
+                    maybeFinish.run();
+                }
+            }
+        }
     }
 
     public void sendClearAllScrollContexts(Transport.Connection connection, final ActionListener<TransportResponse> listener) {
@@ -363,7 +418,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
             (request, channel, task) -> {
                 searchService.canMatch(request, new ChannelActionListener<>(channel, QUERY_CAN_MATCH_NAME, request));
             });
-        TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, true, SearchService.CanMatchResponse::new);
+        TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, true, CanMatchShardResponse::new);
+
+        transportService.registerRequestHandler(QUERY_CAN_MATCH_NODE_NAME, ThreadPool.Names.SEARCH_COORDINATION, CanMatchNodeRequest::new,
+            (request, channel, task) -> {
+                searchService.canMatch(request, new ChannelActionListener<>(channel, QUERY_CAN_MATCH_NAME, request));
+            });
+        TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NODE_NAME, true, CanMatchNodeResponse::new);
     }
 
 
diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java
index 7293933b6cdf1..16284a883016b 100644
--- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java
+++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java
@@ -106,6 +106,9 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
     public static final Setting<Long> SHARD_COUNT_LIMIT_SETTING = Setting.longSetting(
             "action.search.shard_count.limit", Long.MAX_VALUE, 1L, Property.Dynamic, Property.NodeScope);
 
+    public static final Setting<Integer> DEFAULT_PRE_FILTER_SHARD_SIZE = Setting.intSetting(
+        "action.search.pre_filter_shard_size.default", SearchRequest.DEFAULT_PRE_FILTER_SHARD_SIZE, 1, Property.NodeScope);
+
     private final ThreadPool threadPool;
     private final ClusterService clusterService;
     private final SearchTransportService searchTransportService;
@@ -116,6 +119,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
     private final NamedWriteableRegistry namedWriteableRegistry;
     private final CircuitBreaker circuitBreaker;
     private final ExecutorSelector executorSelector;
+    private final int defaultPreFilterShardSize;
 
     @Inject
     public TransportSearchAction(ThreadPool threadPool,
@@ -141,6 +145,7 @@ public TransportSearchAction(ThreadPool threadPool,
         this.indexNameExpressionResolver = indexNameExpressionResolver;
         this.namedWriteableRegistry = namedWriteableRegistry;
         this.executorSelector = executorSelector;
+        this.defaultPreFilterShardSize = DEFAULT_PRE_FILTER_SHARD_SIZE.get(clusterService.getSettings());
     }
 
     private Map<String, OriginalIndices> buildPerIndexOriginalIndices(ClusterState clusterState,
@@ -750,7 +755,7 @@ private void executeSearch(SearchTask task, SearchTimeProvider timeProvider, Sea
             nodes::get, remoteConnections, searchTransportService::getConnection);
         final Executor asyncSearchExecutor = asyncSearchExecutor(concreteLocalIndices);
         final boolean preFilterSearchShards = shouldPreFilterSearchShards(clusterState, searchRequest, concreteLocalIndices,
-            localShardIterators.size() + remoteShardIterators.size());
+            localShardIterators.size() + remoteShardIterators.size(), defaultPreFilterShardSize);
         searchAsyncActionProvider.asyncSearchAction(
             task, searchRequest, asyncSearchExecutor, shardIterators, timeProvider, connectionLookup, clusterState,
             Collections.unmodifiableMap(aliasFilter), concreteIndexBoosts, listener,
@@ -797,14 +802,15 @@ static BiFunction<String, String, Transport.Connection> buildConnectionLookup(St
     static boolean shouldPreFilterSearchShards(ClusterState clusterState,
                                                SearchRequest searchRequest,
                                                String[] indices,
-                                               int numShards) {
+                                               int numShards,
+                                               int defaultPreFilterShardSize) {
         SearchSourceBuilder source = searchRequest.source();
         Integer preFilterShardSize = searchRequest.getPreFilterShardSize();
         if (preFilterShardSize == null
                 && (hasReadOnlyIndices(indices, clusterState) || hasPrimaryFieldSort(source))) {
             preFilterShardSize = 1;
         } else if (preFilterShardSize == null) {
-            preFilterShardSize = SearchRequest.DEFAULT_PRE_FILTER_SHARD_SIZE;
+            preFilterShardSize = defaultPreFilterShardSize;
         }
         return searchRequest.searchType() == QUERY_THEN_FETCH // we can't do this for DFS it needs to fan out to all shards all the time
                     && (SearchService.canRewriteToMatchNone(source) || hasPrimaryFieldSort(source))
@@ -829,14 +835,14 @@ static GroupShardsIterator<SearchShardIterator> mergeShardsIterators(List<Search
     }
 
     interface SearchAsyncActionProvider {
-        AbstractSearchAsyncAction<? extends SearchPhaseResult> asyncSearchAction(
+        SearchPhase asyncSearchAction(
             SearchTask task, SearchRequest searchRequest, Executor executor, GroupShardsIterator<SearchShardIterator> shardIterators,
             SearchTimeProvider timeProvider, BiFunction<String, String, Transport.Connection> connectionLookup,
             ClusterState clusterState, Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
             ActionListener<SearchResponse> listener, boolean preFilter, ThreadPool threadPool, SearchResponse.Clusters clusters);
     }
 
-    private AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction(
+    private SearchPhase searchAsyncAction(
         SearchTask task,
         SearchRequest searchRequest,
         Executor executor,
@@ -852,9 +858,9 @@ private AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction
         SearchResponse.Clusters clusters) {
         if (preFilter) {
             return new CanMatchPreFilterSearchPhase(logger, searchTransportService, connectionLookup,
-                aliasFilter, concreteIndexBoosts, executor, searchRequest, listener, shardIterators,
-                timeProvider, clusterState, task, (iter) -> {
-                AbstractSearchAsyncAction<? extends SearchPhaseResult> action = searchAsyncAction(
+                aliasFilter, concreteIndexBoosts, threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, listener,
+                shardIterators, timeProvider, task, (iter) -> {
+                SearchPhase action = searchAsyncAction(
                     task,
                     searchRequest,
                     executor,
@@ -868,12 +874,8 @@ private AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction
                     false,
                     threadPool,
                     clusters);
-                return new SearchPhase(action.getName()) {
-                    @Override
-                    public void run() {
-                        action.start();
-                    }
-                };
+                assert action instanceof AbstractSearchAsyncAction;
+                return action;
             }, clusters, searchService.getCoordinatorRewriteContextProvider(timeProvider::getAbsoluteStartMillis));
         } else {
             final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults(executor,
diff --git a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java
index 8730b732f6026..8b271ef85aceb 100644
--- a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java
+++ b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java
@@ -327,6 +327,7 @@ public void apply(Settings value, Settings current, Settings previous) {
             SearchService.DEFAULT_ALLOW_PARTIAL_SEARCH_RESULTS,
             ElectMasterService.DISCOVERY_ZEN_MINIMUM_MASTER_NODES_SETTING,
             TransportSearchAction.SHARD_COUNT_LIMIT_SETTING,
+            TransportSearchAction.DEFAULT_PRE_FILTER_SHARD_SIZE,
             RemoteClusterService.REMOTE_CLUSTER_SKIP_UNAVAILABLE,
             RemoteClusterService.SEARCH_REMOTE_CLUSTER_SKIP_UNAVAILABLE,
             SniffConnectionStrategy.REMOTE_CONNECTIONS_PER_CLUSTER,
diff --git a/server/src/main/java/org/elasticsearch/search/CanMatchShardResponse.java b/server/src/main/java/org/elasticsearch/search/CanMatchShardResponse.java
new file mode 100644
index 0000000000000..0a9c98e0c9013
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/search/CanMatchShardResponse.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.search.sort.MinAndMax;
+
+import java.io.IOException;
+
+/**
+ * Shard-level response for can-match requests
+ */
+public final class CanMatchShardResponse extends SearchPhaseResult {
+    private final boolean canMatch;
+    private final MinAndMax<?> estimatedMinAndMax;
+
+    public CanMatchShardResponse(StreamInput in) throws IOException {
+        super(in);
+        this.canMatch = in.readBoolean();
+        if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
+            estimatedMinAndMax = in.readOptionalWriteable(MinAndMax::new);
+        } else {
+            estimatedMinAndMax = null;
+        }
+    }
+
+    public CanMatchShardResponse(boolean canMatch, MinAndMax<?> estimatedMinAndMax) {
+        this.canMatch = canMatch;
+        this.estimatedMinAndMax = estimatedMinAndMax;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeBoolean(canMatch);
+        if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
+            out.writeOptionalWriteable(estimatedMinAndMax);
+        }
+    }
+
+    public boolean canMatch() {
+        return canMatch;
+    }
+
+    public MinAndMax<?> estimatedMinAndMax() {
+        return estimatedMinAndMax;
+    }
+}
diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java
index 87889bb05f03c..af1c6bc3caed1 100644
--- a/server/src/main/java/org/elasticsearch/search/SearchService.java
+++ b/server/src/main/java/org/elasticsearch/search/SearchService.java
@@ -19,6 +19,8 @@
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionRunnable;
+import org.elasticsearch.action.search.CanMatchNodeRequest;
+import org.elasticsearch.action.search.CanMatchNodeResponse;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchShardTask;
 import org.elasticsearch.action.search.SearchType;
@@ -30,8 +32,6 @@
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.component.AbstractLifecycleComponent;
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Setting.Property;
@@ -121,6 +121,7 @@
 import org.elasticsearch.transport.TransportRequest;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -413,7 +414,7 @@ public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task,
             // check if we can shortcut the query phase entirely.
             if (orig.canReturnNullResponseIfMatchNoDocs()) {
                 assert orig.scroll() == null;
-                final CanMatchResponse canMatchResp;
+                final CanMatchShardResponse canMatchResp;
                 try {
                     ShardSearchRequest clone = new ShardSearchRequest(orig);
                     canMatchResp = canMatch(clone, false);
@@ -421,7 +422,7 @@ public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task,
                     l.onFailure(exc);
                     return;
                 }
-                if (canMatchResp.canMatch == false) {
+                if (canMatchResp.canMatch() == false) {
                     l.onResponse(QuerySearchResult.nullInstance());
                     return;
                 }
@@ -1330,7 +1331,7 @@ public AliasFilter buildAliasFilter(ClusterState state, String index, Set<String
         return indicesService.buildAliasFilter(state, index, resolvedExpressions);
     }
 
-    public void canMatch(ShardSearchRequest request, ActionListener<CanMatchResponse> listener) {
+    public void canMatch(ShardSearchRequest request, ActionListener<CanMatchShardResponse> listener) {
         try {
             listener.onResponse(canMatch(request));
         } catch (IOException e) {
@@ -1338,16 +1339,31 @@ public void canMatch(ShardSearchRequest request, ActionListener<CanMatchResponse
         }
     }
 
+    public void canMatch(CanMatchNodeRequest request, ActionListener<CanMatchNodeResponse> listener) {
+        final List<ShardSearchRequest> shardSearchRequests = request.createShardSearchRequests();
+        final List<CanMatchNodeResponse.ResponseOrFailure> responses = new ArrayList<>(shardSearchRequests.size());
+        for (ShardSearchRequest shardSearchRequest : shardSearchRequests) {
+            CanMatchShardResponse canMatchShardResponse;
+            try {
+                canMatchShardResponse = canMatch(shardSearchRequest);
+                responses.add(new CanMatchNodeResponse.ResponseOrFailure(canMatchShardResponse));
+            } catch (Exception e) {
+                responses.add(new CanMatchNodeResponse.ResponseOrFailure(e));
+            }
+        }
+        listener.onResponse(new CanMatchNodeResponse(responses));
+    }
+
     /**
      * This method uses a lightweight searcher without wrapping (i.e., not open a full reader on frozen indices) to rewrite the query
      * to check if the query can match any documents. This method can have false positives while if it returns {@code false} the query
      * won't match any documents on the current shard.
      */
-    public CanMatchResponse canMatch(ShardSearchRequest request) throws IOException {
+    public CanMatchShardResponse canMatch(ShardSearchRequest request) throws IOException {
         return canMatch(request, true);
     }
 
-    private CanMatchResponse canMatch(ShardSearchRequest request, boolean checkRefreshPending) throws IOException {
+    private CanMatchShardResponse canMatch(ShardSearchRequest request, boolean checkRefreshPending) throws IOException {
         assert request.searchType() == SearchType.QUERY_THEN_FETCH : "unexpected search type: " + request.searchType();
         Releasable releasable = null;
         try {
@@ -1400,7 +1416,7 @@ private CanMatchResponse canMatch(ShardSearchRequest request, boolean checkRefre
                 } else {
                     minMax = null;
                 }
-                return new CanMatchResponse(canMatch || hasRefreshPending, minMax);
+                return new CanMatchShardResponse(canMatch || hasRefreshPending, minMax);
             }
         } finally {
             Releasables.close(releasable);
@@ -1495,42 +1511,6 @@ private static PipelineTree requestToPipelineTree(SearchRequest request) {
         return request.source().aggregations().buildPipelineTree();
     }
 
-    public static final class CanMatchResponse extends SearchPhaseResult {
-        private final boolean canMatch;
-        private final MinAndMax<?> estimatedMinAndMax;
-
-        public CanMatchResponse(StreamInput in) throws IOException {
-            super(in);
-            this.canMatch = in.readBoolean();
-            if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
-                estimatedMinAndMax = in.readOptionalWriteable(MinAndMax::new);
-            } else {
-                estimatedMinAndMax = null;
-            }
-        }
-
-        public CanMatchResponse(boolean canMatch, MinAndMax<?> estimatedMinAndMax) {
-            this.canMatch = canMatch;
-            this.estimatedMinAndMax = estimatedMinAndMax;
-        }
-
-        @Override
-        public void writeTo(StreamOutput out) throws IOException {
-            out.writeBoolean(canMatch);
-            if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
-                out.writeOptionalWriteable(estimatedMinAndMax);
-            }
-        }
-
-        public boolean canMatch() {
-            return canMatch;
-        }
-
-        public MinAndMax<?> estimatedMinAndMax() {
-            return estimatedMinAndMax;
-        }
-    }
-
     /**
      * This helper class ensures we only execute either the success or the failure path for {@link SearchOperationListener}.
      * This is crucial for some implementations like {@link org.elasticsearch.index.search.stats.ShardSearchStats}.
diff --git a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java
index fa8578b660502..e2b5d90861872 100644
--- a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java
+++ b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java
@@ -18,8 +18,6 @@
 import org.elasticsearch.cluster.metadata.AliasMetadata;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.common.CheckedBiConsumer;
-import org.elasticsearch.core.CheckedFunction;
-import org.elasticsearch.core.Nullable;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
@@ -28,14 +26,16 @@
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.CheckedFunction;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.MatchNoneQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryRewriteContext;
-import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.index.query.Rewriteable;
+import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.index.seqno.SequenceNumbers;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.indices.AliasFilterParsingException;
@@ -114,34 +114,6 @@ public ShardSearchRequest(OriginalIndices originalIndices,
                               @Nullable String clusterAlias,
                               ShardSearchContextId readerId,
                               TimeValue keepAlive) {
-        this(originalIndices,
-            searchRequest,
-            shardId,
-            shardRequestIndex,
-            numberOfShards,
-            aliasFilter,
-            indexBoost,
-            nowInMillis,
-            clusterAlias,
-            readerId,
-            keepAlive,
-            SequenceNumbers.UNASSIGNED_SEQ_NO,
-            SearchService.NO_TIMEOUT);
-    }
-
-    public ShardSearchRequest(OriginalIndices originalIndices,
-                              SearchRequest searchRequest,
-                              ShardId shardId,
-                              int shardRequestIndex,
-                              int numberOfShards,
-                              AliasFilter aliasFilter,
-                              float indexBoost,
-                              long nowInMillis,
-                              @Nullable String clusterAlias,
-                              ShardSearchContextId readerId,
-                              TimeValue keepAlive,
-                              long waitForCheckpoint,
-                              TimeValue waitForCheckpointsTimeout) {
         this(originalIndices,
             shardId,
             shardRequestIndex,
@@ -158,13 +130,28 @@ public ShardSearchRequest(OriginalIndices originalIndices,
             clusterAlias,
             readerId,
             keepAlive,
-            waitForCheckpoint,
-            waitForCheckpointsTimeout);
+            computeWaitForCheckpoint(searchRequest.getWaitForCheckpoints(), shardId, shardRequestIndex),
+            searchRequest.getWaitForCheckpointsTimeout());
         // If allowPartialSearchResults is unset (ie null), the cluster-level default should have been substituted
         // at this stage. Any NPEs in the above are therefore an error in request preparation logic.
         assert searchRequest.allowPartialSearchResults() != null;
     }
 
+    private static final long[] EMPTY_LONG_ARRAY = new long[0];
+
+    public static long computeWaitForCheckpoint(Map<String, long[]> indexToWaitForCheckpoints, ShardId shardId, int shardRequestIndex) {
+        final long[] waitForCheckpoints = indexToWaitForCheckpoints.getOrDefault(shardId.getIndex().getName(), EMPTY_LONG_ARRAY);
+
+        long waitForCheckpoint;
+        if (waitForCheckpoints.length == 0) {
+            waitForCheckpoint = SequenceNumbers.UNASSIGNED_SEQ_NO;
+        } else {
+            assert waitForCheckpoints.length > shardRequestIndex;
+            waitForCheckpoint = waitForCheckpoints[shardRequestIndex];
+        }
+        return waitForCheckpoint;
+    }
+
     public ShardSearchRequest(ShardId shardId,
                               String[] types,
                               long nowInMillis,
@@ -173,7 +160,7 @@ public ShardSearchRequest(ShardId shardId,
             aliasFilter, 1.0f, true, null, nowInMillis, null, null, null, SequenceNumbers.UNASSIGNED_SEQ_NO, SearchService.NO_TIMEOUT);
     }
 
-    private ShardSearchRequest(OriginalIndices originalIndices,
+    public ShardSearchRequest(OriginalIndices originalIndices,
                                ShardId shardId,
                                int shardRequestIndex,
                                int numberOfShards,
diff --git a/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java b/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java
index 9bdbf3249a67d..9a13566db7432 100644
--- a/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java
+++ b/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java
@@ -61,6 +61,7 @@ public static class Names {
         public static final String ANALYZE = "analyze";
         public static final String WRITE = "write";
         public static final String SEARCH = "search";
+        public static final String SEARCH_COORDINATION = "search_coordination";
         public static final String AUTO_COMPLETE = "auto_complete";
         public static final String SEARCH_THROTTLED = "search_throttled";
         public static final String MANAGEMENT = "management";
@@ -124,6 +125,7 @@ public static ThreadPoolType fromType(String type) {
         map.put(Names.ANALYZE, ThreadPoolType.FIXED);
         map.put(Names.WRITE, ThreadPoolType.FIXED);
         map.put(Names.SEARCH, ThreadPoolType.FIXED_AUTO_QUEUE_SIZE);
+        map.put(Names.SEARCH_COORDINATION, ThreadPoolType.FIXED);
         map.put(Names.MANAGEMENT, ThreadPoolType.SCALING);
         map.put(Names.FLUSH, ThreadPoolType.SCALING);
         map.put(Names.REFRESH, ThreadPoolType.SCALING);
@@ -197,6 +199,7 @@ public ThreadPool(final Settings settings, final ExecutorBuilder<?>... customBui
         builders.put(Names.ANALYZE, new FixedExecutorBuilder(settings, Names.ANALYZE, 1, 16));
         builders.put(Names.SEARCH, new AutoQueueAdjustingExecutorBuilder(settings,
                         Names.SEARCH, searchThreadPoolSize(allocatedProcessors), 1000, 1000, 1000, 2000));
+        builders.put(Names.SEARCH_COORDINATION, new FixedExecutorBuilder(settings, Names.SEARCH_COORDINATION, halfProcMaxAt5, 1000, false));
         builders.put(Names.SEARCH_THROTTLED, new AutoQueueAdjustingExecutorBuilder(settings,
             Names.SEARCH_THROTTLED, 1, 100, 100, 100, 200));
 
diff --git a/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java b/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java
index ec10eefad63c4..c9dca16824a52 100644
--- a/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java
+++ b/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java
@@ -161,12 +161,7 @@ public static void registerProxyActionWithDynamicResponseType(TransportService s
      */
     public static void registerProxyAction(TransportService service, String action, boolean cancellable,
                                            Writeable.Reader<? extends TransportResponse> reader) {
-        RequestHandlerRegistry<? extends TransportRequest> requestHandler = service.getRequestHandler(action);
-        service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false,
-            in -> cancellable ?
-                new CancellableProxyRequest<>(in, requestHandler::newRequest) :
-                new ProxyRequest<>(in, requestHandler::newRequest),
-            new ProxyRequestHandler<>(service, action, request -> reader));
+        registerProxyActionWithDynamicResponseType(service, action, cancellable, request -> reader);
     }
 
     private static final String PROXY_ACTION_PREFIX = "internal:transport/proxy/";
diff --git a/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java
index 1415e4176b691..901237c690020 100644
--- a/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java
+++ b/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java
@@ -5,12 +5,14 @@
  * in compliance with, at your election, the Elastic License 2.0 or the Server
  * Side Public License, v 1.
  */
+
 package org.elasticsearch.action.search;
 
 import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.OriginalIndices;
+import org.elasticsearch.action.search.CanMatchNodeResponse.ResponseOrFailure;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.DataStream;
@@ -22,8 +24,6 @@
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.util.concurrent.EsExecutors;
-import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.mapper.DateFieldMapper;
 import org.elasticsearch.index.query.AbstractQueryBuilder;
@@ -34,9 +34,7 @@
 import org.elasticsearch.index.shard.IndexLongFieldRange;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardLongFieldRange;
-import org.elasticsearch.search.SearchPhaseResult;
-import org.elasticsearch.search.SearchService;
-import org.elasticsearch.search.SearchShardTarget;
+import org.elasticsearch.search.CanMatchShardResponse;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.internal.AliasFilter;
 import org.elasticsearch.search.internal.ShardSearchRequest;
@@ -44,7 +42,10 @@
 import org.elasticsearch.search.sort.SortBuilders;
 import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.Transport;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -57,8 +58,7 @@
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiConsumer;
 import java.util.stream.Collectors;
@@ -67,12 +67,27 @@
 import static org.elasticsearch.action.search.SearchAsyncActionTests.getShardsIter;
 import static org.elasticsearch.core.Types.forciblyCast;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
 import static org.mockito.Mockito.mock;
 
 public class CanMatchPreFilterSearchPhaseTests extends ESTestCase {
 
     private final CoordinatorRewriteContextProvider EMPTY_CONTEXT_PROVIDER = new StaticCoordinatorRewriteContextProviderBuilder().build();
 
+    private TestThreadPool threadPool;
+
+    @Override
+    public void setUp() throws Exception {
+        super.setUp();
+        threadPool = new TestThreadPool(getTestName());
+    }
+
+    @Override
+    public void tearDown() throws Exception {
+        terminate(threadPool);
+        super.tearDown();
+    }
+
     public void testFilterShards() throws InterruptedException {
 
         final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider(0, System.nanoTime(),
@@ -80,18 +95,26 @@ public void testFilterShards() throws InterruptedException {
 
         Map<String, Transport.Connection> lookup = new ConcurrentHashMap<>();
         DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);
-        DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT);
+        DiscoveryNode replicaNode = randomBoolean() ? null :
+            new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT);
         lookup.put("node1", new SearchAsyncActionTests.MockConnection(primaryNode));
         lookup.put("node2", new SearchAsyncActionTests.MockConnection(replicaNode));
         final boolean shard1 = randomBoolean();
         final boolean shard2 = randomBoolean();
 
+        final AtomicInteger numRequests = new AtomicInteger();
         SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
             @Override
-            public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
-                                     ActionListener<SearchService.CanMatchResponse> listener) {
-                new Thread(() -> listener.onResponse(new SearchService.CanMatchResponse(request.shardId().id() == 0 ? shard1 :
-                    shard2, null))).start();
+            public void sendCanMatch(Transport.Connection connection, CanMatchNodeRequest request, SearchTask task,
+                                     ActionListener<CanMatchNodeResponse> listener) {
+                numRequests.incrementAndGet();
+                final List<ResponseOrFailure> responses = new ArrayList<>();
+                for (CanMatchNodeRequest.Shard shard : request.getShardLevelRequests()) {
+                    responses.add(new ResponseOrFailure(new CanMatchShardResponse(shard.shardId().id() == 0 ? shard1 :
+                        shard2, null)));
+                }
+
+                new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start();
             }
         };
 
@@ -107,18 +130,20 @@ public void sendCanMatch(Transport.Connection connection, ShardSearchRequest req
             searchTransportService,
             (clusterAlias, node) -> lookup.get(node),
             Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)),
-            Collections.emptyMap(), EsExecutors.DIRECT_EXECUTOR_SERVICE,
-            searchRequest, null, shardsIter, timeProvider, ClusterState.EMPTY_STATE, null,
+            Collections.emptyMap(), threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION),
+            searchRequest, null, shardsIter, timeProvider,null,
             (iter) -> new SearchPhase("test") {
-                    @Override
-                    public void run() throws IOException {
-                        result.set(iter);
-                        latch.countDown();
-                    }}, SearchResponse.Clusters.EMPTY, EMPTY_CONTEXT_PROVIDER);
+                @Override
+                public void run() throws IOException {
+                    result.set(iter);
+                    latch.countDown();
+                }}, SearchResponse.Clusters.EMPTY, EMPTY_CONTEXT_PROVIDER);
 
         canMatchPhase.start();
         latch.await();
 
+        assertThat(numRequests.get(), replicaNode == null ? equalTo(1) : lessThanOrEqualTo(2));
+
         if (shard1 && shard2) {
             for (SearchShardIterator i : result.get()) {
                 assertFalse(i.skip());
@@ -143,22 +168,32 @@ public void testFilterWithFailure() throws InterruptedException {
         lookup.put("node1", new SearchAsyncActionTests.MockConnection(primaryNode));
         lookup.put("node2", new SearchAsyncActionTests.MockConnection(replicaNode));
         final boolean shard1 = randomBoolean();
+        final boolean useReplicas = randomBoolean();
+        final boolean fullFailure = randomBoolean();
         SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
             @Override
-            public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
-                                     ActionListener<SearchService.CanMatchResponse> listener) {
-                boolean throwException = request.shardId().id() != 0;
-                if (throwException && randomBoolean()) {
+            public void sendCanMatch(Transport.Connection connection, CanMatchNodeRequest request, SearchTask task,
+                                     ActionListener<CanMatchNodeResponse> listener) {
+                if (fullFailure && randomBoolean()) {
                     throw new IllegalArgumentException("boom");
-                } else {
-                    new Thread(() -> {
-                        if (throwException == false) {
-                            listener.onResponse(new SearchService.CanMatchResponse(shard1, null));
-                        } else {
-                            listener.onFailure(new NullPointerException());
-                        }
-                    }).start();
                 }
+                final List<ResponseOrFailure> responses = new ArrayList<>();
+                for (CanMatchNodeRequest.Shard shard : request.getShardLevelRequests()) {
+                    boolean throwException = shard.shardId().id() != 0;
+                    if (throwException) {
+                        responses.add(new ResponseOrFailure(new NullPointerException()));
+                    } else {
+                        responses.add(new ResponseOrFailure(new CanMatchShardResponse(shard1, null)));
+                    }
+                }
+
+                new Thread(() -> {
+                    if (fullFailure) {
+                        listener.onFailure(new NullPointerException());
+                    } else {
+                        listener.onResponse(new CanMatchNodeResponse(responses));
+                    }
+                }).start();
             }
         };
 
@@ -166,7 +201,7 @@ public void sendCanMatch(Transport.Connection connection, ShardSearchRequest req
         CountDownLatch latch = new CountDownLatch(1);
         GroupShardsIterator<SearchShardIterator> shardsIter = getShardsIter("idx",
             new OriginalIndices(new String[]{"idx"}, SearchRequest.DEFAULT_INDICES_OPTIONS),
-            2, randomBoolean(), primaryNode, replicaNode);
+            2, useReplicas, primaryNode, replicaNode);
 
         final SearchRequest searchRequest = new SearchRequest();
         searchRequest.allowPartialSearchResults(true);
@@ -175,8 +210,8 @@ public void sendCanMatch(Transport.Connection connection, ShardSearchRequest req
             searchTransportService,
             (clusterAlias, node) -> lookup.get(node),
             Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)),
-            Collections.emptyMap(), EsExecutors.DIRECT_EXECUTOR_SERVICE,
-            searchRequest, null, shardsIter, timeProvider, ClusterState.EMPTY_STATE, null,
+            Collections.emptyMap(), threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION),
+            searchRequest, null, shardsIter, timeProvider,null,
             (iter) -> new SearchPhase("test") {
                 @Override
                 public void run() throws IOException {
@@ -189,110 +224,14 @@ public void run() throws IOException {
 
         assertEquals(0, result.get().get(0).shardId().id());
         assertEquals(1, result.get().get(1).shardId().id());
-        assertEquals(shard1, result.get().get(0).skip() == false);
+        if (fullFailure) {
+            assertFalse(result.get().get(0).skip()); // never skip the failure
+        } else {
+            assertEquals(shard1, result.get().get(0).skip() == false);
+        }
         assertFalse(result.get().get(1).skip()); // never skip the failure
     }
 
-    /*
-     * In cases that a query coordinating node held all the shards for a query, the can match phase would recurse and end in stack overflow
-     * when subjected to max concurrent search requests. This test is a test for that situation.
-     */
-    public void testLotsOfShards() throws InterruptedException {
-        final TransportSearchAction.SearchTimeProvider timeProvider =
-            new TransportSearchAction.SearchTimeProvider(0, System.nanoTime(), System::nanoTime);
-
-        final Map<String, Transport.Connection> lookup = new ConcurrentHashMap<>();
-        final DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);
-        final DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT);
-        lookup.put("node1", new SearchAsyncActionTests.MockConnection(primaryNode));
-        lookup.put("node2", new SearchAsyncActionTests.MockConnection(replicaNode));
-
-
-        final SearchTransportService searchTransportService =
-            new SearchTransportService(null, null, null) {
-                @Override
-                public void sendCanMatch(
-                    Transport.Connection connection,
-                    ShardSearchRequest request,
-                    SearchTask task,
-                    ActionListener<SearchService.CanMatchResponse> listener) {
-                    listener.onResponse(new SearchService.CanMatchResponse(randomBoolean(), null));
-                }
-            };
-
-        final CountDownLatch latch = new CountDownLatch(1);
-        final OriginalIndices originalIndices = new OriginalIndices(new String[]{"idx"}, SearchRequest.DEFAULT_INDICES_OPTIONS);
-        final GroupShardsIterator<SearchShardIterator> shardsIter =
-            getShardsIter("idx", originalIndices, 4096, randomBoolean(), primaryNode, replicaNode);
-        final ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors()));
-        final SearchRequest searchRequest = new SearchRequest();
-        searchRequest.allowPartialSearchResults(true);
-        SearchTransportService transportService = new SearchTransportService(null, null, null);
-        ActionListener<SearchResponse> responseListener = ActionListener.wrap(response -> {},
-            (e) -> { throw new AssertionError("unexpected", e);});
-        Map<String, AliasFilter> aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY));
-        final CanMatchPreFilterSearchPhase canMatchPhase = new CanMatchPreFilterSearchPhase(
-            logger,
-            searchTransportService,
-            (clusterAlias, node) -> lookup.get(node),
-            Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)),
-            Collections.emptyMap(),
-            EsExecutors.DIRECT_EXECUTOR_SERVICE,
-            searchRequest,
-            null,
-            shardsIter,
-            timeProvider,
-            ClusterState.EMPTY_STATE,
-            null,
-            (iter) -> new AbstractSearchAsyncAction<SearchPhaseResult>(
-                "test",
-                logger,
-                transportService,
-                (cluster, node) -> {
-                        assert cluster == null : "cluster was not null: " + cluster;
-                        return lookup.get(node);
-                    },
-                aliasFilters,
-                Collections.emptyMap(),
-                executor,
-                searchRequest,
-                responseListener,
-                iter,
-                new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0),
-                ClusterState.EMPTY_STATE,
-                null,
-                new ArraySearchPhaseResults<>(iter.size()),
-                randomIntBetween(1, 32),
-                SearchResponse.Clusters.EMPTY) {
-
-                @Override
-                protected SearchPhase getNextPhase(SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
-                    return new SearchPhase("test") {
-                        @Override
-                        public void run() {
-                            latch.countDown();
-                        }
-                    };
-                }
-
-                @Override
-                protected void executePhaseOnShard(
-                    final SearchShardIterator shardIt,
-                    final SearchShardTarget shard,
-                    final SearchActionListener<SearchPhaseResult> listener) {
-                    if (randomBoolean()) {
-                        listener.onResponse(new SearchPhaseResult() {});
-                    } else {
-                        listener.onFailure(new Exception("failure"));
-                    }
-                }
-            }, SearchResponse.Clusters.EMPTY, EMPTY_CONTEXT_PROVIDER);
-
-        canMatchPhase.start();
-        latch.await();
-        executor.shutdown();
-    }
-
     public void testSortShards() throws InterruptedException {
         final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider(0, System.nanoTime(),
             System::nanoTime);
@@ -310,20 +249,26 @@ public void testSortShards() throws InterruptedException {
 
             SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
                 @Override
-                public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
-                                         ActionListener<SearchService.CanMatchResponse> listener) {
-                    Long min = rarely() ? null : randomLong();
-                    Long max = min == null ? null  : randomLongBetween(min, Long.MAX_VALUE);
-                    MinAndMax<?> minMax = min == null ? null : new MinAndMax<>(min, max);
-                    boolean canMatch = frequently();
-                    synchronized (shardIds) {
-                        shardIds.add(request.shardId());
-                        minAndMaxes.add(minMax);
-                        if (canMatch == false) {
-                            shardToSkip.add(request.shardId());
+                public void sendCanMatch(Transport.Connection connection, CanMatchNodeRequest request, SearchTask task,
+                                         ActionListener<CanMatchNodeResponse> listener) {
+                    final List<ResponseOrFailure> responses = new ArrayList<>();
+                    for (CanMatchNodeRequest.Shard shard : request.getShardLevelRequests()) {
+                        Long min = rarely() ? null : randomLong();
+                        Long max = min == null ? null  : randomLongBetween(min, Long.MAX_VALUE);
+                        MinAndMax<?> minMax = min == null ? null : new MinAndMax<>(min, max);
+                        boolean canMatch = frequently();
+                        synchronized (shardIds) {
+                            shardIds.add(shard.shardId());
+                            minAndMaxes.add(minMax);
+                            if (canMatch == false) {
+                                shardToSkip.add(shard.shardId());
+                            }
                         }
+
+                        responses.add(new ResponseOrFailure(new CanMatchShardResponse(canMatch, minMax)));
                     }
-                    new Thread(() -> listener.onResponse(new SearchService.CanMatchResponse(canMatch, minMax))).start();
+
+                    new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start();
                 }
             };
 
@@ -340,8 +285,8 @@ public void sendCanMatch(Transport.Connection connection, ShardSearchRequest req
                 searchTransportService,
                 (clusterAlias, node) -> lookup.get(node),
                 Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)),
-                Collections.emptyMap(), EsExecutors.DIRECT_EXECUTOR_SERVICE,
-                searchRequest, null, shardsIter, timeProvider, ClusterState.EMPTY_STATE, null,
+                Collections.emptyMap(), threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION),
+                searchRequest, null, shardsIter, timeProvider, null,
                 (iter) -> new SearchPhase("test") {
                     @Override
                     public void run() {
@@ -386,24 +331,29 @@ public void testInvalidSortShards() throws InterruptedException {
 
             SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
                 @Override
-                public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
-                                         ActionListener<SearchService.CanMatchResponse> listener) {
-                    final MinAndMax<?> minMax;
-                    if (request.shardId().id() == numShards-1) {
-                        minMax = new MinAndMax<>(new BytesRef("bar"), new BytesRef("baz"));
-                    } else {
-                        Long min = randomLong();
-                        Long max = randomLongBetween(min, Long.MAX_VALUE);
-                        minMax = new MinAndMax<>(min, max);
-                    }
-                    boolean canMatch = frequently();
-                    synchronized (shardIds) {
-                        shardIds.add(request.shardId());
-                        if (canMatch == false) {
-                            shardToSkip.add(request.shardId());
+                public void sendCanMatch(Transport.Connection connection, CanMatchNodeRequest request, SearchTask task,
+                                         ActionListener<CanMatchNodeResponse> listener) {
+                    final List<ResponseOrFailure> responses = new ArrayList<>();
+                    for (CanMatchNodeRequest.Shard shard : request.getShardLevelRequests()) {
+                        final MinAndMax<?> minMax;
+                        if (shard.shardId().id() == numShards-1) {
+                            minMax = new MinAndMax<>(new BytesRef("bar"), new BytesRef("baz"));
+                        } else {
+                            Long min = randomLong();
+                            Long max = randomLongBetween(min, Long.MAX_VALUE);
+                            minMax = new MinAndMax<>(min, max);
                         }
+                        boolean canMatch = frequently();
+                        synchronized (shardIds) {
+                            shardIds.add(shard.shardId());
+                            if (canMatch == false) {
+                                shardToSkip.add(shard.shardId());
+                            }
+                        }
+                        responses.add(new ResponseOrFailure(new CanMatchShardResponse(canMatch, minMax)));
                     }
-                    new Thread(() -> listener.onResponse(new SearchService.CanMatchResponse(canMatch, minMax))).start();
+
+                    new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start();
                 }
             };
 
@@ -420,8 +370,8 @@ public void sendCanMatch(Transport.Connection connection, ShardSearchRequest req
                 searchTransportService,
                 (clusterAlias, node) -> lookup.get(node),
                 Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)),
-                Collections.emptyMap(), EsExecutors.DIRECT_EXECUTOR_SERVICE,
-                searchRequest, null, shardsIter, timeProvider, ClusterState.EMPTY_STATE, null,
+                Collections.emptyMap(), threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION),
+                searchRequest, null, shardsIter, timeProvider, null,
                 (iter) -> new SearchPhase("test") {
                     @Override
                     public void run() {
@@ -713,10 +663,15 @@ void assignShardsAndExecuteCanMatchPhase(DataStream dataStream,
         final List<ShardSearchRequest> requests = Collections.synchronizedList(new ArrayList<>());
         SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
             @Override
-            public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
-                                     ActionListener<SearchService.CanMatchResponse> listener) {
-                requests.add(request);
-                listener.onResponse(new SearchService.CanMatchResponse(true, null));
+            public void sendCanMatch(Transport.Connection connection, CanMatchNodeRequest request, SearchTask task,
+                                     ActionListener<CanMatchNodeResponse> listener) {
+                final List<ResponseOrFailure> responses = new ArrayList<>();
+                for (CanMatchNodeRequest.Shard shard : request.getShardLevelRequests()) {
+                    requests.add(request.createShardSearchRequest(shard));
+                    responses.add(new ResponseOrFailure(new CanMatchShardResponse(true, null)));
+                }
+
+                new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start();
             }
         };
 
@@ -730,12 +685,11 @@ public void sendCanMatch(Transport.Connection connection, ShardSearchRequest req
             (clusterAlias, node) -> lookup.get(node),
             aliasFilters,
             Collections.emptyMap(),
-            EsExecutors.DIRECT_EXECUTOR_SERVICE,
+            threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION),
             searchRequest,
             null,
             shardsIter,
             timeProvider,
-            ClusterState.EMPTY_STATE,
             null,
             (iter) -> new SearchPhase("test") {
                 @Override
diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java
index c76ef34298f9b..b367deae020fb 100644
--- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java
+++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java
@@ -91,6 +91,7 @@
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
+import static org.elasticsearch.action.search.SearchRequest.DEFAULT_PRE_FILTER_SHARD_SIZE;
 import static org.elasticsearch.test.InternalAggregationTestCase.emptyReduceContextBuilder;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
 import static org.hamcrest.CoreMatchers.containsString;
@@ -883,34 +884,34 @@ public void testShouldPreFilterSearchShards() {
         {
             SearchRequest searchRequest = new SearchRequest();
             assertFalse(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(2, 128)));
+                indices, randomIntBetween(2, 128), DEFAULT_PRE_FILTER_SHARD_SIZE));
             assertFalse(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(129, 10000)));
+                indices, randomIntBetween(129, 10000), DEFAULT_PRE_FILTER_SHARD_SIZE));
         }
         {
             SearchRequest searchRequest = new SearchRequest()
                 .source(new SearchSourceBuilder().query(QueryBuilders.rangeQuery("timestamp")));
             assertFalse(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(2, 128)));
+                indices, randomIntBetween(2, DEFAULT_PRE_FILTER_SHARD_SIZE), DEFAULT_PRE_FILTER_SHARD_SIZE));
             assertTrue(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(129, 10000)));
+                indices, randomIntBetween(DEFAULT_PRE_FILTER_SHARD_SIZE + 1, 10000), DEFAULT_PRE_FILTER_SHARD_SIZE));
         }
         {
             SearchRequest searchRequest = new SearchRequest()
                 .source(new SearchSourceBuilder().sort(SortBuilders.fieldSort("timestamp")));
             assertTrue(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(2, 127)));
+                indices, randomIntBetween(2, DEFAULT_PRE_FILTER_SHARD_SIZE - 1), DEFAULT_PRE_FILTER_SHARD_SIZE));
             assertTrue(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(127, 10000)));
+                indices, randomIntBetween(DEFAULT_PRE_FILTER_SHARD_SIZE - 1, 10000), DEFAULT_PRE_FILTER_SHARD_SIZE));
         }
         {
             SearchRequest searchRequest = new SearchRequest()
                 .source(new SearchSourceBuilder().sort(SortBuilders.fieldSort("timestamp")))
                 .scroll("5m");
             assertTrue(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(2, 128)));
+                indices, randomIntBetween(2, DEFAULT_PRE_FILTER_SHARD_SIZE), DEFAULT_PRE_FILTER_SHARD_SIZE));
             assertTrue(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(129, 10000)));
+                indices, randomIntBetween(DEFAULT_PRE_FILTER_SHARD_SIZE + 1, 10000), DEFAULT_PRE_FILTER_SHARD_SIZE));
         }
     }
 
@@ -933,35 +934,35 @@ public void testShouldPreFilterSearchShardsWithReadOnly() {
         {
             SearchRequest searchRequest = new SearchRequest();
             assertFalse(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(2, 127)));
+                indices, randomIntBetween(2, DEFAULT_PRE_FILTER_SHARD_SIZE - 1), DEFAULT_PRE_FILTER_SHARD_SIZE));
             assertFalse(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(127, 10000)));
+                indices, randomIntBetween(DEFAULT_PRE_FILTER_SHARD_SIZE - 1, 10000), DEFAULT_PRE_FILTER_SHARD_SIZE));
         }
         {
             SearchRequest searchRequest = new SearchRequest()
                 .source(new SearchSourceBuilder().query(QueryBuilders.rangeQuery("timestamp")));
             assertTrue(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(2, 127)));
+                indices, randomIntBetween(2, DEFAULT_PRE_FILTER_SHARD_SIZE - 1), DEFAULT_PRE_FILTER_SHARD_SIZE));
             assertTrue(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(127, 10000)));
+                indices, randomIntBetween(DEFAULT_PRE_FILTER_SHARD_SIZE - 1, 10000), DEFAULT_PRE_FILTER_SHARD_SIZE));
         }
         {
             SearchRequest searchRequest = new SearchRequest()
                 .source(new SearchSourceBuilder().query(QueryBuilders.rangeQuery("timestamp")));
             searchRequest.scroll("5s");
             assertTrue(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(2, 127)));
+                indices, randomIntBetween(2, DEFAULT_PRE_FILTER_SHARD_SIZE - 1), DEFAULT_PRE_FILTER_SHARD_SIZE));
             assertTrue(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(127, 10000)));
+                indices, randomIntBetween(DEFAULT_PRE_FILTER_SHARD_SIZE - 1, 10000), DEFAULT_PRE_FILTER_SHARD_SIZE));
         }
         {
             SearchRequest searchRequest = new SearchRequest()
                 .source(new SearchSourceBuilder().query(QueryBuilders.rangeQuery("timestamp")));
             searchRequest.searchType(SearchType.DFS_QUERY_THEN_FETCH);
             assertFalse(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(2, 127)));
+                indices, randomIntBetween(2, DEFAULT_PRE_FILTER_SHARD_SIZE - 1), DEFAULT_PRE_FILTER_SHARD_SIZE));
             assertFalse(TransportSearchAction.shouldPreFilterSearchShards(clusterState, searchRequest,
-                indices, randomIntBetween(127, 10000)));
+                indices, randomIntBetween(DEFAULT_PRE_FILTER_SHARD_SIZE - 1, 10000), DEFAULT_PRE_FILTER_SHARD_SIZE));
         }
     }
 
diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java
index 902c860a9590f..d6c7a88b61ade 100644
--- a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java
+++ b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java
@@ -8,6 +8,7 @@
 package org.elasticsearch.search;
 
 import com.carrotsearch.hppc.IntArrayList;
+
 import org.apache.lucene.index.DirectoryReader;
 import org.apache.lucene.index.FilterDirectoryReader;
 import org.apache.lucene.index.LeafReader;
@@ -43,8 +44,6 @@
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.lucene.search.Queries;
 import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xcontent.json.JsonXContent;
 import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.Index;
@@ -98,6 +97,8 @@
 import org.elasticsearch.test.ESSingleNodeTestCase;
 import org.elasticsearch.test.VersionUtils;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.json.JsonXContent;
 import org.junit.Before;
 
 import java.io.IOException;
@@ -1409,13 +1410,15 @@ public void onFailure(Exception e) {
         client().clearScroll(clearScrollRequest);
     }
 
-    public void testWaitOnRefresh() throws Exception {
+    public void testWaitOnRefresh() {
         createIndex("index");
         final SearchService service = getInstanceFromNode(SearchService.class);
         final IndicesService indicesService = getInstanceFromNode(IndicesService.class);
         final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index"));
         final IndexShard indexShard = indexService.getShard(0);
         SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
+        searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueSeconds(30));
+        searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] {0}));
 
         final IndexResponse response = client().prepareIndex("index", "_doc").setSource("id", "1").get();
         assertEquals(RestStatus.CREATED, response.status());
@@ -1423,19 +1426,21 @@ public void testWaitOnRefresh() throws Exception {
         SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
         PlainActionFuture<SearchPhaseResult> future = PlainActionFuture.newFuture();
         ShardSearchRequest request = new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1,
-            new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null, null, null, 0, TimeValue.timeValueSeconds(30));
+            new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null, null, null);
         service.executeQueryPhase(request, task, future);
         SearchPhaseResult searchPhaseResult = future.actionGet();
         assertEquals(1, searchPhaseResult.queryResult().getTotalHits().value);
     }
 
-    public void testWaitOnRefreshFailsWithRefreshesDisabled() throws Exception {
+    public void testWaitOnRefreshFailsWithRefreshesDisabled() {
         createIndex("index", Settings.builder().put("index.refresh_interval", "-1").build());
         final SearchService service = getInstanceFromNode(SearchService.class);
         final IndicesService indicesService = getInstanceFromNode(IndicesService.class);
         final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index"));
         final IndexShard indexShard = indexService.getShard(0);
         SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
+        searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueSeconds(30));
+        searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] {0}));
 
         final IndexResponse response = client().prepareIndex("index", "_doc").setSource("id", "1").get();
         assertEquals(RestStatus.CREATED, response.status());
@@ -1443,20 +1448,22 @@ public void testWaitOnRefreshFailsWithRefreshesDisabled() throws Exception {
         SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
         PlainActionFuture<SearchPhaseResult> future = PlainActionFuture.newFuture();
         ShardSearchRequest request = new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1,
-            new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null, null, null, 0, TimeValue.timeValueSeconds(30));
+            new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null, null, null);
         service.executeQueryPhase(request, task, future);
         IllegalArgumentException illegalArgumentException = expectThrows(IllegalArgumentException.class, future::actionGet);
         assertThat(illegalArgumentException.getMessage(),
             containsString("Cannot use wait_for_checkpoints with [index.refresh_interval=-1]"));
     }
 
-    public void testWaitOnRefreshFailsIfCheckpointNotIndexed() throws Exception {
+    public void testWaitOnRefreshFailsIfCheckpointNotIndexed() {
         createIndex("index");
         final SearchService service = getInstanceFromNode(SearchService.class);
         final IndicesService indicesService = getInstanceFromNode(IndicesService.class);
         final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index"));
         final IndexShard indexShard = indexService.getShard(0);
         SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
+        searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueMillis(randomIntBetween(10, 100)));
+        searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] {1}));
 
         final IndexResponse response = client().prepareIndex("index", "_doc").setSource("id", "1").get();
         assertEquals(RestStatus.CREATED, response.status());
@@ -1464,8 +1471,7 @@ public void testWaitOnRefreshFailsIfCheckpointNotIndexed() throws Exception {
         SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
         PlainActionFuture<SearchPhaseResult> future = PlainActionFuture.newFuture();
         ShardSearchRequest request = new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1,
-            new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null, null, null, 1,
-            TimeValue.timeValueMillis(randomIntBetween(10, 100)));
+            new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null, null, null);
         service.executeQueryPhase(request, task, future);
 
         IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, future::actionGet);
@@ -1473,13 +1479,15 @@ public void testWaitOnRefreshFailsIfCheckpointNotIndexed() throws Exception {
             containsString("Cannot wait for unissued seqNo checkpoint [wait_for_checkpoint=1, max_issued_seqNo=0]"));
     }
 
-    public void testWaitOnRefreshTimeout() throws Exception {
+    public void testWaitOnRefreshTimeout() {
         createIndex("index", Settings.builder().put("index.refresh_interval", "60s").build());
         final SearchService service = getInstanceFromNode(SearchService.class);
         final IndicesService indicesService = getInstanceFromNode(IndicesService.class);
         final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index"));
         final IndexShard indexShard = indexService.getShard(0);
         SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
+        searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueMillis(randomIntBetween(10, 100)));
+        searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] {0}));
 
         final IndexResponse response = client().prepareIndex("index", "_doc").setSource("id", "1").get();
         assertEquals(RestStatus.CREATED, response.status());
@@ -1487,8 +1495,7 @@ public void testWaitOnRefreshTimeout() throws Exception {
         SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
         PlainActionFuture<SearchPhaseResult> future = PlainActionFuture.newFuture();
         ShardSearchRequest request = new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1,
-            new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null, null, null, 0,
-            TimeValue.timeValueMillis(randomIntBetween(10, 100)));
+            new AliasFilter(null, Strings.EMPTY_ARRAY), 1.0f, -1, null, null, null);
         service.executeQueryPhase(request, task, future);
 
         ElasticsearchTimeoutException ex = expectThrows(ElasticsearchTimeoutException.class, future::actionGet);
diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java
index 96a6c6d25e034..889cb0dbecc94 100644
--- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java
+++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java
@@ -43,7 +43,9 @@
 import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.action.index.IndexResponse;
 import org.elasticsearch.action.search.ClearScrollResponse;
+import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.search.TransportSearchAction;
 import org.elasticsearch.action.support.DefaultShardOperationFailedException;
 import org.elasticsearch.action.support.DestructiveOperations;
 import org.elasticsearch.action.support.IndicesOptions;
@@ -1876,7 +1878,9 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
             // randomly enable enforcing a default tier_preference to make sure it does not alter results
             .put(DataTier.ENFORCE_DEFAULT_TIER_PREFERENCE_SETTING.getKey(), randomBoolean())
             .putList(DISCOVERY_SEED_HOSTS_SETTING.getKey()) // empty list disables a port scan for other nodes
-            .putList(DISCOVERY_SEED_PROVIDERS_SETTING.getKey(), "file");
+            .putList(DISCOVERY_SEED_PROVIDERS_SETTING.getKey(), "file")
+            .put(TransportSearchAction.DEFAULT_PRE_FILTER_SHARD_SIZE.getKey(), randomFrom(1, 2,
+                SearchRequest.DEFAULT_PRE_FILTER_SHARD_SIZE));
         if (rarely()) {
             // Sometimes adjust the minimum search thread pool size, causing
             // QueueResizingEsThreadPoolExecutor to be used instead of a regular