diff --git a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java index 2543be4811c1e..8927adfd43458 100644 --- a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java +++ b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java @@ -730,7 +730,7 @@ public void submitStateUpdateTasks(final String source, return; } final ThreadContext threadContext = threadPool.getThreadContext(); - final Supplier supplier = threadContext.newRestorableContext(false); + final Supplier supplier = threadContext.newRestorableContext(true); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { threadContext.markAsSystemContext(); diff --git a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java index 20587d31f5359..1ef548bd68114 100644 --- a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java @@ -54,6 +54,7 @@ import org.junit.BeforeClass; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -177,6 +178,8 @@ public void testThreadContext() throws InterruptedException { try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) { final Map expectedHeaders = Collections.singletonMap("test", "test"); + final Map> expectedResponseHeaders = Collections.singletonMap("testResponse", + Arrays.asList("testResponse")); threadPool.getThreadContext().putHeader(expectedHeaders); final TimeValue ackTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000)); @@ -187,6 +190,8 @@ public void testThreadContext() throws InterruptedException { public ClusterState execute(ClusterState currentState) { assertTrue(threadPool.getThreadContext().isSystemContext()); assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getHeaders()); + threadPool.getThreadContext().addResponseHeader("testResponse", "testResponse"); + assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders()); if (randomBoolean()) { return ClusterState.builder(currentState).build(); @@ -201,6 +206,7 @@ public ClusterState execute(ClusterState currentState) { public void onFailure(String source, Exception e) { assertFalse(threadPool.getThreadContext().isSystemContext()); assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders()); + assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders()); latch.countDown(); } @@ -208,6 +214,7 @@ public void onFailure(String source, Exception e) { public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { assertFalse(threadPool.getThreadContext().isSystemContext()); assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders()); + assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders()); latch.countDown(); } @@ -229,6 +236,7 @@ public TimeValue timeout() { public void onAllNodesAcked(@Nullable Exception e) { assertFalse(threadPool.getThreadContext().isSystemContext()); assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders()); + assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders()); latch.countDown(); } @@ -236,6 +244,7 @@ public void onAllNodesAcked(@Nullable Exception e) { public void onAckTimeout() { assertFalse(threadPool.getThreadContext().isSystemContext()); assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders()); + assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders()); latch.countDown(); } @@ -243,6 +252,7 @@ public void onAckTimeout() { assertFalse(threadPool.getThreadContext().isSystemContext()); assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders()); + assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getResponseHeaders()); } latch.await();