From acf0fe1c2c0a47dd03fb57913397cae4debec359 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Fri, 11 Feb 2022 12:08:17 +0530 Subject: [PATCH] Fix race conditions in updating RemoteTaskStats In the existing implementation, the shared currentRequestStartNanos object may be updated by the next request after the current request future is done but before the callback handler has read the value of currentRequestStartNanos set before the current request was started. --- .../ContinuousTaskStatusFetcher.java | 89 ++++++++++--------- .../remotetask/DynamicFiltersFetcher.java | 82 ++++++++--------- .../server/remotetask/TaskInfoFetcher.java | 78 ++++++++-------- 3 files changed, 124 insertions(+), 125 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/ContinuousTaskStatusFetcher.java b/core/trino-main/src/main/java/io/trino/server/remotetask/ContinuousTaskStatusFetcher.java index 42591d7821f8..9174b8413511 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/ContinuousTaskStatusFetcher.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/ContinuousTaskStatusFetcher.java @@ -33,7 +33,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import static com.google.common.base.Strings.isNullOrEmpty; @@ -51,7 +50,6 @@ import static java.util.Objects.requireNonNull; class ContinuousTaskStatusFetcher - implements SimpleHttpResponseCallback { private static final Logger log = Logger.get(ContinuousTaskStatusFetcher.class); @@ -67,8 +65,6 @@ class ContinuousTaskStatusFetcher private final RequestErrorTracker errorTracker; private final RemoteTaskStats stats; - private final AtomicLong currentRequestStartNanos = new AtomicLong(); - @GuardedBy("this") private boolean running; @@ -154,8 +150,7 @@ private synchronized void scheduleNextRequest() errorTracker.startRequest(); future = httpClient.executeAsync(request, createFullJsonResponseHandler(taskStatusCodec)); - currentRequestStartNanos.set(System.nanoTime()); - Futures.addCallback(future, new SimpleHttpResponseHandler<>(this, request.getUri(), stats), executor); + Futures.addCallback(future, new SimpleHttpResponseHandler<>(new TaskStatusResponseCallback(), request.getUri(), stats), executor); } TaskStatus getTaskStatus() @@ -163,52 +158,58 @@ TaskStatus getTaskStatus() return taskStatus.get(); } - @Override - public void success(TaskStatus value) + private class TaskStatusResponseCallback + implements SimpleHttpResponseCallback { - try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) { - updateStats(currentRequestStartNanos.get()); - try { - updateTaskStatus(value); - errorTracker.requestSucceeded(); - } - finally { - scheduleNextRequest(); + private final long requestStartNanos = System.nanoTime(); + + @Override + public void success(TaskStatus value) + { + try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) { + updateStats(requestStartNanos); + try { + updateTaskStatus(value); + errorTracker.requestSucceeded(); + } + finally { + scheduleNextRequest(); + } } } - } - @Override - public void failed(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) { - updateStats(currentRequestStartNanos.get()); - try { - // if task not already done, record error - TaskStatus taskStatus = getTaskStatus(); - if (!taskStatus.getState().isDone()) { - errorTracker.requestFailed(cause); + @Override + public void failed(Throwable cause) + { + try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) { + updateStats(requestStartNanos); + try { + // if task not already done, record error + TaskStatus taskStatus = getTaskStatus(); + if (!taskStatus.getState().isDone()) { + errorTracker.requestFailed(cause); + } + } + catch (Error e) { + onFail.accept(e); + throw e; + } + catch (RuntimeException e) { + onFail.accept(e); + } + finally { + scheduleNextRequest(); } - } - catch (Error e) { - onFail.accept(e); - throw e; - } - catch (RuntimeException e) { - onFail.accept(e); - } - finally { - scheduleNextRequest(); } } - } - @Override - public void fatal(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) { - updateStats(currentRequestStartNanos.get()); - onFail.accept(cause); + @Override + public void fatal(Throwable cause) + { + try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) { + updateStats(requestStartNanos); + onFail.accept(cause); + } } } diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/DynamicFiltersFetcher.java b/core/trino-main/src/main/java/io/trino/server/remotetask/DynamicFiltersFetcher.java index 2cb741342865..4b1c0a5929d8 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/DynamicFiltersFetcher.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/DynamicFiltersFetcher.java @@ -29,7 +29,6 @@ import java.net.URI; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import static com.google.common.net.HttpHeaders.CONTENT_TYPE; @@ -45,7 +44,6 @@ import static java.util.Objects.requireNonNull; class DynamicFiltersFetcher - implements SimpleHttpResponseCallback { private final TaskId taskId; private final URI taskUri; @@ -57,7 +55,6 @@ class DynamicFiltersFetcher private final RequestErrorTracker errorTracker; private final RemoteTaskStats stats; private final DynamicFilterService dynamicFilterService; - private final AtomicLong currentRequestStartNanos = new AtomicLong(); @GuardedBy("this") private long dynamicFiltersVersion = INITIAL_DYNAMIC_FILTERS_VERSION; @@ -158,52 +155,57 @@ private synchronized void fetchDynamicFiltersIfNecessary() errorTracker.startRequest(); future = httpClient.executeAsync(request, createFullJsonResponseHandler(dynamicFilterDomainsCodec)); - currentRequestStartNanos.set(System.nanoTime()); - addCallback(future, new SimpleHttpResponseHandler<>(this, request.getUri(), stats), executor); + addCallback(future, new SimpleHttpResponseHandler<>(new DynamicFiltersResponseCallback(), request.getUri(), stats), executor); } - @Override - public void success(VersionedDynamicFilterDomains newDynamicFilterDomains) + private class DynamicFiltersResponseCallback + implements SimpleHttpResponseCallback { - try (SetThreadName ignored = new SetThreadName("DynamicFiltersFetcher-%s", taskId)) { - updateStats(currentRequestStartNanos.get()); - try { - updateDynamicFilterDomains(newDynamicFilterDomains); - errorTracker.requestSucceeded(); - } - finally { - fetchDynamicFiltersIfNecessary(); + private final long requestStartNanos = System.nanoTime(); + + @Override + public void success(VersionedDynamicFilterDomains newDynamicFilterDomains) + { + try (SetThreadName ignored = new SetThreadName("DynamicFiltersFetcher-%s", taskId)) { + updateStats(requestStartNanos); + try { + updateDynamicFilterDomains(newDynamicFilterDomains); + errorTracker.requestSucceeded(); + } + finally { + fetchDynamicFiltersIfNecessary(); + } } } - } - @Override - public void failed(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("DynamicFiltersFetcher-%s", taskId)) { - updateStats(currentRequestStartNanos.get()); - try { - errorTracker.requestFailed(cause); - } - catch (Error e) { - onFail.accept(e); - throw e; - } - catch (RuntimeException e) { - onFail.accept(e); - } - finally { - fetchDynamicFiltersIfNecessary(); + @Override + public void failed(Throwable cause) + { + try (SetThreadName ignored = new SetThreadName("DynamicFiltersFetcher-%s", taskId)) { + updateStats(requestStartNanos); + try { + errorTracker.requestFailed(cause); + } + catch (Error e) { + onFail.accept(e); + throw e; + } + catch (RuntimeException e) { + onFail.accept(e); + } + finally { + fetchDynamicFiltersIfNecessary(); + } } } - } - @Override - public void fatal(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("DynamicFiltersFetcher-%s", taskId)) { - updateStats(currentRequestStartNanos.get()); - onFail.accept(cause); + @Override + public void fatal(Throwable cause) + { + try (SetThreadName ignored = new SetThreadName("DynamicFiltersFetcher-%s", taskId)) { + updateStats(requestStartNanos); + onFail.accept(cause); + } } } diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/TaskInfoFetcher.java b/core/trino-main/src/main/java/io/trino/server/remotetask/TaskInfoFetcher.java index 5a080023d4fc..bf4add43be39 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/TaskInfoFetcher.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/TaskInfoFetcher.java @@ -49,7 +49,6 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; public class TaskInfoFetcher - implements SimpleHttpResponseCallback { private final TaskId taskId; private final Consumer onFail; @@ -67,10 +66,6 @@ public class TaskInfoFetcher private final RequestErrorTracker errorTracker; private final boolean summarizeTaskInfo; - - @GuardedBy("this") - private final AtomicLong currentRequestStartNanos = new AtomicLong(); - private final RemoteTaskStats stats; @GuardedBy("this") @@ -212,8 +207,7 @@ private synchronized void sendNextRequest() errorTracker.startRequest(); future = httpClient.executeAsync(request, createFullJsonResponseHandler(taskInfoCodec)); - currentRequestStartNanos.set(System.nanoTime()); - Futures.addCallback(future, new SimpleHttpResponseHandler<>(this, request.getUri(), stats), executor); + Futures.addCallback(future, new SimpleHttpResponseHandler<>(new TaskInfoResponseCallback(), request.getUri(), stats), executor); } synchronized void updateTaskInfo(TaskInfo newTaskInfo) @@ -247,49 +241,51 @@ synchronized void updateTaskInfo(TaskInfo newTaskInfo) } } - @Override - public void success(TaskInfo newValue) + private class TaskInfoResponseCallback + implements SimpleHttpResponseCallback { - try (SetThreadName ignored = new SetThreadName("TaskInfoFetcher-%s", taskId)) { - lastUpdateNanos.set(System.nanoTime()); + private final long requestStartNanos = System.nanoTime(); - long startNanos; - synchronized (this) { - startNanos = this.currentRequestStartNanos.get(); + @Override + public void success(TaskInfo newValue) + { + try (SetThreadName ignored = new SetThreadName("TaskInfoFetcher-%s", taskId)) { + lastUpdateNanos.set(System.nanoTime()); + + updateStats(requestStartNanos); + errorTracker.requestSucceeded(); + updateTaskInfo(newValue); } - updateStats(startNanos); - errorTracker.requestSucceeded(); - updateTaskInfo(newValue); } - } - - @Override - public void failed(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("TaskInfoFetcher-%s", taskId)) { - lastUpdateNanos.set(System.nanoTime()); - try { - // if task not already done, record error - if (!isDone(getTaskInfo())) { - errorTracker.requestFailed(cause); + @Override + public void failed(Throwable cause) + { + try (SetThreadName ignored = new SetThreadName("TaskInfoFetcher-%s", taskId)) { + lastUpdateNanos.set(System.nanoTime()); + + try { + // if task not already done, record error + if (!isDone(getTaskInfo())) { + errorTracker.requestFailed(cause); + } + } + catch (Error e) { + onFail.accept(e); + throw e; + } + catch (RuntimeException e) { + onFail.accept(e); } - } - catch (Error e) { - onFail.accept(e); - throw e; - } - catch (RuntimeException e) { - onFail.accept(e); } } - } - @Override - public void fatal(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("TaskInfoFetcher-%s", taskId)) { - onFail.accept(cause); + @Override + public void fatal(Throwable cause) + { + try (SetThreadName ignored = new SetThreadName("TaskInfoFetcher-%s", taskId)) { + onFail.accept(cause); + } } }