Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use system context for cluster state update tasks #31241

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ default boolean runOnlyOnMaster() {
/**
* Callback invoked after new cluster state is published. Note that
* this method is not invoked if the cluster state was not updated.
*
* Note that this method will be executed using system context.
*
* @param clusterChangedEvent the change event for this cluster state change, containing
* both old and new states
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.FutureUtils;
import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.threadpool.ThreadPool;

Expand All @@ -60,6 +61,7 @@
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static org.elasticsearch.cluster.service.ClusterService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING;
Expand Down Expand Up @@ -444,26 +446,28 @@ public TimeValue getMaxTaskWaitTime() {
return threadPoolExecutor.getMaxTaskWaitTime();
}

private SafeClusterStateTaskListener safe(ClusterStateTaskListener listener) {
private SafeClusterStateTaskListener safe(ClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> contextSupplier) {
if (listener instanceof AckedClusterStateTaskListener) {
return new SafeAckedClusterStateTaskListener((AckedClusterStateTaskListener) listener, logger);
return new SafeAckedClusterStateTaskListener((AckedClusterStateTaskListener) listener, contextSupplier, logger);
} else {
return new SafeClusterStateTaskListener(listener, logger);
return new SafeClusterStateTaskListener(listener, contextSupplier, logger);
}
}

private static class SafeClusterStateTaskListener implements ClusterStateTaskListener {
private final ClusterStateTaskListener listener;
protected final Supplier<ThreadContext.StoredContext> context;
private final Logger logger;

SafeClusterStateTaskListener(ClusterStateTaskListener listener, Logger logger) {
SafeClusterStateTaskListener(ClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> context, Logger logger) {
this.listener = listener;
this.context = context;
this.logger = logger;
}

@Override
public void onFailure(String source, Exception e) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onFailure(source, e);
} catch (Exception inner) {
inner.addSuppressed(e);
Expand All @@ -474,7 +478,7 @@ public void onFailure(String source, Exception e) {

@Override
public void onNoLongerMaster(String source) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onNoLongerMaster(source);
} catch (Exception e) {
logger.error(() -> new ParameterizedMessage(
Expand All @@ -484,7 +488,7 @@ public void onNoLongerMaster(String source) {

@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.clusterStateProcessed(source, oldState, newState);
} catch (Exception e) {
logger.error(() -> new ParameterizedMessage(
Expand All @@ -498,8 +502,9 @@ private static class SafeAckedClusterStateTaskListener extends SafeClusterStateT
private final AckedClusterStateTaskListener listener;
private final Logger logger;

SafeAckedClusterStateTaskListener(AckedClusterStateTaskListener listener, Logger logger) {
super(listener, logger);
SafeAckedClusterStateTaskListener(AckedClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> context,
Logger logger) {
super(listener, context, logger);
this.listener = listener;
this.logger = logger;
}
Expand All @@ -511,7 +516,7 @@ public boolean mustAck(DiscoveryNode discoveryNode) {

@Override
public void onAllNodesAcked(@Nullable Exception e) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onAllNodesAcked(e);
} catch (Exception inner) {
inner.addSuppressed(e);
Expand All @@ -521,7 +526,7 @@ public void onAllNodesAcked(@Nullable Exception e) {

@Override
public void onAckTimeout() {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onAckTimeout();
} catch (Exception e) {
logger.error("exception thrown by listener while notifying on ack timeout", e);
Expand Down Expand Up @@ -710,9 +715,13 @@ public <T> void submitStateUpdateTasks(final String source,
if (!lifecycle.started()) {
return;
}
try {
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of doing this, I wonder if we should invest in a SystemThreadContext that we give to the executor and it does is under the system context. Alternatively we can have a "no context" executor that simply doesn't set any context. @tvernum do you see any other usage for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand this correctly. Would you have different executors based on whether they switch to the system context or stay in default context? Looking at the other usages of threadContext.markAsSystemContext();, all of them seem to use executors that cannot be categorized as system or non-system (i.e. generic threadpool, ...).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my suggestion referred to the entire pattern of generating a new context and then marking it as system and deal with on the executor level. This has two potential upsides:

  1. You don't need to remember to do this everywhere you submit a task (not so important here as there's just one place).
  2. We don't do all the "clean & capture and then restore in the thread pool" song and dance.

Since this is not a high throughput usage and we only have one place we submit tasks it doesn't have such a big impact on MasterService but might be useful for other things that use an internal thread pool executor.


List<Batcher.UpdateTask> safeTasks = tasks.entrySet().stream()
.map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue()), executor))
.map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue(), supplier), executor))
.collect(Collectors.toList());
taskBatcher.submitTasks(safeTasks, config.timeout());
} catch (EsRejectedExecutionException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
Expand All @@ -33,12 +34,14 @@
import org.elasticsearch.cluster.block.ClusterBlocks;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.BaseFuture;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLogAppender;
import org.elasticsearch.test.junit.annotations.TestLogging;
Expand All @@ -50,6 +53,7 @@
import org.junit.BeforeClass;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -165,6 +169,92 @@ public void onFailure(String source, Exception e) {
nonMaster.close();
}

public void testThreadContext() throws InterruptedException {
final TimedMasterService master = createTimedMasterService(true);
final CountDownLatch latch = new CountDownLatch(1);

try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) {
final Map<String, String> expectedHeaders = Collections.singletonMap("test", "test");
threadPool.getThreadContext().putHeader(expectedHeaders);

final TimeValue ackTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));
final TimeValue masterTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));

master.submitStateUpdateTask("test", new AckedClusterStateUpdateTask<Void>(null, null) {
@Override
public ClusterState execute(ClusterState currentState) {
assertTrue(threadPool.getThreadContext().isSystemContext());
assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getHeaders());

if (randomBoolean()) {
return ClusterState.builder(currentState).build();
} else if (randomBoolean()) {
return currentState;
} else {
throw new IllegalArgumentException("mock failure");
}
}

@Override
public void onFailure(String source, Exception e) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

@Override
public void clusterStatePublished(ClusterChangedEvent clusterChangedEvent) {
assertTrue(threadPool.getThreadContext().isSystemContext());
assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getHeaders());
latch.countDown();
}

@Override
protected Void newResponse(boolean acknowledged) {
return null;
}

public TimeValue ackTimeout() {
return ackTimeout;
}

@Override
public TimeValue timeout() {
return masterTimeout;
}

@Override
public void onAllNodesAcked(@Nullable Exception e) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

@Override
public void onAckTimeout() {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

});

assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
}

latch.await();

master.close();
}

/*
* test that a listener throwing an exception while handling a
* notification does not prevent publication notification to the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public void onFailure(String source, Exception e) {
}

@Override
public void clusterStatePublished(ClusterChangedEvent clusterChangedEvent) {
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
logger.debug("Job [" + jobId + "] is successfully marked as deleted");
listener.onResponse(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ public void onFailure(String source, Exception e) {
}

@Override
public void clusterStatePublished(ClusterChangedEvent clusterChangedEvent) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if ClusterStateUpdateTask should have a final empty method for this? I think it's always the wrong one in that contex.t

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea, I've pushed 1b7eeb4 to implement this.

afterClusterStateUpdate(clusterChangedEvent.state(), request);
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
afterClusterStateUpdate(newState, request);
actionListener.onResponse(new PutJobAction.Response(updatedJob.get()));
}
});
Expand Down