diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java
index 280a1c9cac4b..737fea1b8c2a 100644
--- a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java
+++ b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java
@@ -45,6 +45,7 @@
"task.level-absolute-priority"})
public class TaskManagerConfig
{
+ private boolean threadPerDriverSchedulerEnabled = true;
private boolean perOperatorCpuTimerEnabled = true;
private boolean taskCpuTimerEnabled = true;
private boolean statisticsCpuTimerEnabled = true;
@@ -107,6 +108,18 @@ public class TaskManagerConfig
private BigDecimal levelTimeMultiplier = new BigDecimal(2.0);
+ @Config("experimental.thread-per-split-scheduler-enabled")
+ public TaskManagerConfig setThreadPerDriverSchedulerEnabled(boolean enabled)
+ {
+ this.threadPerDriverSchedulerEnabled = enabled;
+ return this;
+ }
+
+ public boolean isThreadPerDriverSchedulerEnabled()
+ {
+ return threadPerDriverSchedulerEnabled;
+ }
+
@MinDuration("1ms")
@MaxDuration("10s")
@NotNull
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/SplitProcessor.java b/core/trino-main/src/main/java/io/trino/execution/executor2/SplitProcessor.java
new file mode 100644
index 000000000000..a81af6a617b8
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/SplitProcessor.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2;
+
+import com.google.common.base.Ticker;
+import com.google.common.util.concurrent.ListenableFuture;
+import io.airlift.concurrent.SetThreadName;
+import io.airlift.stats.CpuTimer;
+import io.airlift.units.Duration;
+import io.opentelemetry.api.trace.Span;
+import io.opentelemetry.api.trace.SpanBuilder;
+import io.opentelemetry.api.trace.Tracer;
+import io.opentelemetry.context.Context;
+import io.trino.execution.SplitRunner;
+import io.trino.execution.TaskId;
+import io.trino.execution.executor2.scheduler.Schedulable;
+import io.trino.execution.executor2.scheduler.SchedulerContext;
+import io.trino.tracing.TrinoAttributes;
+
+import java.util.concurrent.TimeUnit;
+
+import static java.util.Objects.requireNonNull;
+import static java.util.concurrent.TimeUnit.NANOSECONDS;
+
+class SplitProcessor
+ implements Schedulable
+{
+ private static final Duration SPLIT_RUN_QUANTA = new Duration(1, TimeUnit.SECONDS);
+
+ private final TaskId taskId;
+ private final int splitId;
+ private final SplitRunner split;
+ private final Tracer tracer;
+
+ public SplitProcessor(TaskId taskId, int splitId, SplitRunner split, Tracer tracer)
+ {
+ this.taskId = requireNonNull(taskId, "taskId is null");
+ this.splitId = splitId;
+ this.split = requireNonNull(split, "split is null");
+ this.tracer = requireNonNull(tracer, "tracer is null");
+ }
+
+ @Override
+ public void run(SchedulerContext context)
+ {
+ Span splitSpan = tracer.spanBuilder("split")
+ .setParent(Context.current().with(split.getPipelineSpan()))
+ .setAttribute(TrinoAttributes.QUERY_ID, taskId.getQueryId().toString())
+ .setAttribute(TrinoAttributes.STAGE_ID, taskId.getStageId().toString())
+ .setAttribute(TrinoAttributes.TASK_ID, taskId.toString())
+ .setAttribute(TrinoAttributes.PIPELINE_ID, taskId.getStageId() + "-" + split.getPipelineId())
+ .setAttribute(TrinoAttributes.SPLIT_ID, taskId + "-" + splitId)
+ .startSpan();
+
+ Span processSpan = newSpan(splitSpan, null);
+
+ CpuTimer timer = new CpuTimer(Ticker.systemTicker(), false);
+ long previousCpuNanos = 0;
+ long previousScheduledNanos = 0;
+ try (SetThreadName ignored = new SetThreadName("SplitRunner-%s-%s", taskId, splitId)) {
+ while (!split.isFinished()) {
+ ListenableFuture blocked = split.processFor(SPLIT_RUN_QUANTA);
+ CpuTimer.CpuDuration elapsed = timer.elapsedTime();
+
+ long scheduledNanos = elapsed.getWall().roundTo(NANOSECONDS);
+ processSpan.setAttribute(TrinoAttributes.SPLIT_SCHEDULED_TIME_NANOS, scheduledNanos - previousScheduledNanos);
+ previousScheduledNanos = scheduledNanos;
+
+ long cpuNanos = elapsed.getCpu().roundTo(NANOSECONDS);
+ processSpan.setAttribute(TrinoAttributes.SPLIT_CPU_TIME_NANOS, cpuNanos - previousCpuNanos);
+ previousCpuNanos = cpuNanos;
+
+ if (!split.isFinished()) {
+ if (blocked.isDone()) {
+ processSpan.addEvent("yield");
+ processSpan.end();
+ if (!context.maybeYield()) {
+ return;
+ }
+ }
+ else {
+ processSpan.addEvent("blocked");
+ processSpan.end();
+ if (!context.block(blocked)) {
+ return;
+ }
+ }
+ processSpan = newSpan(splitSpan, processSpan);
+ }
+ }
+ }
+ finally {
+ processSpan.end();
+
+ splitSpan.setAttribute(TrinoAttributes.SPLIT_CPU_TIME_NANOS, timer.elapsedTime().getCpu().roundTo(NANOSECONDS));
+ splitSpan.setAttribute(TrinoAttributes.SPLIT_SCHEDULED_TIME_NANOS, context.getScheduledNanos());
+ splitSpan.setAttribute(TrinoAttributes.SPLIT_BLOCK_TIME_NANOS, context.getBlockedNanos());
+ splitSpan.setAttribute(TrinoAttributes.SPLIT_WAIT_TIME_NANOS, context.getWaitNanos());
+ splitSpan.end();
+ }
+ }
+
+ private Span newSpan(Span parent, Span previous)
+ {
+ SpanBuilder builder = tracer.spanBuilder("process")
+ .setParent(Context.current().with(parent));
+
+ if (previous != null) {
+ builder.addLink(previous.getSpanContext());
+ }
+
+ return builder.startSpan();
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/ThreadPerDriverTaskExecutor.java b/core/trino-main/src/main/java/io/trino/execution/executor2/ThreadPerDriverTaskExecutor.java
new file mode 100644
index 000000000000..2d65b0171506
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/ThreadPerDriverTaskExecutor.java
@@ -0,0 +1,212 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Ticker;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.errorprone.annotations.ThreadSafe;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
+import com.google.inject.Inject;
+import io.airlift.units.Duration;
+import io.opentelemetry.api.trace.Tracer;
+import io.trino.execution.SplitRunner;
+import io.trino.execution.TaskId;
+import io.trino.execution.TaskManagerConfig;
+import io.trino.execution.executor.RunningSplitInfo;
+import io.trino.execution.executor.TaskExecutor;
+import io.trino.execution.executor.TaskHandle;
+import io.trino.execution.executor2.scheduler.FairScheduler;
+import io.trino.execution.executor2.scheduler.Group;
+import io.trino.execution.executor2.scheduler.Schedulable;
+import io.trino.execution.executor2.scheduler.SchedulerContext;
+import io.trino.spi.VersionEmbedder;
+import jakarta.annotation.PostConstruct;
+import jakarta.annotation.PreDestroy;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.OptionalInt;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.DoubleSupplier;
+import java.util.function.Predicate;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
+import static java.util.Objects.requireNonNull;
+
+@ThreadSafe
+public class ThreadPerDriverTaskExecutor
+ implements TaskExecutor
+{
+ private final FairScheduler scheduler;
+ private final Tracer tracer;
+ private final VersionEmbedder versionEmbedder;
+ private volatile boolean closed;
+
+ @Inject
+ public ThreadPerDriverTaskExecutor(TaskManagerConfig config, Tracer tracer, VersionEmbedder versionEmbedder)
+ {
+ this(tracer, versionEmbedder, new FairScheduler(config.getMaxWorkerThreads(), "SplitRunner-%d", Ticker.systemTicker()));
+ }
+
+ @VisibleForTesting
+ public ThreadPerDriverTaskExecutor(Tracer tracer, VersionEmbedder versionEmbedder, FairScheduler scheduler)
+ {
+ this.scheduler = scheduler;
+ this.tracer = requireNonNull(tracer, "tracer is null");
+ this.versionEmbedder = requireNonNull(versionEmbedder, "versionEmbedder is null");
+ }
+
+ @PostConstruct
+ @Override
+ public synchronized void start()
+ {
+ scheduler.start();
+ }
+
+ @PreDestroy
+ @Override
+ public synchronized void stop()
+ {
+ closed = true;
+ scheduler.close();
+ }
+
+ @Override
+ public synchronized TaskHandle addTask(
+ TaskId taskId,
+ DoubleSupplier utilizationSupplier,
+ int initialSplitConcurrency,
+ Duration splitConcurrencyAdjustFrequency,
+ OptionalInt maxDriversPerTask)
+ {
+ checkArgument(!closed, "Executor is already closed");
+
+ Group group = scheduler.createGroup(taskId.toString());
+ return new TaskEntry(taskId, group);
+ }
+
+ @Override
+ public synchronized void removeTask(TaskHandle handle)
+ {
+ TaskEntry entry = (TaskEntry) handle;
+
+ if (!entry.isDestroyed()) {
+ scheduler.removeGroup(entry.group());
+ entry.destroy();
+ }
+ }
+
+ @Override
+ public synchronized List> enqueueSplits(TaskHandle handle, boolean intermediate, List extends SplitRunner> splits)
+ {
+ checkArgument(!closed, "Executor is already closed");
+
+ TaskEntry entry = (TaskEntry) handle;
+
+ List> futures = new ArrayList<>();
+ for (SplitRunner split : splits) {
+ entry.addSplit(split);
+
+ int splitId = entry.nextSplitId();
+ ListenableFuture done = scheduler.submit(entry.group(), splitId, new VersionEmbedderBridge(versionEmbedder, new SplitProcessor(entry.taskId(), splitId, split, tracer)));
+ done.addListener(split::close, directExecutor());
+ futures.add(done);
+ }
+
+ return futures;
+ }
+
+ @Override
+ public String getMaxActiveSplitsInfo()
+ {
+ return ""; // TODO
+ }
+
+ @Override
+ public Set getStuckSplitTaskIds(Duration processingDurationThreshold, Predicate filter)
+ {
+ // TODO
+ return ImmutableSet.of();
+ }
+
+ private static class TaskEntry
+ implements TaskHandle
+ {
+ private final TaskId taskId;
+ private final Group group;
+ private final AtomicInteger nextSplitId = new AtomicInteger();
+ private volatile boolean destroyed;
+
+ @GuardedBy("this")
+ private Set splits = new HashSet<>();
+
+ public TaskEntry(TaskId taskId, Group group)
+ {
+ this.taskId = taskId;
+ this.group = group;
+ }
+
+ public TaskId taskId()
+ {
+ return taskId;
+ }
+
+ public Group group()
+ {
+ return group;
+ }
+
+ public synchronized void destroy()
+ {
+ destroyed = true;
+
+ for (SplitRunner split : splits) {
+ split.close();
+ }
+ }
+
+ public synchronized void addSplit(SplitRunner split)
+ {
+ checkArgument(!destroyed, "Task already destroyed: %s", taskId);
+ splits.add(split);
+ }
+
+ public int nextSplitId()
+ {
+ return nextSplitId.incrementAndGet();
+ }
+
+ @Override
+ public boolean isDestroyed()
+ {
+ return destroyed;
+ }
+ }
+
+ private record VersionEmbedderBridge(VersionEmbedder versionEmbedder, Schedulable delegate)
+ implements Schedulable
+ {
+ @Override
+ public void run(SchedulerContext context)
+ {
+ Runnable adapter = () -> delegate.run(context);
+ versionEmbedder.embedVersion(adapter).run();
+ }
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/BlockingSchedulingQueue.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/BlockingSchedulingQueue.java
new file mode 100644
index 000000000000..b32b8e42e47b
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/BlockingSchedulingQueue.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import com.google.errorprone.annotations.ThreadSafe;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
+
+import java.util.Set;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+
+@ThreadSafe
+final class BlockingSchedulingQueue
+{
+ private final Lock lock = new ReentrantLock();
+ private final Condition notEmpty = lock.newCondition();
+
+ @GuardedBy("lock")
+ private final SchedulingQueue queue = new SchedulingQueue<>();
+
+ public void startGroup(G group)
+ {
+ lock.lock();
+ try {
+ queue.startGroup(group);
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public Set finishGroup(G group)
+ {
+ lock.lock();
+ try {
+ return queue.finishGroup(group);
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public Set finishAll()
+ {
+ lock.lock();
+ try {
+ return queue.finishAll();
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public boolean enqueue(G group, T task, long deltaWeight)
+ {
+ lock.lock();
+ try {
+ if (!queue.containsGroup(group)) {
+ return false;
+ }
+
+ queue.enqueue(group, task, deltaWeight);
+ notEmpty.signal();
+
+ return true;
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public boolean block(G group, T task, long deltaWeight)
+ {
+ lock.lock();
+ try {
+ if (!queue.containsGroup(group)) {
+ return false;
+ }
+
+ queue.block(group, task, deltaWeight);
+ return true;
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public T dequeue(long expectedWeight)
+ throws InterruptedException
+ {
+ lock.lock();
+ try {
+ T result;
+ do {
+ result = queue.dequeue(expectedWeight);
+ if (result == null) {
+ notEmpty.await();
+ }
+ }
+ while (result == null);
+
+ return result;
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public String toString()
+ {
+ lock.lock();
+ try {
+ return queue.toString();
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/FairScheduler.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/FairScheduler.java
new file mode 100644
index 000000000000..30af34c7ae42
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/FairScheduler.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import com.google.common.base.Ticker;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import com.google.errorprone.annotations.ThreadSafe;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
+import io.airlift.log.Logger;
+import io.trino.execution.executor2.scheduler.TaskControl.State;
+
+import java.util.Set;
+import java.util.StringJoiner;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * Implementation nodes
+ *
+ *
+ * - The TaskControl state machine is only modified by the task executor
+ * thread (i.e., from within {@link FairScheduler#runTask(Schedulable, TaskControl)} )}). Other threads
+ * can indirectly affect what the task executor thread does by marking the task as ready or cancelled
+ * and unblocking the task executor thread, which will then act on that information.
+ *
+ */
+@ThreadSafe
+public final class FairScheduler
+ implements AutoCloseable
+{
+ private static final Logger LOG = Logger.get(FairScheduler.class);
+
+ public static final long QUANTUM_NANOS = TimeUnit.MILLISECONDS.toNanos(1000);
+
+ private final ExecutorService schedulerExecutor;
+ private final ListeningExecutorService taskExecutor;
+ private final BlockingSchedulingQueue queue = new BlockingSchedulingQueue<>();
+ private final Reservation concurrencyControl;
+ private final Ticker ticker;
+
+ private final Gate paused = new Gate(true);
+
+ @GuardedBy("this")
+ private boolean closed;
+
+ public FairScheduler(int maxConcurrentTasks, String threadNameFormat, Ticker ticker)
+ {
+ this.ticker = requireNonNull(ticker, "ticker is null");
+
+ concurrencyControl = new Reservation<>(maxConcurrentTasks);
+
+ schedulerExecutor = Executors.newCachedThreadPool(new ThreadFactoryBuilder()
+ .setNameFormat("fair-scheduler-%d")
+ .setDaemon(true)
+ .build());
+
+ taskExecutor = MoreExecutors.listeningDecorator(Executors.newCachedThreadPool(new ThreadFactoryBuilder()
+ .setNameFormat(threadNameFormat)
+ .setDaemon(true)
+ .build()));
+ }
+
+ public static FairScheduler newInstance(int maxConcurrentTasks)
+ {
+ return newInstance(maxConcurrentTasks, Ticker.systemTicker());
+ }
+
+ public static FairScheduler newInstance(int maxConcurrentTasks, Ticker ticker)
+ {
+ FairScheduler scheduler = new FairScheduler(maxConcurrentTasks, "fair-scheduler-runner-%d", ticker);
+ scheduler.start();
+ return scheduler;
+ }
+
+ public void start()
+ {
+ schedulerExecutor.submit(this::runScheduler);
+ }
+
+ public void pause()
+ {
+ paused.close();
+ }
+
+ public void resume()
+ {
+ paused.open();
+ }
+
+ @Override
+ public synchronized void close()
+ {
+ if (closed) {
+ return;
+ }
+ closed = true;
+
+ Set tasks = queue.finishAll();
+
+ for (TaskControl task : tasks) {
+ task.cancel();
+ }
+
+ taskExecutor.shutdownNow();
+ schedulerExecutor.shutdownNow();
+ }
+
+ public synchronized Group createGroup(String name)
+ {
+ checkArgument(!closed, "Already closed");
+
+ Group group = new Group(name);
+ queue.startGroup(group);
+
+ return group;
+ }
+
+ public synchronized void removeGroup(Group group)
+ {
+ checkArgument(!closed, "Already closed");
+
+ Set tasks = queue.finishGroup(group);
+
+ for (TaskControl task : tasks) {
+ task.cancel();
+ }
+ }
+
+ public synchronized ListenableFuture submit(Group group, int id, Schedulable runner)
+ {
+ checkArgument(!closed, "Already closed");
+
+ TaskControl task = new TaskControl(group, id, ticker);
+
+ return taskExecutor.submit(() -> runTask(runner, task), null);
+ }
+
+ private void runTask(Schedulable runner, TaskControl task)
+ {
+ task.setThread(Thread.currentThread());
+
+ if (!makeRunnableAndAwait(task, 0)) {
+ return;
+ }
+
+ SchedulerContext context = new SchedulerContext(this, task);
+ try {
+ runner.run(context);
+ }
+ catch (Exception e) {
+ LOG.error(e);
+ }
+ finally {
+ // If the runner exited due to an exception in user code or
+ // normally (not in response to an interruption during blocking or yield),
+ // it must have had a semaphore permit reserved, so release it.
+ if (task.getState() == State.RUNNING) {
+ concurrencyControl.release(task);
+ }
+ }
+ }
+
+ private boolean makeRunnableAndAwait(TaskControl task, long deltaWeight)
+ {
+ if (!task.transitionToWaiting()) {
+ return false;
+ }
+
+ if (!queue.enqueue(task.group(), task, deltaWeight)) {
+ return false;
+ }
+
+ // wait for the task to be scheduled
+ return awaitReadyAndTransitionToRunning(task);
+ }
+
+ /**
+ * @return false if the transition was unsuccessful due to the task being cancelled
+ */
+ private boolean awaitReadyAndTransitionToRunning(TaskControl task)
+ {
+ if (!task.awaitReady()) {
+ if (task.isReady()) {
+ // If the task was marked as ready (slot acquired) but then cancelled before
+ // awaitReady() was notified, we need to release the slot.
+ concurrencyControl.release(task);
+ }
+ return false;
+ }
+
+ if (!task.transitionToRunning()) {
+ concurrencyControl.release(task);
+ return false;
+ }
+
+ return true;
+ }
+
+ boolean yield(TaskControl task)
+ {
+ long delta = task.elapsed();
+ if (delta < QUANTUM_NANOS) {
+ return true;
+ }
+
+ concurrencyControl.release(task);
+
+ return makeRunnableAndAwait(task, delta);
+ }
+
+ boolean block(TaskControl task, ListenableFuture> future)
+ {
+ long delta = task.elapsed();
+
+ concurrencyControl.release(task);
+
+ if (!task.transitionToBlocked()) {
+ return false;
+ }
+
+ if (!queue.block(task.group(), task, delta)) {
+ return false;
+ }
+
+ future.addListener(task::markUnblocked, MoreExecutors.directExecutor());
+ task.awaitUnblock();
+
+ return makeRunnableAndAwait(task, 0);
+ }
+
+ private void runScheduler()
+ {
+ while (true) {
+ try {
+ paused.awaitOpen();
+ concurrencyControl.reserve();
+ TaskControl task = queue.dequeue(QUANTUM_NANOS);
+ concurrencyControl.register(task);
+ if (!task.markReady()) {
+ concurrencyControl.release(task);
+ }
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ return;
+ }
+ }
+ }
+
+ long getScheduledNanos(TaskControl task)
+ {
+ return task.getScheduledNanos();
+ }
+
+ long getWaitNanos(TaskControl task)
+ {
+ return task.getWaitNanos();
+ }
+
+ long getBlockedNanos(TaskControl task)
+ {
+ return task.getBlockedNanos();
+ }
+
+ @Override
+ public String toString()
+ {
+ return new StringJoiner(", ", FairScheduler.class.getSimpleName() + "[", "]")
+ .add("queue=" + queue)
+ .add("concurrencyControl=" + concurrencyControl)
+ .add("closed=" + closed)
+ .toString();
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Gate.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Gate.java
new file mode 100644
index 000000000000..d41fe5754d04
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Gate.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+
+public class Gate
+{
+ private final Lock lock = new ReentrantLock();
+ private final Condition opened = lock.newCondition();
+ private boolean open;
+
+ public Gate(boolean opened)
+ {
+ this.open = opened;
+ }
+
+ public void close()
+ {
+ lock.lock();
+ try {
+ open = false;
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public void open()
+ {
+ lock.lock();
+ try {
+ open = true;
+ opened.signalAll();
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public void awaitOpen()
+ throws InterruptedException
+ {
+ lock.lock();
+ try {
+ while (!open) {
+ opened.await();
+ }
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Group.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Group.java
new file mode 100644
index 000000000000..1700e097919d
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Group.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+public record Group(String name)
+{
+ @Override
+ public String toString()
+ {
+ return name;
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/PriorityQueue.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/PriorityQueue.java
new file mode 100644
index 000000000000..93fd08cfb08b
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/PriorityQueue.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+
+final class PriorityQueue
+{
+ private final TreeSet queue;
+ private final Map> index = new HashMap<>();
+
+ public PriorityQueue()
+ {
+ queue = new TreeSet<>((a, b) -> {
+ int result = Long.compare(index.get(a).priority(), index.get(b).priority());
+ if (result == 0) {
+ result = Integer.compare(System.identityHashCode(a), System.identityHashCode(b));
+ }
+ return result;
+ });
+ }
+
+ public void add(T value, long priority)
+ {
+ checkArgument(!index.containsKey(value), "Value already in queue: %s", value);
+ index.put(value, new Entry<>(priority, value));
+ queue.add(value);
+ }
+
+ public void addOrReplace(T value, long priority)
+ {
+ if (index.containsKey(value)) {
+ queue.remove(value);
+ index.put(value, new Entry<>(priority, value));
+ queue.add(value);
+ }
+ else {
+ add(value, priority);
+ }
+ }
+
+ public T poll()
+ {
+ T result = queue.pollFirst();
+ index.remove(result);
+
+ return result;
+ }
+
+ public void remove(T value)
+ {
+ checkArgument(index.containsKey(value), "Value not in queue: %s", value);
+
+ queue.remove(value);
+ index.remove(value);
+ }
+
+ public void removeIfPresent(T value)
+ {
+ if (index.containsKey(value)) {
+ remove(value);
+ }
+ }
+
+ public boolean contains(T value)
+ {
+ return index.containsKey(value);
+ }
+
+ public boolean isEmpty()
+ {
+ return index.isEmpty();
+ }
+
+ public Set values()
+ {
+ return index.keySet();
+ }
+
+ public long nextPriority()
+ {
+ checkState(!queue.isEmpty(), "Queue is empty");
+ return index.get(queue.first()).priority();
+ }
+
+ public T peek()
+ {
+ if (queue.isEmpty()) {
+ return null;
+ }
+ return queue.first();
+ }
+
+ @Override
+ public String toString()
+ {
+ return queue.toString();
+ }
+
+ private record Entry(long priority, T value)
+ {
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Reservation.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Reservation.java
new file mode 100644
index 000000000000..7cfeaf540922
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Reservation.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.StringJoiner;
+import java.util.concurrent.Semaphore;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+final class Reservation
+{
+ private final Semaphore semaphore;
+ private final Set reservations = new HashSet<>();
+
+ public Reservation(int slots)
+ {
+ semaphore = new Semaphore(slots);
+ }
+
+ public int availablePermits()
+ {
+ return semaphore.availablePermits();
+ }
+
+ public void reserve()
+ throws InterruptedException
+ {
+ semaphore.acquire();
+ }
+
+ public synchronized void register(T entry)
+ {
+ checkArgument(!reservations.contains(entry), "Already acquired: %s", entry);
+ reservations.add(entry);
+ }
+
+ public synchronized void release(T entry)
+ {
+ checkArgument(reservations.contains(entry), "Already released: %s", entry);
+ reservations.remove(entry);
+
+ semaphore.release();
+ }
+
+ public synchronized Set reservations()
+ {
+ return ImmutableSet.copyOf(reservations);
+ }
+
+ @Override
+ public String toString()
+ {
+ return new StringJoiner(", ", Reservation.class.getSimpleName() + "[", "]")
+ .add("semaphore=" + semaphore)
+ .add("reservations=" + reservations)
+ .toString();
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Schedulable.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Schedulable.java
new file mode 100644
index 000000000000..1b2a0d0ab39b
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Schedulable.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+public interface Schedulable
+{
+ void run(SchedulerContext context);
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulerContext.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulerContext.java
new file mode 100644
index 000000000000..ca623029e059
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulerContext.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import com.google.common.util.concurrent.ListenableFuture;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+public final class SchedulerContext
+{
+ private final FairScheduler scheduler;
+ private final TaskControl handle;
+
+ public SchedulerContext(FairScheduler scheduler, TaskControl handle)
+ {
+ this.scheduler = scheduler;
+ this.handle = handle;
+ }
+
+ /**
+ * Attempt to relinquish control to let other tasks run.
+ *
+ * @return false if the task was interrupted while yielding. The caller is expected to clean up and finish
+ */
+ public boolean maybeYield()
+ {
+ checkArgument(handle.getState() == TaskControl.State.RUNNING, "Task is not running");
+
+ return scheduler.yield(handle);
+ }
+
+ /**
+ * Indicate that the current task is blocked. The task will become runnable
+ * when the provided future is completed.
+ *
+ * @return false if the task was interrupted while blocked. The caller is expected to clean up and finish
+ */
+ public boolean block(ListenableFuture> future)
+ {
+ checkArgument(handle.getState() == TaskControl.State.RUNNING, "Task is not running");
+
+ return scheduler.block(handle, future);
+ }
+
+ public long getWaitNanos()
+ {
+ return scheduler.getWaitNanos(handle);
+ }
+
+ public long getScheduledNanos()
+ {
+ return scheduler.getScheduledNanos(handle);
+ }
+
+ public long getBlockedNanos()
+ {
+ return scheduler.getBlockedNanos(handle);
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulingQueue.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulingQueue.java
new file mode 100644
index 000000000000..9d7df3249116
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulingQueue.java
@@ -0,0 +1,531 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import com.google.common.collect.ImmutableSet;
+import io.trino.annotation.NotThreadSafe;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.base.Verify.verify;
+import static io.trino.execution.executor2.scheduler.SchedulingQueue.State.BLOCKED;
+import static io.trino.execution.executor2.scheduler.SchedulingQueue.State.RUNNABLE;
+import static io.trino.execution.executor2.scheduler.SchedulingQueue.State.RUNNING;
+
+/**
+ * A queue of tasks that are scheduled for execution. Modeled after
+ * Completely Fair Scheduler.
+ * Tasks are grouped into scheduling groups. Within a group, tasks are ordered based
+ * on their relative weight. Groups are ordered relative to each other based on the
+ * accumulated weight of their tasks.
+ *
+ * A task can be in one of three states:
+ *
+ * - runnable: the task is ready to run and waiting to be dequeued
+ *
- running: the task has been dequeued and is running
+ *
- blocked: the task is blocked on some external event and is not running
+ *
+ *
+ *
+ * A group can be in one of three states:
+ *
+ * - runnable: the group has at least one runnable task
+ *
- running: all the tasks in the group are currently running
+ *
- blocked: all the tasks in the group are currently blocked
+ *
+ *
+ *
+ * The goal is to balance the consideration among groups to ensure the accumulated
+ * weight in the long run is equal among groups. Within a group, the goal is to
+ * balance the consideration among tasks to ensure the accumulated weight in the
+ * long run is equal among tasks within the group.
+ *
+ *
Groups start in the blocked state and transition to the runnable state when a task is
+ * added via the {@link #enqueue(Object, Object, long)} method.
+ *
+ * Tasks are dequeued via the {@link #dequeue(long)}. When all tasks in a group have
+ * been dequeued, the group transitions to the running state and is removed from the
+ * queue.
+ *
+ * When a task time slice completes, it needs to be re-enqueued via the
+ * {@link #enqueue(Object, Object, long)}, which includes the desired
+ * increment in relative weight to apply to the task for further prioritization.
+ * The weight increment is also applied to the group.
+ *
+ *
+ * If a task blocks, the caller must call the {@link #block(Object, Object, long)}
+ * method to indicate that the task is no longer running. A weight increment can be
+ * included for the portion of time the task was not blocked.
+ *
+ * Group state transitions
+ *
+ * blockTask()
+ * finishTask() enqueueTask() enqueueTask()
+ * ┌───┐ ┌──────────────────────────────────────────┐ ┌────┐
+ * │ │ │ │ │ │
+ * │ ▼ │ ▼ ▼ │
+ * ┌─┴───────┴─┐ all blocked finishTask() ┌────────────┐ │
+ * │ │◄──────────────O◄────────────────────┤ ├──┘
+ * ────►│ BLOCKED │ │ │ RUNNABLE │
+ * │ │ │ ┌────────────────►│ │◄───┐
+ * └───────────┘ not all │ │ enqueueTask() └──────┬─────┘ │
+ * ▲ blocked │ │ │ │
+ * │ │ │ dequeueTask()│ │
+ * │ all blocked ▼ │ │ │
+ * │ ┌─────┴─────┐ ▼ │
+ * │ │ │◄─────────────────O──────────┘
+ * O◄──────────────────┤ RUNNING │ queue empty queue
+ * │ blockTask() │ ├───┐ not empty
+ * │ └───────────┘ │
+ * │ ▲ ▲ │ finishTask()
+ * └─────────────────────┘ └──────┘
+ * not all blocked
+ *
+ *
+ * Implementation notes
+ *
+ * - TODO: Initial weight upon registration
+ * - TODO: Weight adjustment during blocking / unblocking
+ * - TODO: Uncommitted weight on dequeue
+ *
+ *
+ */
+@NotThreadSafe
+final class SchedulingQueue
+{
+ private final PriorityQueue runnableQueue = new PriorityQueue<>();
+ private final Map> groups = new HashMap<>();
+ private final PriorityQueue baselineWeights = new PriorityQueue<>();
+ private final Map blocked = new HashMap<>();
+
+ public void startGroup(G group)
+ {
+ checkArgument(!groups.containsKey(group), "Group already started: %s", group);
+
+ Group info = new Group<>(baselineWeight());
+ groups.put(group, info);
+
+ doTransition(group, info);
+ }
+
+ public Set finishGroup(G group)
+ {
+ checkArgument(groups.containsKey(group), "Unknown group: %s", group);
+
+ runnableQueue.removeIfPresent(group);
+ baselineWeights.removeIfPresent(group);
+ blocked.remove(group);
+ return groups.remove(group).tasks();
+ }
+
+ public boolean containsGroup(G group)
+ {
+ return groups.containsKey(group);
+ }
+
+ public Set finishAll()
+ {
+ Set groups = ImmutableSet.copyOf(this.groups.keySet());
+ return groups.stream()
+ .map(this::finishGroup)
+ .flatMap(Collection::stream)
+ .collect(Collectors.toSet());
+ }
+
+ public void finish(G group, T task)
+ {
+ checkArgument(groups.containsKey(group), "Unknown group: %s", group);
+ verifyState(group);
+
+ Group info = groups.get(group);
+ info.finish(task);
+
+ doTransition(group, info);
+ }
+
+ public void enqueue(G group, T task, long deltaWeight)
+ {
+ checkArgument(groups.containsKey(group), "Unknown group: %s", group);
+ verifyState(group);
+
+ Group info = groups.get(group);
+
+ State previousState = info.state();
+ info.enqueue(task, deltaWeight);
+
+ if (previousState == BLOCKED) {
+ // When transitioning from blocked, set the baseline weight to the minimum current weight
+ // to avoid the newly unblocked group from monopolizing the queue while it catches up
+ Blocked blockedGroup = blocked.remove(group);
+ info.adjustWeight(blockedGroup.savedDelta());
+ }
+
+ checkState(info.state() == RUNNABLE);
+ doTransition(group, info);
+ }
+
+ public void block(G group, T task, long deltaWeight)
+ {
+ Group info = groups.get(group);
+ checkArgument(info != null, "Unknown group: %s", group);
+ checkArgument(info.state() == RUNNABLE || info.state() == RUNNING, "Group is already blocked: %s", group);
+ verifyState(group);
+
+ info.block(task, deltaWeight);
+ doTransition(group, info);
+ }
+
+ public T dequeue(long expectedWeight)
+ {
+ G group = runnableQueue.poll();
+
+ if (group == null) {
+ return null;
+ }
+
+ Group info = groups.get(group);
+ verify(info.state() == RUNNABLE, "Group is not runnable: %s", group);
+
+ T task = info.dequeue(expectedWeight);
+ verify(task != null);
+
+ doTransition(group, info);
+
+ return task;
+ }
+
+ public T peek()
+ {
+ G group = runnableQueue.peek();
+
+ if (group == null) {
+ return null;
+ }
+
+ Group info = groups.get(group);
+ verify(info.state() == RUNNABLE, "Group is not runnable: %s", group);
+
+ T task = info.peek();
+ checkState(task != null);
+
+ return task;
+ }
+
+ private void doTransition(G group, Group info)
+ {
+ switch (info.state()) {
+ case RUNNABLE -> transitionToRunnable(group, info);
+ case RUNNING -> transitionToRunning(group, info);
+ case BLOCKED -> transitionToBlocked(group, info);
+ }
+
+ verifyState(group);
+ }
+
+ private void transitionToRunning(G group, Group info)
+ {
+ checkArgument(info.state() == RUNNING);
+
+ baselineWeights.addOrReplace(group, info.weight());
+ }
+
+ private void transitionToBlocked(G group, Group info)
+ {
+ checkArgument(info.state() == BLOCKED);
+
+ blocked.put(group, new Blocked(info.weight() - baselineWeight()));
+ baselineWeights.removeIfPresent(group);
+ runnableQueue.removeIfPresent(group);
+ }
+
+ private void transitionToRunnable(G group, Group info)
+ {
+ checkArgument(info.state() == RUNNABLE);
+
+ runnableQueue.addOrReplace(group, info.weight());
+ baselineWeights.addOrReplace(group, info.weight());
+ }
+
+ public State state(G group)
+ {
+ Group info = groups.get(group);
+ checkArgument(info != null, "Unknown group: %s", group);
+
+ return info.state();
+ }
+
+ private long baselineWeight()
+ {
+ if (baselineWeights.isEmpty()) {
+ return 0;
+ }
+
+ return baselineWeights.nextPriority();
+ }
+
+ private void verifyState(G groupKey)
+ {
+ Group group = groups.get(groupKey);
+ checkArgument(group != null, "Unknown group: %s", groupKey);
+
+ switch (group.state()) {
+ case BLOCKED -> {
+ checkState(!runnableQueue.contains(groupKey), "Group in BLOCKED state should not be in queue: %s", groupKey);
+ checkState(!baselineWeights.contains(groupKey));
+ }
+ case RUNNABLE -> {
+ checkState(runnableQueue.contains(groupKey), "Group in RUNNABLE state should be in queue: %s", groupKey);
+ checkState(baselineWeights.contains(groupKey));
+ }
+ case RUNNING -> {
+ checkState(!runnableQueue.contains(groupKey), "Group in RUNNING state should not be in queue: %s", groupKey);
+ checkState(baselineWeights.contains(groupKey));
+ }
+ }
+ }
+
+ @Override
+ public String toString()
+ {
+ StringBuilder builder = new StringBuilder();
+
+ builder.append("Baseline weight: %s\n".formatted(baselineWeight()));
+ builder.append("\n");
+
+ for (Map.Entry> entry : groups.entrySet()) {
+ G group = entry.getKey();
+ Group info = entry.getValue();
+
+ switch (entry.getValue().state()) {
+ case BLOCKED -> builder.append(" - %s [BLOCKED, saved delta = %s]\n".formatted(
+ group,
+ blocked.get(entry.getKey()).savedDelta()));
+ case RUNNING, RUNNABLE -> builder.append(" %s %s [%s, weight = %s, baseline = %s]\n".formatted(
+ group == runnableQueue.peek() ? "=>" : " -",
+ group,
+ info.state(),
+ info.weight(),
+ info.baselineWeight()));
+ }
+
+ for (Map.Entry taskEntry : info.tasks.entrySet()) {
+ T task = taskEntry.getKey();
+ Task taskInfo = taskEntry.getValue();
+
+ if (info.blocked.containsKey(task)) {
+ builder.append(" %s [BLOCKED, saved delta = %s]\n".formatted(task, info.blocked.get(task).savedDelta()));
+ }
+ else if (info.runnableQueue.contains(task)) {
+ builder.append(" %s %s [RUNNABLE, weight = %s]\n".formatted(
+ task == info.peek() ? "=>" : " ",
+ task,
+ taskInfo.weight()));
+ }
+ else {
+ builder.append(" %s [RUNNING, weight = %s, uncommitted = %s]\n".formatted(
+ task,
+ taskInfo.weight(),
+ taskInfo.uncommittedWeight()));
+ }
+ }
+ }
+
+ return builder.toString();
+ }
+
+ private static class Group
+ {
+ private State state;
+ private long weight;
+ private final Map tasks = new HashMap<>();
+ private final PriorityQueue runnableQueue = new PriorityQueue<>();
+ private final Map blocked = new HashMap<>();
+ private final PriorityQueue baselineWeights = new PriorityQueue<>();
+
+ public Group(long weight)
+ {
+ this.state = BLOCKED;
+ this.weight = weight;
+ }
+
+ public void enqueue(T task, long deltaWeight)
+ {
+ Task info = tasks.get(task);
+
+ if (info == null) {
+ // New tasks get assigned the baseline weight so that they don't monopolize the queue
+ // while they catch up
+ info = new Task(baselineWeight());
+ tasks.put(task, info);
+ }
+ else if (blocked.containsKey(task)) {
+ Blocked blockedTask = blocked.remove(task);
+ info.adjustWeight(blockedTask.savedDelta());
+ }
+
+ weight -= info.uncommittedWeight();
+ weight += deltaWeight;
+
+ info.commitWeight(deltaWeight);
+ runnableQueue.add(task, info.weight());
+ baselineWeights.addOrReplace(task, info.weight());
+
+ updateState();
+ }
+
+ public T dequeue(long expectedWeight)
+ {
+ checkArgument(state == RUNNABLE);
+
+ T task = runnableQueue.poll();
+
+ Task info = tasks.get(task);
+ info.setUncommittedWeight(expectedWeight);
+ weight += expectedWeight;
+
+ baselineWeights.addOrReplace(task, info.weight());
+
+ updateState();
+
+ return task;
+ }
+
+ public void finish(T task)
+ {
+ checkArgument(tasks.containsKey(task), "Unknown task: %s", task);
+ tasks.remove(task);
+ runnableQueue.removeIfPresent(task);
+ baselineWeights.removeIfPresent(task);
+
+ updateState();
+ }
+
+ public void block(T task, long deltaWeight)
+ {
+ checkArgument(tasks.containsKey(task), "Unknown task: %s", task);
+ checkArgument(!runnableQueue.contains(task), "Task is already in queue: %s", task);
+
+ weight += deltaWeight;
+
+ Task info = tasks.get(task);
+ info.commitWeight(deltaWeight);
+
+ blocked.put(task, new Blocked(weight - baselineWeight()));
+ baselineWeights.remove(task);
+
+ updateState();
+ }
+
+ private long baselineWeight()
+ {
+ if (baselineWeights.isEmpty()) {
+ return 0;
+ }
+
+ return baselineWeights.nextPriority();
+ }
+
+ public void adjustWeight(long delta)
+ {
+ weight += delta;
+ }
+
+ private void updateState()
+ {
+ if (blocked.size() == tasks.size()) {
+ state = BLOCKED;
+ }
+ else if (runnableQueue.isEmpty()) {
+ state = RUNNING;
+ }
+ else {
+ state = RUNNABLE;
+ }
+ }
+
+ public long weight()
+ {
+ return weight;
+ }
+
+ public Set tasks()
+ {
+ return tasks.keySet();
+ }
+
+ public State state()
+ {
+ return state;
+ }
+
+ public T peek()
+ {
+ return runnableQueue.peek();
+ }
+ }
+
+ private static class Task
+ {
+ private long weight;
+ private long uncommittedWeight;
+
+ public Task(long initialWeight)
+ {
+ weight = initialWeight;
+ }
+
+ public void commitWeight(long delta)
+ {
+ weight += delta;
+ uncommittedWeight = 0;
+ }
+
+ public void adjustWeight(long delta)
+ {
+ weight += delta;
+ }
+
+ public long weight()
+ {
+ return weight + uncommittedWeight;
+ }
+
+ public void setUncommittedWeight(long weight)
+ {
+ this.uncommittedWeight = weight;
+ }
+
+ public long uncommittedWeight()
+ {
+ return uncommittedWeight;
+ }
+ }
+
+ public enum State
+ {
+ BLOCKED, // all tasks are blocked
+ RUNNING, // all tasks are dequeued and running
+ RUNNABLE // some tasks are enqueued and ready to run
+ }
+
+ private record Blocked(long savedDelta)
+ {
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/TaskControl.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/TaskControl.java
new file mode 100644
index 000000000000..732fe4990530
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/TaskControl.java
@@ -0,0 +1,348 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import com.google.common.base.Ticker;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
+
+import java.util.Objects;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * Equality is based on group and id for the purpose of adding to the scheduling queue.
+ */
+final class TaskControl
+{
+ private final Group group;
+ private final int id;
+ private final Ticker ticker;
+
+ private final Lock lock = new ReentrantLock();
+
+ @GuardedBy("lock")
+ private final Condition wakeup = lock.newCondition();
+
+ @GuardedBy("lock")
+ private boolean ready;
+
+ @GuardedBy("lock")
+ private boolean blocked;
+
+ @GuardedBy("lock")
+ private boolean cancelled;
+
+ @GuardedBy("lock")
+ private State state;
+
+ private volatile long periodStart;
+ private final AtomicLong scheduledNanos = new AtomicLong();
+ private final AtomicLong blockedNanos = new AtomicLong();
+ private final AtomicLong waitNanos = new AtomicLong();
+ private volatile Thread thread;
+
+ public TaskControl(Group group, int id, Ticker ticker)
+ {
+ this.group = requireNonNull(group, "group is null");
+ this.id = id;
+ this.ticker = requireNonNull(ticker, "ticker is null");
+ this.state = State.NEW;
+ this.ready = false;
+ this.periodStart = ticker.read();
+ }
+
+ public void setThread(Thread thread)
+ {
+ this.thread = thread;
+ }
+
+ public void cancel()
+ {
+ lock.lock();
+ try {
+ cancelled = true;
+ wakeup.signal();
+
+ // TODO: it should be possible to interrupt the thread, but
+ // it appears that it's not safe to do so. It can cause the query
+ // to get stuck (e.g., AbstractDistributedEngineOnlyQueries.testSelectiveLimit)
+ //
+ // Thread thread = this.thread;
+ // if (thread != null) {
+ // thread.interrupt();
+ // }
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ /**
+ * Called by the scheduler thread when the task is ready to run. It
+ * causes anyone blocking in {@link #awaitReady()} to wake up.
+ *
+ * @return false if the task was already cancelled
+ */
+ public boolean markReady()
+ {
+ lock.lock();
+ try {
+ if (cancelled) {
+ return false;
+ }
+ ready = true;
+ wakeup.signal();
+ }
+ finally {
+ lock.unlock();
+ }
+
+ return true;
+ }
+
+ public void markNotReady()
+ {
+ lock.lock();
+ try {
+ ready = false;
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public boolean isReady()
+ {
+ lock.lock();
+ try {
+ return ready;
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ /**
+ * @return false if the operation was interrupted due to cancellation
+ */
+ public boolean awaitReady()
+ {
+ lock.lock();
+ try {
+ while (!ready && !cancelled) {
+ try {
+ wakeup.await();
+ }
+ catch (InterruptedException e) {
+ }
+ }
+
+ return !cancelled;
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public void markUnblocked()
+ {
+ lock.lock();
+ try {
+ blocked = false;
+ wakeup.signal();
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public void markBlocked()
+ {
+ lock.lock();
+ try {
+ blocked = true;
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public void awaitUnblock()
+ {
+ lock.lock();
+ try {
+ while (blocked && !cancelled) {
+ try {
+ wakeup.await();
+ }
+ catch (InterruptedException e) {
+ }
+ }
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ /**
+ * @return false if the transition was unsuccessful due to the task being interrupted
+ */
+ public boolean transitionToBlocked()
+ {
+ boolean success = transitionTo(State.BLOCKED);
+
+ if (success) {
+ markBlocked();
+ }
+
+ return success;
+ }
+
+ /**
+ * @return false if the transition was unsuccessful due to the task being interrupted
+ */
+ public boolean transitionToWaiting()
+ {
+ boolean success = transitionTo(State.WAITING);
+
+ if (success) {
+ markNotReady();
+ }
+
+ return success;
+ }
+
+ /**
+ * @return false if the transition was unsuccessful due to the task being interrupted
+ */
+ public boolean transitionToRunning()
+ {
+ return transitionTo(State.RUNNING);
+ }
+
+ private boolean transitionTo(State state)
+ {
+ lock.lock();
+ try {
+ recordPeriodEnd();
+
+ if (cancelled) {
+ this.state = State.INTERRUPTED;
+ return false;
+ }
+ else {
+ this.state = state;
+ return true;
+ }
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ private void recordPeriodEnd()
+ {
+ long now = ticker.read();
+ long elapsed = now - periodStart;
+ switch (state) {
+ case RUNNING -> scheduledNanos.addAndGet(elapsed);
+ case BLOCKED -> blockedNanos.addAndGet(elapsed);
+ case NEW, WAITING -> waitNanos.addAndGet(elapsed);
+ default -> {
+ // make checkstyle happy. TODO: remove this
+ }
+ }
+ periodStart = now;
+ }
+
+ public Group group()
+ {
+ return group;
+ }
+
+ public State getState()
+ {
+ lock.lock();
+ try {
+ return state;
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public long elapsed()
+ {
+ return ticker.read() - periodStart;
+ }
+
+ public long getWaitNanos()
+ {
+ return waitNanos.get();
+ }
+
+ public long getScheduledNanos()
+ {
+ return scheduledNanos.get();
+ }
+
+ public long getBlockedNanos()
+ {
+ return blockedNanos.get();
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ TaskControl that = (TaskControl) o;
+ return id == that.id && group.equals(that.group);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(group, id);
+ }
+
+ @Override
+ public String toString()
+ {
+ lock.lock();
+ try {
+ return group.name() + "-" + id + " [" + state + "]";
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ public enum State
+ {
+ NEW,
+ WAITING,
+ RUNNING,
+ BLOCKED,
+ INTERRUPTED
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/group-state-diagram.dot b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/group-state-diagram.dot
new file mode 100644
index 000000000000..bc1346753ee0
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/group-state-diagram.dot
@@ -0,0 +1,24 @@
+digraph Group {
+ node [shape=box];
+
+ start [shape=point];
+ split1 [shape=point];
+ split2 [shape=point];
+
+
+ start -> blocked;
+ blocked -> runnable [label="enqueueTask()"];
+ runnable -> runnable [label="enqueueTask()\nblockTask()"];
+ runnable -> split1 [label="dequeueTask()"];
+ split1 -> runnable [label="queue not empty"];
+ split1 -> running [label="queue empty"];
+ running -> split2 [label="blockTask()"];
+ running -> runnable [label="enqueueTask()"];
+ split2 -> blocked [label="all blocked"];
+ split2 -> running [label="not all blocked"];
+ blocked -> blocked [label="finishTask()"];
+ running -> running [label="finishTask()"];
+ runnable -> split3 [label="finishTask()"];
+ split3 -> blocked [label="all blocked"];
+ split3 -> running [label="all running"];
+}
diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java
index 6d8c876fd61d..8ffb27347218 100644
--- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java
+++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java
@@ -52,6 +52,7 @@
import io.trino.execution.executor.MultilevelSplitQueue;
import io.trino.execution.executor.TaskExecutor;
import io.trino.execution.executor.TimeSharingTaskExecutor;
+import io.trino.execution.executor2.ThreadPerDriverTaskExecutor;
import io.trino.execution.scheduler.NodeScheduler;
import io.trino.execution.scheduler.NodeSchedulerConfig;
import io.trino.execution.scheduler.TopologyAwareNodeSelectorModule;
@@ -309,10 +310,6 @@ protected void setup(Binder binder)
newOptionalBinder(binder, VersionEmbedder.class).setDefault().to(EmbedVersion.class).in(Scopes.SINGLETON);
newExporter(binder).export(SqlTaskManager.class).withGeneratedName();
- binder.bind(TaskExecutor.class)
- .to(TimeSharingTaskExecutor.class)
- .in(Scopes.SINGLETON);
-
newExporter(binder).export(TaskExecutor.class).withGeneratedName();
binder.bind(MultilevelSplitQueue.class).in(Scopes.SINGLETON);
newExporter(binder).export(MultilevelSplitQueue.class).withGeneratedName();
@@ -323,6 +320,20 @@ protected void setup(Binder binder)
binder.bind(PageFunctionCompiler.class).in(Scopes.SINGLETON);
newExporter(binder).export(PageFunctionCompiler.class).withGeneratedName();
configBinder(binder).bindConfig(TaskManagerConfig.class);
+
+ // TODO: use conditional module
+ TaskManagerConfig taskManagerConfig = buildConfigObject(TaskManagerConfig.class);
+ if (taskManagerConfig.isThreadPerDriverSchedulerEnabled()) {
+ binder.bind(TaskExecutor.class)
+ .to(ThreadPerDriverTaskExecutor.class)
+ .in(Scopes.SINGLETON);
+ }
+ else {
+ binder.bind(TaskExecutor.class)
+ .to(TimeSharingTaskExecutor.class)
+ .in(Scopes.SINGLETON);
+ }
+
if (retryPolicy == TASK) {
configBinder(binder).bindConfigDefaults(TaskManagerConfig.class, TaskManagerConfig::applyFaultTolerantExecutionDefaults);
}
diff --git a/core/trino-main/src/main/java/io/trino/tracing/TrinoAttributes.java b/core/trino-main/src/main/java/io/trino/tracing/TrinoAttributes.java
index c5cbcd9fdf6f..e9ce0836e9da 100644
--- a/core/trino-main/src/main/java/io/trino/tracing/TrinoAttributes.java
+++ b/core/trino-main/src/main/java/io/trino/tracing/TrinoAttributes.java
@@ -54,6 +54,7 @@ private TrinoAttributes() {}
public static final AttributeKey SPLIT_SCHEDULED_TIME_NANOS = longKey("trino.split.scheduled_time_nanos");
public static final AttributeKey SPLIT_CPU_TIME_NANOS = longKey("trino.split.cpu_time_nanos");
public static final AttributeKey SPLIT_WAIT_TIME_NANOS = longKey("trino.split.wait_time_nanos");
+ public static final AttributeKey SPLIT_BLOCK_TIME_NANOS = longKey("trino.split.block_time_nanos");
public static final AttributeKey SPLIT_BLOCKED = booleanKey("trino.split.blocked");
public static final AttributeKey EVENT_STATE = stringKey("state");
diff --git a/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java b/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java
index 58e468ea2a1f..4241afb6e643 100644
--- a/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java
+++ b/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java
@@ -195,6 +195,7 @@ public void testAbort()
try (SqlTaskManager sqlTaskManager = createSqlTaskManager(new TaskManagerConfig())) {
TaskId taskId = newTaskId();
TaskInfo taskInfo = createTask(sqlTaskManager, taskId, PipelinedOutputBuffers.createInitial(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
+
assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING);
assertNull(taskInfo.getStats().getEndTime());
diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerThreadPerDriver.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerThreadPerDriver.java
new file mode 100644
index 000000000000..67007faf1486
--- /dev/null
+++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerThreadPerDriver.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution;
+
+import com.google.common.base.Ticker;
+import io.airlift.tracing.Tracing;
+import io.trino.execution.executor.TaskExecutor;
+import io.trino.execution.executor2.ThreadPerDriverTaskExecutor;
+import io.trino.execution.executor2.scheduler.FairScheduler;
+
+import static io.trino.version.EmbedVersion.testingVersionEmbedder;
+
+public class TestSqlTaskManagerThreadPerDriver
+ extends BaseTestSqlTaskManager
+{
+ @Override
+ protected TaskExecutor createTaskExecutor()
+ {
+ return new ThreadPerDriverTaskExecutor(
+ Tracing.noopTracer(),
+ testingVersionEmbedder(),
+ new FairScheduler(8, "Runner-%d", Ticker.systemTicker()));
+ }
+}
diff --git a/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java
index 48e6603948d7..48bbf7c9d5cf 100644
--- a/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java
+++ b/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java
@@ -41,6 +41,7 @@ public class TestTaskManagerConfig
public void testDefaults()
{
assertRecordedDefaults(recordDefaults(TaskManagerConfig.class)
+ .setThreadPerDriverSchedulerEnabled(true)
.setInitialSplitsPerNode(Runtime.getRuntime().availableProcessors() * 2)
.setSplitConcurrencyAdjustmentInterval(new Duration(100, TimeUnit.MILLISECONDS))
.setStatusRefreshMaxWait(new Duration(1, TimeUnit.SECONDS))
@@ -87,6 +88,7 @@ public void testExplicitPropertyMappings()
int maxWriterCount = DEFAULT_SCALE_WRITERS_MAX_WRITER_COUNT == 32 ? 16 : 32;
int partitionedWriterCount = DEFAULT_PARTITIONED_WRITER_COUNT == 64 ? 32 : 64;
Map properties = ImmutableMap.builder()
+ .put("experimental.thread-per-split-scheduler-enabled", "false")
.put("task.initial-splits-per-node", "1")
.put("task.split-concurrency-adjustment-interval", "1s")
.put("task.status-refresh-max-wait", "2s")
@@ -127,6 +129,7 @@ public void testExplicitPropertyMappings()
.buildOrThrow();
TaskManagerConfig expected = new TaskManagerConfig()
+ .setThreadPerDriverSchedulerEnabled(false)
.setInitialSplitsPerNode(1)
.setSplitConcurrencyAdjustmentInterval(new Duration(1, TimeUnit.SECONDS))
.setStatusRefreshMaxWait(new Duration(2, TimeUnit.SECONDS))
diff --git a/core/trino-main/src/test/java/io/trino/execution/executor2/TestThreadPerDriverTaskExecutor.java b/core/trino-main/src/test/java/io/trino/execution/executor2/TestThreadPerDriverTaskExecutor.java
new file mode 100644
index 000000000000..47bca43cc0e3
--- /dev/null
+++ b/core/trino-main/src/test/java/io/trino/execution/executor2/TestThreadPerDriverTaskExecutor.java
@@ -0,0 +1,258 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.AbstractFuture;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import io.airlift.testing.TestingTicker;
+import io.airlift.units.Duration;
+import io.opentelemetry.api.trace.Span;
+import io.trino.execution.SplitRunner;
+import io.trino.execution.StageId;
+import io.trino.execution.TaskId;
+import io.trino.execution.TaskManagerConfig;
+import io.trino.execution.executor.TaskHandle;
+import io.trino.execution.executor2.scheduler.FairScheduler;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+
+import java.util.List;
+import java.util.OptionalInt;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Phaser;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+
+import static io.airlift.tracing.Tracing.noopTracer;
+import static io.trino.version.EmbedVersion.testingVersionEmbedder;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class TestThreadPerDriverTaskExecutor
+{
+ @Test
+ @Timeout(10)
+ public void testCancellationWhileProcessing()
+ throws ExecutionException, InterruptedException
+ {
+ ThreadPerDriverTaskExecutor executor = new ThreadPerDriverTaskExecutor(new TaskManagerConfig(), noopTracer(), testingVersionEmbedder());
+ executor.start();
+ try {
+ TaskId taskId = new TaskId(new StageId("query", 1), 1, 1);
+ TaskHandle task = executor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty());
+
+ CountDownLatch started = new CountDownLatch(1);
+
+ SplitRunner split = new TestingSplitRunner(ImmutableList.of(duration -> {
+ started.countDown();
+ try {
+ Thread.currentThread().join();
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+
+ return Futures.immediateVoidFuture();
+ }));
+
+ ListenableFuture splitDone = executor.enqueueSplits(task, false, ImmutableList.of(split)).get(0);
+
+ started.await();
+ executor.removeTask(task);
+
+ splitDone.get();
+ assertThat(split.isFinished()).isTrue();
+ }
+ finally {
+ executor.stop();
+ }
+ }
+
+ @Test
+ @Timeout(10)
+ public void testBlocking()
+ throws ExecutionException, InterruptedException
+ {
+ ThreadPerDriverTaskExecutor executor = new ThreadPerDriverTaskExecutor(new TaskManagerConfig(), noopTracer(), testingVersionEmbedder());
+ executor.start();
+
+ try {
+ TaskId taskId = new TaskId(new StageId("query", 1), 1, 1);
+ TaskHandle task = executor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty());
+
+ TestFuture blocked = new TestFuture();
+
+ SplitRunner split = new TestingSplitRunner(ImmutableList.of(
+ duration -> blocked,
+ duration -> Futures.immediateVoidFuture()));
+
+ ListenableFuture splitDone = executor.enqueueSplits(task, false, ImmutableList.of(split)).get(0);
+
+ blocked.awaitListenerAdded();
+ blocked.set(null); // unblock the split
+
+ splitDone.get();
+ assertThat(split.isFinished()).isTrue();
+ }
+ finally {
+ executor.stop();
+ }
+ }
+
+ @Test
+ @Timeout(10)
+ public void testYielding()
+ throws ExecutionException, InterruptedException
+ {
+ TestingTicker ticker = new TestingTicker();
+ FairScheduler scheduler = new FairScheduler(1, "Runner-%d", ticker);
+ ThreadPerDriverTaskExecutor executor = new ThreadPerDriverTaskExecutor(noopTracer(), testingVersionEmbedder(), scheduler);
+ executor.start();
+
+ try {
+ TaskId taskId = new TaskId(new StageId("query", 1), 1, 1);
+ TaskHandle task = executor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty());
+
+ Phaser phaser = new Phaser(2);
+ SplitRunner split = new TestingSplitRunner(ImmutableList.of(
+ duration -> {
+ phaser.arriveAndAwaitAdvance(); // wait to start
+ phaser.arriveAndAwaitAdvance(); // wait to advance time
+ return Futures.immediateVoidFuture();
+ },
+ duration -> {
+ phaser.arriveAndAwaitAdvance();
+ return Futures.immediateVoidFuture();
+ }));
+
+ ListenableFuture splitDone = executor.enqueueSplits(task, false, ImmutableList.of(split)).get(0);
+
+ phaser.arriveAndAwaitAdvance(); // wait for split to start
+
+ // cause the task to yield
+ ticker.increment(FairScheduler.QUANTUM_NANOS * 2, TimeUnit.NANOSECONDS);
+ phaser.arriveAndAwaitAdvance();
+
+ // wait for reschedule
+ assertThat(phaser.arriveAndAwaitAdvance()).isEqualTo(3); // wait for reschedule
+
+ splitDone.get();
+ assertThat(split.isFinished()).isTrue();
+ }
+ finally {
+ executor.stop();
+ }
+ }
+
+ private static class TestFuture
+ extends AbstractFuture
+ {
+ private final CountDownLatch listenerAdded = new CountDownLatch(1);
+
+ @Override
+ public void addListener(Runnable listener, Executor executor)
+ {
+ super.addListener(listener, executor);
+ listenerAdded.countDown();
+ }
+
+ @Override
+ public boolean set(Void value)
+ {
+ return super.set(value);
+ }
+
+ public void awaitListenerAdded()
+ throws InterruptedException
+ {
+ listenerAdded.await();
+ }
+ }
+
+ private static class TestingSplitRunner
+ implements SplitRunner
+ {
+ private final List>> invocations;
+ private int invocation;
+ private volatile boolean finished;
+ private volatile Thread runnerThread;
+
+ public TestingSplitRunner(List>> invocations)
+ {
+ this.invocations = invocations;
+ }
+
+ @Override
+ public final int getPipelineId()
+ {
+ return 0;
+ }
+
+ @Override
+ public final Span getPipelineSpan()
+ {
+ return Span.getInvalid();
+ }
+
+ @Override
+ public final boolean isFinished()
+ {
+ return finished;
+ }
+
+ @Override
+ public final ListenableFuture processFor(Duration duration)
+ {
+ ListenableFuture blocked;
+
+ runnerThread = Thread.currentThread();
+ try {
+ blocked = invocations.get(invocation).apply(duration);
+ }
+ finally {
+ runnerThread = null;
+ }
+
+ invocation++;
+
+ if (invocation == invocations.size()) {
+ finished = true;
+ }
+
+ return blocked;
+ }
+
+ @Override
+ public final String getInfo()
+ {
+ return "";
+ }
+
+ @Override
+ public final void close()
+ {
+ finished = true;
+
+ Thread runnerThread = this.runnerThread;
+
+ if (runnerThread != null) {
+ runnerThread.interrupt();
+ }
+ }
+ }
+}
diff --git a/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestFairScheduler.java b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestFairScheduler.java
new file mode 100644
index 000000000000..446be1d4e49e
--- /dev/null
+++ b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestFairScheduler.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import com.google.common.util.concurrent.AbstractFuture;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.SettableFuture;
+import io.airlift.testing.TestingTicker;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class TestFairScheduler
+{
+ @Test
+ public void testBasic()
+ throws ExecutionException, InterruptedException
+ {
+ try (FairScheduler scheduler = FairScheduler.newInstance(1)) {
+ Group group = scheduler.createGroup("G1");
+
+ AtomicBoolean ran = new AtomicBoolean();
+ ListenableFuture done = scheduler.submit(group, 1, context -> ran.set(true));
+
+ done.get();
+ assertThat(ran.get())
+ .describedAs("Ran task")
+ .isTrue();
+ }
+ }
+
+ @Test
+ @Timeout(5)
+ public void testYield()
+ throws ExecutionException, InterruptedException
+ {
+ TestingTicker ticker = new TestingTicker();
+ try (FairScheduler scheduler = FairScheduler.newInstance(1, ticker)) {
+ Group group = scheduler.createGroup("G");
+
+ CountDownLatch task1Started = new CountDownLatch(1);
+ AtomicBoolean task2Ran = new AtomicBoolean();
+
+ ListenableFuture task1 = scheduler.submit(group, 1, context -> {
+ task1Started.countDown();
+ while (!task2Ran.get()) {
+ if (!context.maybeYield()) {
+ return;
+ }
+ }
+ });
+
+ task1Started.await();
+
+ ListenableFuture task2 = scheduler.submit(group, 2, context -> {
+ task2Ran.set(true);
+ });
+
+ while (!task2.isDone()) {
+ ticker.increment(FairScheduler.QUANTUM_NANOS * 2, TimeUnit.NANOSECONDS);
+ }
+
+ task1.get();
+ }
+ }
+
+ @Test
+ public void testBlocking()
+ throws InterruptedException, ExecutionException
+ {
+ try (FairScheduler scheduler = FairScheduler.newInstance(1)) {
+ Group group = scheduler.createGroup("G");
+
+ CountDownLatch task1Started = new CountDownLatch(1);
+ CountDownLatch task2Submitted = new CountDownLatch(1);
+ CountDownLatch task2Started = new CountDownLatch(1);
+ AtomicBoolean task2Ran = new AtomicBoolean();
+
+ SettableFuture task1Blocked = SettableFuture.create();
+
+ ListenableFuture task1 = scheduler.submit(group, 1, context -> {
+ try {
+ task1Started.countDown();
+ task2Submitted.await();
+
+ assertThat(task2Ran.get())
+ .describedAs("Task 2 run")
+ .isFalse();
+
+ context.block(task1Blocked);
+
+ assertThat(task2Ran.get())
+ .describedAs("Task 2 run")
+ .isTrue();
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ }
+ });
+
+ task1Started.await();
+
+ ListenableFuture task2 = scheduler.submit(group, 2, context -> {
+ task2Started.countDown();
+ task2Ran.set(true);
+ });
+
+ task2Submitted.countDown();
+ task2Started.await();
+
+ // unblock task 1
+ task1Blocked.set(null);
+
+ task1.get();
+ task2.get();
+ }
+ }
+
+ @Test
+ public void testCancelWhileYielding()
+ throws InterruptedException, ExecutionException
+ {
+ TestingTicker ticker = new TestingTicker();
+ try (FairScheduler scheduler = FairScheduler.newInstance(1, ticker)) {
+ Group group = scheduler.createGroup("G");
+
+ CountDownLatch task1Started = new CountDownLatch(1);
+ CountDownLatch task1TimeAdvanced = new CountDownLatch(1);
+
+ ListenableFuture task1 = scheduler.submit(group, 1, context -> {
+ try {
+ task1Started.countDown();
+ task1TimeAdvanced.await();
+
+ assertThat(context.maybeYield())
+ .describedAs("Cancelled while yielding")
+ .isFalse();
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ }
+ });
+
+ task1Started.await();
+ scheduler.pause(); // prevent rescheduling after yield
+
+ ticker.increment(FairScheduler.QUANTUM_NANOS * 2, TimeUnit.NANOSECONDS);
+ task1TimeAdvanced.countDown();
+
+ scheduler.removeGroup(group);
+ task1.get();
+ }
+ }
+
+ @Test
+ public void testCancelWhileBlocking()
+ throws InterruptedException, ExecutionException
+ {
+ TestingTicker ticker = new TestingTicker();
+ try (FairScheduler scheduler = FairScheduler.newInstance(1, ticker)) {
+ Group group = scheduler.createGroup("G");
+
+ CountDownLatch task1Started = new CountDownLatch(1);
+ TestFuture task1Blocked = new TestFuture();
+
+ ListenableFuture task1 = scheduler.submit(group, 1, context -> {
+ task1Started.countDown();
+
+ assertThat(context.block(task1Blocked))
+ .describedAs("Cancelled while blocking")
+ .isFalse();
+ });
+
+ task1Started.await();
+
+ task1Blocked.awaitListenerAdded(); // When the listener is added, we know the task is blocked
+
+ scheduler.removeGroup(group);
+ task1.get();
+ }
+ }
+
+ private static class TestFuture
+ extends AbstractFuture
+ {
+ private final CountDownLatch listenerAdded = new CountDownLatch(1);
+
+ @Override
+ public void addListener(Runnable listener, Executor executor)
+ {
+ super.addListener(listener, executor);
+ listenerAdded.countDown();
+ }
+
+ @Override
+ public boolean set(Void value)
+ {
+ return super.set(value);
+ }
+
+ public void awaitListenerAdded()
+ throws InterruptedException
+ {
+ listenerAdded.await();
+ }
+ }
+}
diff --git a/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestPriorityQueue.java b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestPriorityQueue.java
new file mode 100644
index 000000000000..ab21233f86a0
--- /dev/null
+++ b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestPriorityQueue.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import com.google.common.collect.ImmutableSet;
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+public class TestPriorityQueue
+{
+ @Test
+ public void testEmpty()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ assertThat(queue.poll()).isNull();
+ assertThat(queue.isEmpty()).isTrue();
+ }
+
+ @Test
+ public void testNotEmpty()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ queue.add("hello", 1);
+ assertThat(queue.isEmpty()).isFalse();
+ }
+
+ @Test
+ public void testDuplicate()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ queue.add("hello", 1);
+ assertThatThrownBy(() -> queue.add("hello", 2))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void testOrder()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ queue.add("jumps", 5);
+ queue.add("fox", 4);
+ queue.add("over", 6);
+ queue.add("brown", 3);
+ queue.add("dog", 8);
+ queue.add("the", 1);
+ queue.add("lazy", 7);
+ queue.add("quick", 2);
+
+ assertThat(queue.poll()).isEqualTo("the");
+ assertThat(queue.poll()).isEqualTo("quick");
+ assertThat(queue.poll()).isEqualTo("brown");
+ assertThat(queue.poll()).isEqualTo("fox");
+ assertThat(queue.poll()).isEqualTo("jumps");
+ assertThat(queue.poll()).isEqualTo("over");
+ assertThat(queue.poll()).isEqualTo("lazy");
+ assertThat(queue.poll()).isEqualTo("dog");
+ assertThat(queue.poll()).isNull();
+ }
+
+ @Test
+ public void testInterleaved()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ queue.add("jumps", 5);
+ queue.add("over", 6);
+ queue.add("fox", 4);
+
+ assertThat(queue.poll()).isEqualTo("fox");
+ assertThat(queue.poll()).isEqualTo("jumps");
+
+ queue.add("brown", 3);
+ queue.add("dog", 8);
+ queue.add("the", 1);
+
+ assertThat(queue.poll()).isEqualTo("the");
+ assertThat(queue.poll()).isEqualTo("brown");
+ assertThat(queue.poll()).isEqualTo("over");
+
+ queue.add("lazy", 7);
+ queue.add("quick", 2);
+
+ assertThat(queue.poll()).isEqualTo("quick");
+ assertThat(queue.poll()).isEqualTo("lazy");
+ assertThat(queue.poll()).isEqualTo("dog");
+ assertThat(queue.poll()).isNull();
+ }
+
+ @Test
+ public void testRemove()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ queue.add("fox", 4);
+ queue.add("brown", 3);
+ queue.add("the", 1);
+ queue.add("quick", 2);
+
+ queue.remove("brown");
+
+ assertThat(queue.poll()).isEqualTo("the");
+ assertThat(queue.poll()).isEqualTo("quick");
+ assertThat(queue.poll()).isEqualTo("fox");
+ assertThat(queue.poll()).isNull();
+ }
+
+ @Test
+ public void testRemoveMissing()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ queue.add("the", 1);
+ queue.add("quick", 2);
+ queue.add("brown", 3);
+
+ assertThatThrownBy(() -> queue.remove("fox"))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ public void testContains()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ queue.add("the", 1);
+ queue.add("quick", 2);
+ queue.add("brown", 3);
+
+ assertThat(queue.contains("quick")).isTrue();
+ assertThat(queue.contains("fox")).isFalse();
+ }
+
+ @Test
+ public void testRecycle()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ queue.add("hello", 1);
+ assertThat(queue.poll()).isEqualTo("hello");
+
+ queue.add("hello", 2);
+ assertThat(queue.poll()).isEqualTo("hello");
+ }
+
+ @Test
+ public void testValues()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ assertThat(queue.values()).isEmpty();
+
+ queue.add("hello", 1);
+ queue.add("world", 2);
+
+ assertThat(queue.values())
+ .isEqualTo(ImmutableSet.of("hello", "world"));
+ }
+
+ @Test
+ public void testNextPriority()
+ {
+ PriorityQueue queue = new PriorityQueue<>();
+
+ assertThatThrownBy(queue::nextPriority)
+ .isInstanceOf(IllegalStateException.class);
+
+ queue.add("hello", 10);
+ queue.add("world", 20);
+
+ assertThat(queue.nextPriority()).isEqualTo(10);
+
+ queue.poll();
+ assertThat(queue.nextPriority()).isEqualTo(20);
+
+ queue.poll();
+ assertThatThrownBy(queue::nextPriority)
+ .isInstanceOf(IllegalStateException.class);
+ }
+}
diff --git a/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestSchedulingQueue.java b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestSchedulingQueue.java
new file mode 100644
index 000000000000..4fdfc7570934
--- /dev/null
+++ b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestSchedulingQueue.java
@@ -0,0 +1,286 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.execution.executor2.scheduler;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class TestSchedulingQueue
+{
+ @Test
+ public void testEmpty()
+ {
+ SchedulingQueue queue = new SchedulingQueue<>();
+
+ assertThat(queue.dequeue(1)).isNull();
+ }
+
+ @Test
+ public void testSingleGroup()
+ {
+ SchedulingQueue queue = new SchedulingQueue<>();
+
+ queue.startGroup("G1");
+
+ queue.enqueue("G1", "T1", 1);
+ queue.enqueue("G1", "T2", 3);
+ queue.enqueue("G1", "T3", 5);
+ queue.enqueue("G1", "T4", 7);
+
+ assertThat(queue.dequeue(1)).isEqualTo("T1");
+ assertThat(queue.dequeue(1)).isEqualTo("T2");
+ assertThat(queue.dequeue(1)).isEqualTo("T3");
+ assertThat(queue.dequeue(1)).isEqualTo("T4");
+
+ queue.enqueue("G1", "T1", 10);
+ queue.enqueue("G1", "T2", 10);
+ queue.enqueue("G1", "T3", 10);
+ queue.enqueue("G1", "T4", 10);
+
+ assertThat(queue.dequeue(1)).isEqualTo("T1");
+ assertThat(queue.dequeue(1)).isEqualTo("T2");
+ assertThat(queue.dequeue(1)).isEqualTo("T3");
+ assertThat(queue.dequeue(1)).isEqualTo("T4");
+
+ queue.enqueue("G1", "T1", 16);
+ queue.enqueue("G1", "T2", 12);
+ queue.enqueue("G1", "T3", 8);
+ queue.enqueue("G1", "T4", 4);
+
+ assertThat(queue.dequeue(1)).isEqualTo("T4");
+ assertThat(queue.dequeue(1)).isEqualTo("T3");
+ assertThat(queue.dequeue(1)).isEqualTo("T2");
+ assertThat(queue.dequeue(1)).isEqualTo("T1");
+
+ queue.finish("G1", "T1");
+ queue.finish("G1", "T2");
+ queue.finish("G1", "T3");
+ queue.finish("G1", "T4");
+
+ assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.BLOCKED);
+ }
+
+ @Test
+ public void testBasic()
+ {
+ SchedulingQueue queue = new SchedulingQueue<>();
+
+ queue.startGroup("G1");
+ queue.startGroup("G2");
+
+ queue.enqueue("G1", "T1.0", 1);
+ queue.enqueue("G1", "T1.1", 2);
+ queue.enqueue("G2", "T2.0", 3);
+ queue.enqueue("G2", "T2.1", 4);
+
+ assertThat(queue.dequeue(1)).isEqualTo("T1.0");
+ assertThat(queue.dequeue(1)).isEqualTo("T1.1");
+ assertThat(queue.dequeue(1)).isEqualTo("T2.0");
+ assertThat(queue.dequeue(1)).isEqualTo("T2.1");
+
+ queue.enqueue("G1", "T1.0", 10);
+ queue.enqueue("G1", "T1.1", 20);
+ queue.enqueue("G2", "T2.0", 15);
+ queue.enqueue("G2", "T2.1", 5);
+
+ assertThat(queue.dequeue(1)).isEqualTo("T2.1");
+ assertThat(queue.dequeue(1)).isEqualTo("T2.0");
+ assertThat(queue.dequeue(1)).isEqualTo("T1.0");
+ assertThat(queue.dequeue(1)).isEqualTo("T1.1");
+
+ queue.enqueue("G1", "T1.0", 100);
+ queue.enqueue("G2", "T2.0", 90);
+ assertThat(queue.dequeue(1)).isEqualTo("T2.0");
+ assertThat(queue.dequeue(1)).isEqualTo("T1.0");
+ }
+
+ @Test
+ public void testSomeEmptyGroups()
+ {
+ SchedulingQueue queue = new SchedulingQueue<>();
+
+ queue.startGroup("G1");
+ queue.startGroup("G2");
+
+ queue.enqueue("G2", "T1", 0);
+
+ assertThat(queue.dequeue(1)).isEqualTo("T1");
+ }
+
+ @Test
+ public void testDelayedCreation()
+ {
+ SchedulingQueue queue = new SchedulingQueue<>();
+
+ queue.startGroup("G1");
+ queue.startGroup("G2");
+
+ queue.enqueue("G1", "T1.0", 100);
+ queue.enqueue("G2", "T2.0", 200);
+
+ queue.startGroup("G3"); // new group gets a priority baseline equal to the minimum current priority
+ queue.enqueue("G3", "T3.0", 50);
+
+ assertThat(queue.dequeue(1)).isEqualTo("T1.0");
+ assertThat(queue.dequeue(1)).isEqualTo("T3.0");
+ assertThat(queue.dequeue(1)).isEqualTo("T2.0");
+ }
+
+ @Test
+ public void testDelayedCreationWhileAllRunning()
+ {
+ SchedulingQueue queue = new SchedulingQueue<>();
+
+ queue.startGroup("G1");
+ queue.startGroup("G2");
+
+ queue.enqueue("G1", "T1.0", 0);
+
+ queue.enqueue("G2", "T2.0", 100);
+ queue.dequeue(50);
+ queue.dequeue(50);
+
+ queue.startGroup("G3"); // new group gets a priority baseline equal to the minimum current priority
+ queue.enqueue("G3", "T3.0", 10);
+
+ queue.enqueue("G1", "T1.0", 50);
+ queue.enqueue("G2", "T2.0", 50);
+
+ assertThat(queue.dequeue(1)).isEqualTo("T1.0");
+ assertThat(queue.dequeue(1)).isEqualTo("T3.0");
+ assertThat(queue.dequeue(1)).isEqualTo("T2.0");
+ }
+
+ @Test
+ public void testGroupState()
+ {
+ SchedulingQueue queue = new SchedulingQueue<>();
+
+ // initial state with no tasks
+ queue.startGroup("G1");
+ assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.BLOCKED);
+
+ // after adding a task, it should be runnable
+ queue.enqueue("G1", "T1", 0);
+ assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNABLE);
+ queue.enqueue("G1", "T2", 0);
+ assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNABLE);
+
+ // after dequeueing, still runnable if there's at least one runnable task
+ queue.dequeue(1);
+ assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNABLE);
+
+ // after all tasks are dequeued, it should be running
+ queue.dequeue(1);
+ assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNING);
+
+ // still running while at least one task is running and there are no runnable tasks
+ queue.block("G1", "T1", 1);
+ assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNING);
+
+ // blocked when all tasks are blocked
+ queue.block("G1", "T2", 1);
+ assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.BLOCKED);
+
+ // back to runnable after unblocking
+ queue.enqueue("G1", "T1", 1);
+ assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNABLE);
+ }
+
+ @Test
+ public void testNonGreedyDeque()
+ {
+ SchedulingQueue queue = new SchedulingQueue<>();
+
+ queue.startGroup("G1");
+ queue.startGroup("G2");
+
+ queue.enqueue("G1", "T1.0", 0);
+ queue.enqueue("G2", "T2.0", 1);
+
+ queue.enqueue("G1", "T1.1", 2);
+ queue.enqueue("G1", "T1.2", 3);
+
+ queue.enqueue("G2", "T2.1", 2);
+ queue.enqueue("G2", "T2.2", 3);
+
+ assertThat(queue.dequeue(2)).isEqualTo("T1.0");
+ assertThat(queue.dequeue(2)).isEqualTo("T2.0");
+ assertThat(queue.dequeue(2)).isEqualTo("T1.1");
+ assertThat(queue.dequeue(2)).isEqualTo("T2.1");
+ assertThat(queue.dequeue(2)).isEqualTo("T1.2");
+ assertThat(queue.dequeue(2)).isEqualTo("T2.2");
+ assertThat(queue.dequeue(2)).isNull();
+ }
+
+ @Test
+ public void testFinishTask()
+ {
+ SchedulingQueue queue = new SchedulingQueue<>();
+
+ queue.startGroup("G1");
+ queue.enqueue("G1", "T1", 0);
+ queue.enqueue("G1", "T2", 1);
+ queue.enqueue("G1", "T3", 2);
+
+ assertThat(queue.peek()).isEqualTo("T1");
+ queue.finish("G1", "T1");
+ assertThat(queue.peek()).isEqualTo("T2");
+
+ // check that the group becomes not-runnable
+ queue.finish("G1", "T2");
+ queue.finish("G1", "T3");
+ assertThat(queue.peek()).isNull();
+
+ // check that the group becomes runnable again
+ queue.enqueue("G1", "T4", 0);
+ assertThat(queue.peek()).isEqualTo("T4");
+ }
+
+ @Test
+ public void testFinishGroup()
+ {
+ SchedulingQueue queue = new SchedulingQueue<>();
+
+ queue.startGroup("G1");
+ queue.enqueue("G1", "T1.1", 0);
+ assertThat(queue.peek()).isEqualTo("T1.1");
+
+ queue.startGroup("G2");
+ queue.enqueue("G2", "T2.1", 1);
+ assertThat(queue.peek()).isEqualTo("T1.1");
+
+ queue.finishGroup("G1");
+ assertThat(queue.peek()).isEqualTo("T2.1");
+ }
+
+ record Group(int id)
+ {
+ @Override
+ public String toString()
+ {
+ return "G" + id;
+ }
+ }
+
+ record Task(Group group, int id)
+ {
+ @Override
+ public String toString()
+ {
+ return "T" + group + "." + id;
+ }
+ }
+}
diff --git a/pom.xml b/pom.xml
index 4bbf755b3bed..961fbd5d31ec 100644
--- a/pom.xml
+++ b/pom.xml
@@ -2411,7 +2411,8 @@
-Xep:StreamResourceLeak:ERROR \
-Xep:UnnecessaryMethodReference:ERROR \
-Xep:UnnecessaryOptionalGet:ERROR \
- -Xep:UnusedVariable:ERROR \
+
+
-Xep:UseEnumSwitch:ERROR \
-XepExcludedPaths:.*/target/generated-(|test-)sources/.*
diff --git a/testing/trino-server-dev/etc/config.properties b/testing/trino-server-dev/etc/config.properties
index b786657e8d2b..6692959ccb50 100644
--- a/testing/trino-server-dev/etc/config.properties
+++ b/testing/trino-server-dev/etc/config.properties
@@ -60,3 +60,6 @@ plugin.bundles=\
../../plugin/trino-mysql-event-listener/pom.xml
node-scheduler.include-coordinator=true
+
+tracing.enabled=true
+tracing.exporter.endpoint=http://localhost:4317