Skip to content

Commit

Permalink
Optimize threading in AbstractSearchAsyncAction (elastic#113230) (ela…
Browse files Browse the repository at this point in the history
…stic#115643)

Forking when an action completes on the current thread is needlessly heavy handed
in preventing stack-overflows. Also, we don't need locking/synchronization
to deal with a worker-count + queue length problem. Both of these allow for
non-trivial optimization even in the current execution model, also this change
helps with moving to a more efficient execution model by saving needless forking to
the search pool in particular.
-> refactored the code to never fork but instead avoid stack-depth issues through use
of a `SubscribableListener`
-> replaced our home brew queue and semaphore combination by JDK primitives which
saves blocking synchronization on task start and completion.
  • Loading branch information
original-brownbear authored Oct 25, 2024
1 parent b22b9c7 commit 3cc5796
Showing 1 changed file with 94 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.ShardOperationFailedException;
import org.elasticsearch.action.search.TransportSearchAction.SearchTimeProvider;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.TransportActions;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.shard.ShardId;
Expand All @@ -44,17 +43,19 @@
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.transport.Transport;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;
Expand Down Expand Up @@ -248,7 +249,12 @@ public final void run() {
assert shardRoutings.skip() == false;
assert shardIndexMap.containsKey(shardRoutings);
int shardIndex = shardIndexMap.get(shardRoutings);
performPhaseOnShard(shardIndex, shardRoutings, shardRoutings.nextOrNull());
final SearchShardTarget routing = shardRoutings.nextOrNull();
if (routing == null) {
failOnUnavailable(shardIndex, shardRoutings);
} else {
performPhaseOnShard(shardIndex, shardRoutings, routing);
}
}
}
}
Expand Down Expand Up @@ -283,7 +289,7 @@ private static boolean assertExecuteOnStartThread() {
int index = 0;
assert stackTraceElements[index++].getMethodName().equals("getStackTrace");
assert stackTraceElements[index++].getMethodName().equals("assertExecuteOnStartThread");
assert stackTraceElements[index++].getMethodName().equals("performPhaseOnShard");
assert stackTraceElements[index++].getMethodName().equals("failOnUnavailable");
if (stackTraceElements[index].getMethodName().equals("performPhaseOnShard")) {
assert stackTraceElements[index].getClassName().endsWith("CanMatchPreFilterSearchPhase");
index++;
Expand All @@ -302,65 +308,53 @@ private static boolean assertExecuteOnStartThread() {
}

protected void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) {
/*
* We capture the thread that this phase is starting on. When we are called back after executing the phase, we are either on the
* same thread (because we never went async, or the same thread was selected from the thread pool) or a different thread. If we
* continue on the same thread in the case that we never went async and this happens repeatedly we will end up recursing deeply and
* could stack overflow. To prevent this, we fork if we are called back on the same thread that execution started on and otherwise
* we can continue (cf. InitialSearchPhase#maybeFork).
*/
if (shard == null) {
assert assertExecuteOnStartThread();
SearchShardTarget unassignedShard = new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias());
onShardFailure(shardIndex, unassignedShard, shardIt, new NoShardAvailableActionException(shardIt.shardId()));
if (throttleConcurrentRequests) {
var pendingExecutions = pendingExecutionsPerNode.computeIfAbsent(
shard.getNodeId(),
n -> new PendingExecutions(maxConcurrentRequestsPerNode)
);
pendingExecutions.submit(l -> doPerformPhaseOnShard(shardIndex, shardIt, shard, l));
} else {
final PendingExecutions pendingExecutions = throttleConcurrentRequests
? pendingExecutionsPerNode.computeIfAbsent(shard.getNodeId(), n -> new PendingExecutions(maxConcurrentRequestsPerNode))
: null;
Runnable r = () -> {
final Thread thread = Thread.currentThread();
try {
executePhaseOnShard(shardIt, shard, new SearchActionListener<>(shard, shardIndex) {
@Override
public void innerOnResponse(Result result) {
try {
onShardResult(result, shardIt);
} catch (Exception exc) {
onShardFailure(shardIndex, shard, shardIt, exc);
} finally {
executeNext(pendingExecutions, thread);
}
}
doPerformPhaseOnShard(shardIndex, shardIt, shard, () -> {});
}
}

@Override
public void onFailure(Exception t) {
try {
onShardFailure(shardIndex, shard, shardIt, t);
} finally {
executeNext(pendingExecutions, thread);
}
}
});
} catch (final Exception e) {
try {
/*
* It is possible to run into connection exceptions here because we are getting the connection early and might
* run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
*/
fork(() -> onShardFailure(shardIndex, shard, shardIt, e));
} finally {
executeNext(pendingExecutions, thread);
private void doPerformPhaseOnShard(int shardIndex, SearchShardIterator shardIt, SearchShardTarget shard, Releasable releasable) {
try {
executePhaseOnShard(shardIt, shard, new SearchActionListener<>(shard, shardIndex) {
@Override
public void innerOnResponse(Result result) {
try (releasable) {
onShardResult(result, shardIt);
} catch (Exception exc) {
onShardFailure(shardIndex, shard, shardIt, exc);
}
}
};
if (throttleConcurrentRequests) {
pendingExecutions.tryRun(r);
} else {
r.run();

@Override
public void onFailure(Exception e) {
try (releasable) {
onShardFailure(shardIndex, shard, shardIt, e);
}
}
});
} catch (final Exception e) {
/*
* It is possible to run into connection exceptions here because we are getting the connection early and might
* run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
*/
try (releasable) {
onShardFailure(shardIndex, shard, shardIt, e);
}
}
}

private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) {
assert assertExecuteOnStartThread();
SearchShardTarget unassignedShard = new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias());
onShardFailure(shardIndex, unassignedShard, shardIt, new NoShardAvailableActionException(shardIt.shardId()));
}

/**
* Sends the request to the actual shard.
* @param shardIt the shards iterator
Expand All @@ -373,34 +367,6 @@ protected abstract void executePhaseOnShard(
SearchActionListener<Result> listener
);

protected void fork(final Runnable runnable) {
executor.execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
logger.error(() -> "unexpected error during [" + task + "]", e);
assert false : e;
}

@Override
public void onRejection(Exception e) {
// avoid leaks during node shutdown by executing on the current thread if the executor shuts down
assert e instanceof EsRejectedExecutionException esre && esre.isExecutorShutdown() : e;
doRun();
}

@Override
protected void doRun() {
runnable.run();
}

@Override
public boolean isForceExecution() {
// we can not allow a stuffed queue to reject execution here
return true;
}
});
}

@Override
public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase) {
/* This is the main search phase transition where we move to the next phase. If all shards
Expand Down Expand Up @@ -824,61 +790,63 @@ protected final ShardSearchRequest buildShardSearchRequest(SearchShardIterator s
*/
protected abstract SearchPhase getNextPhase(SearchPhaseResults<Result> results, SearchPhaseContext context);

private void executeNext(PendingExecutions pendingExecutions, Thread originalThread) {
executeNext(pendingExecutions == null ? null : pendingExecutions.finishAndGetNext(), originalThread);
}

void executeNext(Runnable runnable, Thread originalThread) {
if (runnable != null) {
assert throttleConcurrentRequests;
if (originalThread == Thread.currentThread()) {
fork(runnable);
} else {
runnable.run();
}
}
}

private static final class PendingExecutions {
private final int permits;
private int permitsTaken = 0;
private final ArrayDeque<Runnable> queue = new ArrayDeque<>();
private final Semaphore semaphore;
private final LinkedTransferQueue<Consumer<Releasable>> queue = new LinkedTransferQueue<>();

PendingExecutions(int permits) {
assert permits > 0 : "not enough permits: " + permits;
this.permits = permits;
semaphore = new Semaphore(permits);
}

Runnable finishAndGetNext() {
synchronized (this) {
permitsTaken--;
assert permitsTaken >= 0 : "illegal taken permits: " + permitsTaken;
void submit(Consumer<Releasable> task) {
if (semaphore.tryAcquire()) {
executeAndRelease(task);
} else {
queue.add(task);
if (semaphore.tryAcquire()) {
task = pollNextTaskOrReleasePermit();
if (task != null) {
executeAndRelease(task);
}
}
}
return tryQueue(null);

}

void tryRun(Runnable runnable) {
Runnable r = tryQueue(runnable);
if (r != null) {
r.run();
private void executeAndRelease(Consumer<Releasable> task) {
while (task != null) {
final SubscribableListener<Void> onDone = new SubscribableListener<>();
task.accept(() -> onDone.onResponse(null));
if (onDone.isDone()) {
// keep going on the current thread, no need to fork
task = pollNextTaskOrReleasePermit();
} else {
onDone.addListener(new ActionListener<>() {
@Override
public void onResponse(Void unused) {
final Consumer<Releasable> nextTask = pollNextTaskOrReleasePermit();
if (nextTask != null) {
executeAndRelease(nextTask);
}
}

@Override
public void onFailure(Exception e) {
assert false : e;
}
});
return;
}
}
}

private synchronized Runnable tryQueue(Runnable runnable) {
Runnable toExecute = null;
if (permitsTaken < permits) {
permitsTaken++;
toExecute = runnable;
if (toExecute == null) { // only poll if we don't have anything to execute
toExecute = queue.poll();
}
if (toExecute == null) {
permitsTaken--;
}
} else if (runnable != null) {
queue.add(runnable);
private Consumer<Releasable> pollNextTaskOrReleasePermit() {
var task = queue.poll();
if (task == null) {
semaphore.release();
}
return toExecute;
return task;
}
}
}

0 comments on commit 3cc5796

Please sign in to comment.