Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use system context for cluster state update tasks #31241

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ default boolean runOnlyOnMaster() {
/**
* Callback invoked after new cluster state is published. Note that
* this method is not invoked if the cluster state was not updated.
*
* Note that this method will be executed using system context.
*
* @param clusterChangedEvent the change event for this cluster state change, containing
* both old and new states
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ public String describeTasks(List<ClusterStateUpdateTask> tasks) {
*/
public abstract void onFailure(String source, Exception e);

@Override
public final void clusterStatePublished(ClusterChangedEvent clusterChangedEvent) {
// final, empty implementation here as this method should only be defined in combination
// with a batching executor as it will always be executed within the system context.
}

/**
* If the cluster state update task wasn't processed by the provided timeout, call
* {@link ClusterStateTaskListener#onFailure(String, Exception)}. May return null to indicate no timeout is needed (default).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.FutureUtils;
import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.threadpool.ThreadPool;

Expand All @@ -59,6 +60,7 @@
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static org.elasticsearch.cluster.service.ClusterService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING;
Expand Down Expand Up @@ -426,26 +428,28 @@ public TimeValue getMaxTaskWaitTime() {
return threadPoolExecutor.getMaxTaskWaitTime();
}

private SafeClusterStateTaskListener safe(ClusterStateTaskListener listener) {
private SafeClusterStateTaskListener safe(ClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> contextSupplier) {
if (listener instanceof AckedClusterStateTaskListener) {
return new SafeAckedClusterStateTaskListener((AckedClusterStateTaskListener) listener, logger);
return new SafeAckedClusterStateTaskListener((AckedClusterStateTaskListener) listener, contextSupplier, logger);
} else {
return new SafeClusterStateTaskListener(listener, logger);
return new SafeClusterStateTaskListener(listener, contextSupplier, logger);
}
}

private static class SafeClusterStateTaskListener implements ClusterStateTaskListener {
private final ClusterStateTaskListener listener;
protected final Supplier<ThreadContext.StoredContext> context;
private final Logger logger;

SafeClusterStateTaskListener(ClusterStateTaskListener listener, Logger logger) {
SafeClusterStateTaskListener(ClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> context, Logger logger) {
this.listener = listener;
this.context = context;
this.logger = logger;
}

@Override
public void onFailure(String source, Exception e) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onFailure(source, e);
} catch (Exception inner) {
inner.addSuppressed(e);
Expand All @@ -456,7 +460,7 @@ public void onFailure(String source, Exception e) {

@Override
public void onNoLongerMaster(String source) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onNoLongerMaster(source);
} catch (Exception e) {
logger.error(() -> new ParameterizedMessage(
Expand All @@ -466,7 +470,7 @@ public void onNoLongerMaster(String source) {

@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.clusterStateProcessed(source, oldState, newState);
} catch (Exception e) {
logger.error(() -> new ParameterizedMessage(
Expand All @@ -480,8 +484,9 @@ private static class SafeAckedClusterStateTaskListener extends SafeClusterStateT
private final AckedClusterStateTaskListener listener;
private final Logger logger;

SafeAckedClusterStateTaskListener(AckedClusterStateTaskListener listener, Logger logger) {
super(listener, logger);
SafeAckedClusterStateTaskListener(AckedClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> context,
Logger logger) {
super(listener, context, logger);
this.listener = listener;
this.logger = logger;
}
Expand All @@ -493,7 +498,7 @@ public boolean mustAck(DiscoveryNode discoveryNode) {

@Override
public void onAllNodesAcked(@Nullable Exception e) {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onAllNodesAcked(e);
} catch (Exception inner) {
inner.addSuppressed(e);
Expand All @@ -503,7 +508,7 @@ public void onAllNodesAcked(@Nullable Exception e) {

@Override
public void onAckTimeout() {
try {
try (ThreadContext.StoredContext ignore = context.get()) {
listener.onAckTimeout();
} catch (Exception e) {
logger.error("exception thrown by listener while notifying on ack timeout", e);
Expand Down Expand Up @@ -724,9 +729,13 @@ public <T> void submitStateUpdateTasks(final String source,
if (!lifecycle.started()) {
return;
}
try {
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of doing this, I wonder if we should invest in a SystemThreadContext that we give to the executor and it does is under the system context. Alternatively we can have a "no context" executor that simply doesn't set any context. @tvernum do you see any other usage for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand this correctly. Would you have different executors based on whether they switch to the system context or stay in default context? Looking at the other usages of threadContext.markAsSystemContext();, all of them seem to use executors that cannot be categorized as system or non-system (i.e. generic threadpool, ...).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my suggestion referred to the entire pattern of generating a new context and then marking it as system and deal with on the executor level. This has two potential upsides:

  1. You don't need to remember to do this everywhere you submit a task (not so important here as there's just one place).
  2. We don't do all the "clean & capture and then restore in the thread pool" song and dance.

Since this is not a high throughput usage and we only have one place we submit tasks it doesn't have such a big impact on MasterService but might be useful for other things that use an internal thread pool executor.


List<Batcher.UpdateTask> safeTasks = tasks.entrySet().stream()
.map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue()), executor))
.map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue(), supplier), executor))
.collect(Collectors.toList());
taskBatcher.submitTasks(safeTasks, config.timeout());
} catch (EsRejectedExecutionException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,6 @@ public ClusterStateResponse newInstance() {

@Override
public void handleResponse(ClusterStateResponse response) {
assert transportService.getThreadPool().getThreadContext().isSystemContext() == false : "context is a system context";
try {
if (remoteClusterName.get() == null) {
assert response.getClusterName().value() != null;
Expand Down Expand Up @@ -574,7 +573,6 @@ public void handleResponse(ClusterStateResponse response) {

@Override
public void handleException(TransportException exp) {
assert transportService.getThreadPool().getThreadContext().isSystemContext() == false : "context is a system context";
logger.warn(() -> new ParameterizedMessage("fetching nodes from external cluster {} failed", clusterAlias), exp);
try {
IOUtils.closeWhileHandlingException(connection);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
import org.elasticsearch.cluster.block.ClusterBlocks;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.BaseFuture;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLogAppender;
Expand All @@ -52,6 +54,7 @@
import org.junit.BeforeClass;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -168,6 +171,85 @@ public void onFailure(String source, Exception e) {
nonMaster.close();
}

public void testThreadContext() throws InterruptedException {
final TimedMasterService master = createTimedMasterService(true);
final CountDownLatch latch = new CountDownLatch(1);

try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) {
final Map<String, String> expectedHeaders = Collections.singletonMap("test", "test");
threadPool.getThreadContext().putHeader(expectedHeaders);

final TimeValue ackTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));
final TimeValue masterTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));

master.submitStateUpdateTask("test", new AckedClusterStateUpdateTask<Void>(null, null) {
@Override
public ClusterState execute(ClusterState currentState) {
assertTrue(threadPool.getThreadContext().isSystemContext());
assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getHeaders());

if (randomBoolean()) {
return ClusterState.builder(currentState).build();
} else if (randomBoolean()) {
return currentState;
} else {
throw new IllegalArgumentException("mock failure");
}
}

@Override
public void onFailure(String source, Exception e) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

@Override
protected Void newResponse(boolean acknowledged) {
return null;
}

public TimeValue ackTimeout() {
return ackTimeout;
}

@Override
public TimeValue timeout() {
return masterTimeout;
}

@Override
public void onAllNodesAcked(@Nullable Exception e) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

@Override
public void onAckTimeout() {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
latch.countDown();
}

});

assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
}

latch.await();

master.close();
}

/*
* test that a listener throwing an exception while handling a
* notification does not prevent publication notification to the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand Down Expand Up @@ -293,7 +292,7 @@ public Builder deleteJob(String jobId, PersistentTasksCustomMetaData tasks) {
return this;
}

public Builder putDatafeed(DatafeedConfig datafeedConfig, ThreadContext threadContext) {
public Builder putDatafeed(DatafeedConfig datafeedConfig, Map<String, String> headers) {
if (datafeeds.containsKey(datafeedConfig.getId())) {
throw new ResourceAlreadyExistsException("A datafeed with id [" + datafeedConfig.getId() + "] already exists");
}
Expand All @@ -302,13 +301,13 @@ public Builder putDatafeed(DatafeedConfig datafeedConfig, ThreadContext threadCo
Job job = jobs.get(jobId);
DatafeedJobValidator.validate(datafeedConfig, job);

if (threadContext != null) {
if (headers.isEmpty() == false) {
// Adjust the request, adding security headers from the current thread context
DatafeedConfig.Builder builder = new DatafeedConfig.Builder(datafeedConfig);
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
Map<String, String> securityHeaders = headers.entrySet().stream()
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(headers);
builder.setHeaders(securityHeaders);
datafeedConfig = builder.build();
}

Expand All @@ -328,15 +327,15 @@ private void checkJobIsAvailableForDatafeed(String jobId) {
}
}

public Builder updateDatafeed(DatafeedUpdate update, PersistentTasksCustomMetaData persistentTasks, ThreadContext threadContext) {
public Builder updateDatafeed(DatafeedUpdate update, PersistentTasksCustomMetaData persistentTasks, Map<String, String> headers) {
String datafeedId = update.getId();
DatafeedConfig oldDatafeedConfig = datafeeds.get(datafeedId);
if (oldDatafeedConfig == null) {
throw ExceptionsHelper.missingDatafeedException(datafeedId);
}
checkDatafeedIsStopped(() -> Messages.getMessage(Messages.DATAFEED_CANNOT_UPDATE_IN_CURRENT_STATE, datafeedId,
DatafeedState.STARTED), datafeedId, persistentTasks);
DatafeedConfig newDatafeedConfig = update.apply(oldDatafeedConfig, threadContext);
DatafeedConfig newDatafeedConfig = update.apply(oldDatafeedConfig, headers);
if (newDatafeedConfig.getJobId().equals(oldDatafeedConfig.getJobId()) == false) {
checkJobIsAvailableForDatafeed(newDatafeedConfig.getJobId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand Down Expand Up @@ -264,7 +263,7 @@ ChunkingConfig getChunkingConfig() {
* Applies the update to the given {@link DatafeedConfig}
* @return a new {@link DatafeedConfig} that contains the update
*/
public DatafeedConfig apply(DatafeedConfig datafeedConfig, ThreadContext threadContext) {
public DatafeedConfig apply(DatafeedConfig datafeedConfig, Map<String, String> headers) {
if (id.equals(datafeedConfig.getId()) == false) {
throw new IllegalArgumentException("Cannot apply update to datafeedConfig with different id");
}
Expand Down Expand Up @@ -301,12 +300,12 @@ public DatafeedConfig apply(DatafeedConfig datafeedConfig, ThreadContext threadC
builder.setChunkingConfig(chunkingConfig);
}

if (threadContext != null) {
if (headers.isEmpty() == false) {
// Adjust the request, adding security headers from the current thread context
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
Map<String, String> securityHeaders = headers.entrySet().stream()
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(headers);
builder.setHeaders(securityHeaders);
}

return builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public void testApply_failBecauseTargetDatafeedHasDifferentId() {

public void testApply_givenEmptyUpdate() {
DatafeedConfig datafeed = DatafeedConfigTests.createRandomizedDatafeedConfig("foo");
DatafeedConfig updatedDatafeed = new DatafeedUpdate.Builder(datafeed.getId()).build().apply(datafeed, null);
DatafeedConfig updatedDatafeed = new DatafeedUpdate.Builder(datafeed.getId()).build().apply(datafeed, Collections.emptyMap());
assertThat(datafeed, equalTo(updatedDatafeed));
}

Expand All @@ -125,7 +125,7 @@ public void testApply_givenPartialUpdate() {

DatafeedUpdate.Builder updated = new DatafeedUpdate.Builder(datafeed.getId());
updated.setScrollSize(datafeed.getScrollSize() + 1);
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, Collections.emptyMap());

DatafeedConfig.Builder expectedDatafeed = new DatafeedConfig.Builder(datafeed);
expectedDatafeed.setScrollSize(datafeed.getScrollSize() + 1);
Expand All @@ -149,7 +149,7 @@ public void testApply_givenFullUpdateNoAggregations() {
update.setScrollSize(8000);
update.setChunkingConfig(ChunkingConfig.newManual(TimeValue.timeValueHours(1)));

DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, Collections.emptyMap());

assertThat(updatedDatafeed.getJobId(), equalTo("bar"));
assertThat(updatedDatafeed.getIndices(), equalTo(Collections.singletonList("i_2")));
Expand All @@ -175,7 +175,7 @@ public void testApply_givenAggregations() {
update.setAggregations(new AggregatorFactories.Builder().addAggregator(
AggregationBuilders.histogram("a").interval(300000).field("time").subAggregation(maxTime)));

DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, Collections.emptyMap());

assertThat(updatedDatafeed.getIndices(), equalTo(Collections.singletonList("i_1")));
assertThat(updatedDatafeed.getTypes(), equalTo(Collections.singletonList("t_1")));
Expand Down
Loading