diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java index 280a1c9cac4b..737fea1b8c2a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java @@ -45,6 +45,7 @@ "task.level-absolute-priority"}) public class TaskManagerConfig { + private boolean threadPerDriverSchedulerEnabled = true; private boolean perOperatorCpuTimerEnabled = true; private boolean taskCpuTimerEnabled = true; private boolean statisticsCpuTimerEnabled = true; @@ -107,6 +108,18 @@ public class TaskManagerConfig private BigDecimal levelTimeMultiplier = new BigDecimal(2.0); + @Config("experimental.thread-per-split-scheduler-enabled") + public TaskManagerConfig setThreadPerDriverSchedulerEnabled(boolean enabled) + { + this.threadPerDriverSchedulerEnabled = enabled; + return this; + } + + public boolean isThreadPerDriverSchedulerEnabled() + { + return threadPerDriverSchedulerEnabled; + } + @MinDuration("1ms") @MaxDuration("10s") @NotNull diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/SplitProcessor.java b/core/trino-main/src/main/java/io/trino/execution/executor2/SplitProcessor.java new file mode 100644 index 000000000000..a81af6a617b8 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/SplitProcessor.java @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2; + +import com.google.common.base.Ticker; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.concurrent.SetThreadName; +import io.airlift.stats.CpuTimer; +import io.airlift.units.Duration; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanBuilder; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.trino.execution.SplitRunner; +import io.trino.execution.TaskId; +import io.trino.execution.executor2.scheduler.Schedulable; +import io.trino.execution.executor2.scheduler.SchedulerContext; +import io.trino.tracing.TrinoAttributes; + +import java.util.concurrent.TimeUnit; + +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +class SplitProcessor + implements Schedulable +{ + private static final Duration SPLIT_RUN_QUANTA = new Duration(1, TimeUnit.SECONDS); + + private final TaskId taskId; + private final int splitId; + private final SplitRunner split; + private final Tracer tracer; + + public SplitProcessor(TaskId taskId, int splitId, SplitRunner split, Tracer tracer) + { + this.taskId = requireNonNull(taskId, "taskId is null"); + this.splitId = splitId; + this.split = requireNonNull(split, "split is null"); + this.tracer = requireNonNull(tracer, "tracer is null"); + } + + @Override + public void run(SchedulerContext context) + { + Span splitSpan = tracer.spanBuilder("split") + .setParent(Context.current().with(split.getPipelineSpan())) + .setAttribute(TrinoAttributes.QUERY_ID, taskId.getQueryId().toString()) + .setAttribute(TrinoAttributes.STAGE_ID, taskId.getStageId().toString()) + .setAttribute(TrinoAttributes.TASK_ID, taskId.toString()) + .setAttribute(TrinoAttributes.PIPELINE_ID, taskId.getStageId() + "-" + split.getPipelineId()) + .setAttribute(TrinoAttributes.SPLIT_ID, taskId + "-" + splitId) + .startSpan(); + + Span processSpan = newSpan(splitSpan, null); + + CpuTimer timer = new CpuTimer(Ticker.systemTicker(), false); + long previousCpuNanos = 0; + long previousScheduledNanos = 0; + try (SetThreadName ignored = new SetThreadName("SplitRunner-%s-%s", taskId, splitId)) { + while (!split.isFinished()) { + ListenableFuture blocked = split.processFor(SPLIT_RUN_QUANTA); + CpuTimer.CpuDuration elapsed = timer.elapsedTime(); + + long scheduledNanos = elapsed.getWall().roundTo(NANOSECONDS); + processSpan.setAttribute(TrinoAttributes.SPLIT_SCHEDULED_TIME_NANOS, scheduledNanos - previousScheduledNanos); + previousScheduledNanos = scheduledNanos; + + long cpuNanos = elapsed.getCpu().roundTo(NANOSECONDS); + processSpan.setAttribute(TrinoAttributes.SPLIT_CPU_TIME_NANOS, cpuNanos - previousCpuNanos); + previousCpuNanos = cpuNanos; + + if (!split.isFinished()) { + if (blocked.isDone()) { + processSpan.addEvent("yield"); + processSpan.end(); + if (!context.maybeYield()) { + return; + } + } + else { + processSpan.addEvent("blocked"); + processSpan.end(); + if (!context.block(blocked)) { + return; + } + } + processSpan = newSpan(splitSpan, processSpan); + } + } + } + finally { + processSpan.end(); + + splitSpan.setAttribute(TrinoAttributes.SPLIT_CPU_TIME_NANOS, timer.elapsedTime().getCpu().roundTo(NANOSECONDS)); + splitSpan.setAttribute(TrinoAttributes.SPLIT_SCHEDULED_TIME_NANOS, context.getScheduledNanos()); + splitSpan.setAttribute(TrinoAttributes.SPLIT_BLOCK_TIME_NANOS, context.getBlockedNanos()); + splitSpan.setAttribute(TrinoAttributes.SPLIT_WAIT_TIME_NANOS, context.getWaitNanos()); + splitSpan.end(); + } + } + + private Span newSpan(Span parent, Span previous) + { + SpanBuilder builder = tracer.spanBuilder("process") + .setParent(Context.current().with(parent)); + + if (previous != null) { + builder.addLink(previous.getSpanContext()); + } + + return builder.startSpan(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/ThreadPerDriverTaskExecutor.java b/core/trino-main/src/main/java/io/trino/execution/executor2/ThreadPerDriverTaskExecutor.java new file mode 100644 index 000000000000..2d65b0171506 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/ThreadPerDriverTaskExecutor.java @@ -0,0 +1,212 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ticker; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import com.google.inject.Inject; +import io.airlift.units.Duration; +import io.opentelemetry.api.trace.Tracer; +import io.trino.execution.SplitRunner; +import io.trino.execution.TaskId; +import io.trino.execution.TaskManagerConfig; +import io.trino.execution.executor.RunningSplitInfo; +import io.trino.execution.executor.TaskExecutor; +import io.trino.execution.executor.TaskHandle; +import io.trino.execution.executor2.scheduler.FairScheduler; +import io.trino.execution.executor2.scheduler.Group; +import io.trino.execution.executor2.scheduler.Schedulable; +import io.trino.execution.executor2.scheduler.SchedulerContext; +import io.trino.spi.VersionEmbedder; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.OptionalInt; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.DoubleSupplier; +import java.util.function.Predicate; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static java.util.Objects.requireNonNull; + +@ThreadSafe +public class ThreadPerDriverTaskExecutor + implements TaskExecutor +{ + private final FairScheduler scheduler; + private final Tracer tracer; + private final VersionEmbedder versionEmbedder; + private volatile boolean closed; + + @Inject + public ThreadPerDriverTaskExecutor(TaskManagerConfig config, Tracer tracer, VersionEmbedder versionEmbedder) + { + this(tracer, versionEmbedder, new FairScheduler(config.getMaxWorkerThreads(), "SplitRunner-%d", Ticker.systemTicker())); + } + + @VisibleForTesting + public ThreadPerDriverTaskExecutor(Tracer tracer, VersionEmbedder versionEmbedder, FairScheduler scheduler) + { + this.scheduler = scheduler; + this.tracer = requireNonNull(tracer, "tracer is null"); + this.versionEmbedder = requireNonNull(versionEmbedder, "versionEmbedder is null"); + } + + @PostConstruct + @Override + public synchronized void start() + { + scheduler.start(); + } + + @PreDestroy + @Override + public synchronized void stop() + { + closed = true; + scheduler.close(); + } + + @Override + public synchronized TaskHandle addTask( + TaskId taskId, + DoubleSupplier utilizationSupplier, + int initialSplitConcurrency, + Duration splitConcurrencyAdjustFrequency, + OptionalInt maxDriversPerTask) + { + checkArgument(!closed, "Executor is already closed"); + + Group group = scheduler.createGroup(taskId.toString()); + return new TaskEntry(taskId, group); + } + + @Override + public synchronized void removeTask(TaskHandle handle) + { + TaskEntry entry = (TaskEntry) handle; + + if (!entry.isDestroyed()) { + scheduler.removeGroup(entry.group()); + entry.destroy(); + } + } + + @Override + public synchronized List> enqueueSplits(TaskHandle handle, boolean intermediate, List splits) + { + checkArgument(!closed, "Executor is already closed"); + + TaskEntry entry = (TaskEntry) handle; + + List> futures = new ArrayList<>(); + for (SplitRunner split : splits) { + entry.addSplit(split); + + int splitId = entry.nextSplitId(); + ListenableFuture done = scheduler.submit(entry.group(), splitId, new VersionEmbedderBridge(versionEmbedder, new SplitProcessor(entry.taskId(), splitId, split, tracer))); + done.addListener(split::close, directExecutor()); + futures.add(done); + } + + return futures; + } + + @Override + public String getMaxActiveSplitsInfo() + { + return ""; // TODO + } + + @Override + public Set getStuckSplitTaskIds(Duration processingDurationThreshold, Predicate filter) + { + // TODO + return ImmutableSet.of(); + } + + private static class TaskEntry + implements TaskHandle + { + private final TaskId taskId; + private final Group group; + private final AtomicInteger nextSplitId = new AtomicInteger(); + private volatile boolean destroyed; + + @GuardedBy("this") + private Set splits = new HashSet<>(); + + public TaskEntry(TaskId taskId, Group group) + { + this.taskId = taskId; + this.group = group; + } + + public TaskId taskId() + { + return taskId; + } + + public Group group() + { + return group; + } + + public synchronized void destroy() + { + destroyed = true; + + for (SplitRunner split : splits) { + split.close(); + } + } + + public synchronized void addSplit(SplitRunner split) + { + checkArgument(!destroyed, "Task already destroyed: %s", taskId); + splits.add(split); + } + + public int nextSplitId() + { + return nextSplitId.incrementAndGet(); + } + + @Override + public boolean isDestroyed() + { + return destroyed; + } + } + + private record VersionEmbedderBridge(VersionEmbedder versionEmbedder, Schedulable delegate) + implements Schedulable + { + @Override + public void run(SchedulerContext context) + { + Runnable adapter = () -> delegate.run(context); + versionEmbedder.embedVersion(adapter).run(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/BlockingSchedulingQueue.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/BlockingSchedulingQueue.java new file mode 100644 index 000000000000..b32b8e42e47b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/BlockingSchedulingQueue.java @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; + +import java.util.Set; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +@ThreadSafe +final class BlockingSchedulingQueue +{ + private final Lock lock = new ReentrantLock(); + private final Condition notEmpty = lock.newCondition(); + + @GuardedBy("lock") + private final SchedulingQueue queue = new SchedulingQueue<>(); + + public void startGroup(G group) + { + lock.lock(); + try { + queue.startGroup(group); + } + finally { + lock.unlock(); + } + } + + public Set finishGroup(G group) + { + lock.lock(); + try { + return queue.finishGroup(group); + } + finally { + lock.unlock(); + } + } + + public Set finishAll() + { + lock.lock(); + try { + return queue.finishAll(); + } + finally { + lock.unlock(); + } + } + + public boolean enqueue(G group, T task, long deltaWeight) + { + lock.lock(); + try { + if (!queue.containsGroup(group)) { + return false; + } + + queue.enqueue(group, task, deltaWeight); + notEmpty.signal(); + + return true; + } + finally { + lock.unlock(); + } + } + + public boolean block(G group, T task, long deltaWeight) + { + lock.lock(); + try { + if (!queue.containsGroup(group)) { + return false; + } + + queue.block(group, task, deltaWeight); + return true; + } + finally { + lock.unlock(); + } + } + + public T dequeue(long expectedWeight) + throws InterruptedException + { + lock.lock(); + try { + T result; + do { + result = queue.dequeue(expectedWeight); + if (result == null) { + notEmpty.await(); + } + } + while (result == null); + + return result; + } + finally { + lock.unlock(); + } + } + + @Override + public String toString() + { + lock.lock(); + try { + return queue.toString(); + } + finally { + lock.unlock(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/FairScheduler.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/FairScheduler.java new file mode 100644 index 000000000000..30af34c7ae42 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/FairScheduler.java @@ -0,0 +1,292 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import com.google.common.base.Ticker; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.airlift.log.Logger; +import io.trino.execution.executor2.scheduler.TaskControl.State; + +import java.util.Set; +import java.util.StringJoiner; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + *

Implementation nodes

+ * + *
    + *
  • The TaskControl state machine is only modified by the task executor + * thread (i.e., from within {@link FairScheduler#runTask(Schedulable, TaskControl)} )}). Other threads + * can indirectly affect what the task executor thread does by marking the task as ready or cancelled + * and unblocking the task executor thread, which will then act on that information.
  • + *
+ */ +@ThreadSafe +public final class FairScheduler + implements AutoCloseable +{ + private static final Logger LOG = Logger.get(FairScheduler.class); + + public static final long QUANTUM_NANOS = TimeUnit.MILLISECONDS.toNanos(1000); + + private final ExecutorService schedulerExecutor; + private final ListeningExecutorService taskExecutor; + private final BlockingSchedulingQueue queue = new BlockingSchedulingQueue<>(); + private final Reservation concurrencyControl; + private final Ticker ticker; + + private final Gate paused = new Gate(true); + + @GuardedBy("this") + private boolean closed; + + public FairScheduler(int maxConcurrentTasks, String threadNameFormat, Ticker ticker) + { + this.ticker = requireNonNull(ticker, "ticker is null"); + + concurrencyControl = new Reservation<>(maxConcurrentTasks); + + schedulerExecutor = Executors.newCachedThreadPool(new ThreadFactoryBuilder() + .setNameFormat("fair-scheduler-%d") + .setDaemon(true) + .build()); + + taskExecutor = MoreExecutors.listeningDecorator(Executors.newCachedThreadPool(new ThreadFactoryBuilder() + .setNameFormat(threadNameFormat) + .setDaemon(true) + .build())); + } + + public static FairScheduler newInstance(int maxConcurrentTasks) + { + return newInstance(maxConcurrentTasks, Ticker.systemTicker()); + } + + public static FairScheduler newInstance(int maxConcurrentTasks, Ticker ticker) + { + FairScheduler scheduler = new FairScheduler(maxConcurrentTasks, "fair-scheduler-runner-%d", ticker); + scheduler.start(); + return scheduler; + } + + public void start() + { + schedulerExecutor.submit(this::runScheduler); + } + + public void pause() + { + paused.close(); + } + + public void resume() + { + paused.open(); + } + + @Override + public synchronized void close() + { + if (closed) { + return; + } + closed = true; + + Set tasks = queue.finishAll(); + + for (TaskControl task : tasks) { + task.cancel(); + } + + taskExecutor.shutdownNow(); + schedulerExecutor.shutdownNow(); + } + + public synchronized Group createGroup(String name) + { + checkArgument(!closed, "Already closed"); + + Group group = new Group(name); + queue.startGroup(group); + + return group; + } + + public synchronized void removeGroup(Group group) + { + checkArgument(!closed, "Already closed"); + + Set tasks = queue.finishGroup(group); + + for (TaskControl task : tasks) { + task.cancel(); + } + } + + public synchronized ListenableFuture submit(Group group, int id, Schedulable runner) + { + checkArgument(!closed, "Already closed"); + + TaskControl task = new TaskControl(group, id, ticker); + + return taskExecutor.submit(() -> runTask(runner, task), null); + } + + private void runTask(Schedulable runner, TaskControl task) + { + task.setThread(Thread.currentThread()); + + if (!makeRunnableAndAwait(task, 0)) { + return; + } + + SchedulerContext context = new SchedulerContext(this, task); + try { + runner.run(context); + } + catch (Exception e) { + LOG.error(e); + } + finally { + // If the runner exited due to an exception in user code or + // normally (not in response to an interruption during blocking or yield), + // it must have had a semaphore permit reserved, so release it. + if (task.getState() == State.RUNNING) { + concurrencyControl.release(task); + } + } + } + + private boolean makeRunnableAndAwait(TaskControl task, long deltaWeight) + { + if (!task.transitionToWaiting()) { + return false; + } + + if (!queue.enqueue(task.group(), task, deltaWeight)) { + return false; + } + + // wait for the task to be scheduled + return awaitReadyAndTransitionToRunning(task); + } + + /** + * @return false if the transition was unsuccessful due to the task being cancelled + */ + private boolean awaitReadyAndTransitionToRunning(TaskControl task) + { + if (!task.awaitReady()) { + if (task.isReady()) { + // If the task was marked as ready (slot acquired) but then cancelled before + // awaitReady() was notified, we need to release the slot. + concurrencyControl.release(task); + } + return false; + } + + if (!task.transitionToRunning()) { + concurrencyControl.release(task); + return false; + } + + return true; + } + + boolean yield(TaskControl task) + { + long delta = task.elapsed(); + if (delta < QUANTUM_NANOS) { + return true; + } + + concurrencyControl.release(task); + + return makeRunnableAndAwait(task, delta); + } + + boolean block(TaskControl task, ListenableFuture future) + { + long delta = task.elapsed(); + + concurrencyControl.release(task); + + if (!task.transitionToBlocked()) { + return false; + } + + if (!queue.block(task.group(), task, delta)) { + return false; + } + + future.addListener(task::markUnblocked, MoreExecutors.directExecutor()); + task.awaitUnblock(); + + return makeRunnableAndAwait(task, 0); + } + + private void runScheduler() + { + while (true) { + try { + paused.awaitOpen(); + concurrencyControl.reserve(); + TaskControl task = queue.dequeue(QUANTUM_NANOS); + concurrencyControl.register(task); + if (!task.markReady()) { + concurrencyControl.release(task); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + } + } + + long getScheduledNanos(TaskControl task) + { + return task.getScheduledNanos(); + } + + long getWaitNanos(TaskControl task) + { + return task.getWaitNanos(); + } + + long getBlockedNanos(TaskControl task) + { + return task.getBlockedNanos(); + } + + @Override + public String toString() + { + return new StringJoiner(", ", FairScheduler.class.getSimpleName() + "[", "]") + .add("queue=" + queue) + .add("concurrencyControl=" + concurrencyControl) + .add("closed=" + closed) + .toString(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Gate.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Gate.java new file mode 100644 index 000000000000..d41fe5754d04 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Gate.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +public class Gate +{ + private final Lock lock = new ReentrantLock(); + private final Condition opened = lock.newCondition(); + private boolean open; + + public Gate(boolean opened) + { + this.open = opened; + } + + public void close() + { + lock.lock(); + try { + open = false; + } + finally { + lock.unlock(); + } + } + + public void open() + { + lock.lock(); + try { + open = true; + opened.signalAll(); + } + finally { + lock.unlock(); + } + } + + public void awaitOpen() + throws InterruptedException + { + lock.lock(); + try { + while (!open) { + opened.await(); + } + } + finally { + lock.unlock(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Group.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Group.java new file mode 100644 index 000000000000..1700e097919d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Group.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +public record Group(String name) +{ + @Override + public String toString() + { + return name; + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/PriorityQueue.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/PriorityQueue.java new file mode 100644 index 000000000000..93fd08cfb08b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/PriorityQueue.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.TreeSet; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +final class PriorityQueue +{ + private final TreeSet queue; + private final Map> index = new HashMap<>(); + + public PriorityQueue() + { + queue = new TreeSet<>((a, b) -> { + int result = Long.compare(index.get(a).priority(), index.get(b).priority()); + if (result == 0) { + result = Integer.compare(System.identityHashCode(a), System.identityHashCode(b)); + } + return result; + }); + } + + public void add(T value, long priority) + { + checkArgument(!index.containsKey(value), "Value already in queue: %s", value); + index.put(value, new Entry<>(priority, value)); + queue.add(value); + } + + public void addOrReplace(T value, long priority) + { + if (index.containsKey(value)) { + queue.remove(value); + index.put(value, new Entry<>(priority, value)); + queue.add(value); + } + else { + add(value, priority); + } + } + + public T poll() + { + T result = queue.pollFirst(); + index.remove(result); + + return result; + } + + public void remove(T value) + { + checkArgument(index.containsKey(value), "Value not in queue: %s", value); + + queue.remove(value); + index.remove(value); + } + + public void removeIfPresent(T value) + { + if (index.containsKey(value)) { + remove(value); + } + } + + public boolean contains(T value) + { + return index.containsKey(value); + } + + public boolean isEmpty() + { + return index.isEmpty(); + } + + public Set values() + { + return index.keySet(); + } + + public long nextPriority() + { + checkState(!queue.isEmpty(), "Queue is empty"); + return index.get(queue.first()).priority(); + } + + public T peek() + { + if (queue.isEmpty()) { + return null; + } + return queue.first(); + } + + @Override + public String toString() + { + return queue.toString(); + } + + private record Entry(long priority, T value) + { + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Reservation.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Reservation.java new file mode 100644 index 000000000000..7cfeaf540922 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Reservation.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import com.google.common.collect.ImmutableSet; + +import java.util.HashSet; +import java.util.Set; +import java.util.StringJoiner; +import java.util.concurrent.Semaphore; + +import static com.google.common.base.Preconditions.checkArgument; + +final class Reservation +{ + private final Semaphore semaphore; + private final Set reservations = new HashSet<>(); + + public Reservation(int slots) + { + semaphore = new Semaphore(slots); + } + + public int availablePermits() + { + return semaphore.availablePermits(); + } + + public void reserve() + throws InterruptedException + { + semaphore.acquire(); + } + + public synchronized void register(T entry) + { + checkArgument(!reservations.contains(entry), "Already acquired: %s", entry); + reservations.add(entry); + } + + public synchronized void release(T entry) + { + checkArgument(reservations.contains(entry), "Already released: %s", entry); + reservations.remove(entry); + + semaphore.release(); + } + + public synchronized Set reservations() + { + return ImmutableSet.copyOf(reservations); + } + + @Override + public String toString() + { + return new StringJoiner(", ", Reservation.class.getSimpleName() + "[", "]") + .add("semaphore=" + semaphore) + .add("reservations=" + reservations) + .toString(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Schedulable.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Schedulable.java new file mode 100644 index 000000000000..1b2a0d0ab39b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/Schedulable.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +public interface Schedulable +{ + void run(SchedulerContext context); +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulerContext.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulerContext.java new file mode 100644 index 000000000000..ca623029e059 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulerContext.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import com.google.common.util.concurrent.ListenableFuture; + +import static com.google.common.base.Preconditions.checkArgument; + +public final class SchedulerContext +{ + private final FairScheduler scheduler; + private final TaskControl handle; + + public SchedulerContext(FairScheduler scheduler, TaskControl handle) + { + this.scheduler = scheduler; + this.handle = handle; + } + + /** + * Attempt to relinquish control to let other tasks run. + * + * @return false if the task was interrupted while yielding. The caller is expected to clean up and finish + */ + public boolean maybeYield() + { + checkArgument(handle.getState() == TaskControl.State.RUNNING, "Task is not running"); + + return scheduler.yield(handle); + } + + /** + * Indicate that the current task is blocked. The task will become runnable + * when the provided future is completed. + * + * @return false if the task was interrupted while blocked. The caller is expected to clean up and finish + */ + public boolean block(ListenableFuture future) + { + checkArgument(handle.getState() == TaskControl.State.RUNNING, "Task is not running"); + + return scheduler.block(handle, future); + } + + public long getWaitNanos() + { + return scheduler.getWaitNanos(handle); + } + + public long getScheduledNanos() + { + return scheduler.getScheduledNanos(handle); + } + + public long getBlockedNanos() + { + return scheduler.getBlockedNanos(handle); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulingQueue.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulingQueue.java new file mode 100644 index 000000000000..9d7df3249116 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/SchedulingQueue.java @@ -0,0 +1,531 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import com.google.common.collect.ImmutableSet; +import io.trino.annotation.NotThreadSafe; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static io.trino.execution.executor2.scheduler.SchedulingQueue.State.BLOCKED; +import static io.trino.execution.executor2.scheduler.SchedulingQueue.State.RUNNABLE; +import static io.trino.execution.executor2.scheduler.SchedulingQueue.State.RUNNING; + +/** + *

A queue of tasks that are scheduled for execution. Modeled after + * Completely Fair Scheduler. + * Tasks are grouped into scheduling groups. Within a group, tasks are ordered based + * on their relative weight. Groups are ordered relative to each other based on the + * accumulated weight of their tasks.

+ * + *

A task can be in one of three states: + *

    + *
  • runnable: the task is ready to run and waiting to be dequeued + *
  • running: the task has been dequeued and is running + *
  • blocked: the task is blocked on some external event and is not running + *
+ *

+ *

+ * A group can be in one of three states: + *

    + *
  • runnable: the group has at least one runnable task + *
  • running: all the tasks in the group are currently running + *
  • blocked: all the tasks in the group are currently blocked + *
+ *

+ *

+ * The goal is to balance the consideration among groups to ensure the accumulated + * weight in the long run is equal among groups. Within a group, the goal is to + * balance the consideration among tasks to ensure the accumulated weight in the + * long run is equal among tasks within the group. + * + *

Groups start in the blocked state and transition to the runnable state when a task is + * added via the {@link #enqueue(Object, Object, long)} method.

+ * + *

Tasks are dequeued via the {@link #dequeue(long)}. When all tasks in a group have + * been dequeued, the group transitions to the running state and is removed from the + * queue.

+ * + *

When a task time slice completes, it needs to be re-enqueued via the + * {@link #enqueue(Object, Object, long)}, which includes the desired + * increment in relative weight to apply to the task for further prioritization. + * The weight increment is also applied to the group. + *

+ * + *

If a task blocks, the caller must call the {@link #block(Object, Object, long)} + * method to indicate that the task is no longer running. A weight increment can be + * included for the portion of time the task was not blocked.

+ *
+ *

Group state transitions

+ *
+ *                                                                 blockTask()
+ *    finishTask()               enqueueTask()                     enqueueTask()
+ *        ┌───┐   ┌──────────────────────────────────────────┐       ┌────┐
+ *        │   │   │                                          │       │    │
+ *        │   ▼   │                                          ▼       ▼    │
+ *      ┌─┴───────┴─┐   all blocked        finishTask()   ┌────────────┐  │
+ *      │           │◄──────────────O◄────────────────────┤            ├──┘
+ * ────►│  BLOCKED  │               │                     │  RUNNABLE  │
+ *      │           │               │   ┌────────────────►│            │◄───┐
+ *      └───────────┘       not all │   │  enqueueTask()  └──────┬─────┘    │
+ *            ▲             blocked │   │                        │          │
+ *            │                     │   │           dequeueTask()│          │
+ *            │ all blocked         ▼   │                        │          │
+ *            │                   ┌─────┴─────┐                  ▼          │
+ *            │                   │           │◄─────────────────O──────────┘
+ *            O◄──────────────────┤  RUNNING  │      queue empty     queue
+ *            │      blockTask()  │           ├───┐                 not empty
+ *            │                   └───────────┘   │
+ *            │                     ▲      ▲      │ finishTask()
+ *            └─────────────────────┘      └──────┘
+ *                not all blocked
+ * 
+ * + *

Implementation notes

+ *
    + *
  • TODO: Initial weight upon registration
  • + *
  • TODO: Weight adjustment during blocking / unblocking
  • + *
  • TODO: Uncommitted weight on dequeue
  • + *
+ *

+ */ +@NotThreadSafe +final class SchedulingQueue +{ + private final PriorityQueue runnableQueue = new PriorityQueue<>(); + private final Map> groups = new HashMap<>(); + private final PriorityQueue baselineWeights = new PriorityQueue<>(); + private final Map blocked = new HashMap<>(); + + public void startGroup(G group) + { + checkArgument(!groups.containsKey(group), "Group already started: %s", group); + + Group info = new Group<>(baselineWeight()); + groups.put(group, info); + + doTransition(group, info); + } + + public Set finishGroup(G group) + { + checkArgument(groups.containsKey(group), "Unknown group: %s", group); + + runnableQueue.removeIfPresent(group); + baselineWeights.removeIfPresent(group); + blocked.remove(group); + return groups.remove(group).tasks(); + } + + public boolean containsGroup(G group) + { + return groups.containsKey(group); + } + + public Set finishAll() + { + Set groups = ImmutableSet.copyOf(this.groups.keySet()); + return groups.stream() + .map(this::finishGroup) + .flatMap(Collection::stream) + .collect(Collectors.toSet()); + } + + public void finish(G group, T task) + { + checkArgument(groups.containsKey(group), "Unknown group: %s", group); + verifyState(group); + + Group info = groups.get(group); + info.finish(task); + + doTransition(group, info); + } + + public void enqueue(G group, T task, long deltaWeight) + { + checkArgument(groups.containsKey(group), "Unknown group: %s", group); + verifyState(group); + + Group info = groups.get(group); + + State previousState = info.state(); + info.enqueue(task, deltaWeight); + + if (previousState == BLOCKED) { + // When transitioning from blocked, set the baseline weight to the minimum current weight + // to avoid the newly unblocked group from monopolizing the queue while it catches up + Blocked blockedGroup = blocked.remove(group); + info.adjustWeight(blockedGroup.savedDelta()); + } + + checkState(info.state() == RUNNABLE); + doTransition(group, info); + } + + public void block(G group, T task, long deltaWeight) + { + Group info = groups.get(group); + checkArgument(info != null, "Unknown group: %s", group); + checkArgument(info.state() == RUNNABLE || info.state() == RUNNING, "Group is already blocked: %s", group); + verifyState(group); + + info.block(task, deltaWeight); + doTransition(group, info); + } + + public T dequeue(long expectedWeight) + { + G group = runnableQueue.poll(); + + if (group == null) { + return null; + } + + Group info = groups.get(group); + verify(info.state() == RUNNABLE, "Group is not runnable: %s", group); + + T task = info.dequeue(expectedWeight); + verify(task != null); + + doTransition(group, info); + + return task; + } + + public T peek() + { + G group = runnableQueue.peek(); + + if (group == null) { + return null; + } + + Group info = groups.get(group); + verify(info.state() == RUNNABLE, "Group is not runnable: %s", group); + + T task = info.peek(); + checkState(task != null); + + return task; + } + + private void doTransition(G group, Group info) + { + switch (info.state()) { + case RUNNABLE -> transitionToRunnable(group, info); + case RUNNING -> transitionToRunning(group, info); + case BLOCKED -> transitionToBlocked(group, info); + } + + verifyState(group); + } + + private void transitionToRunning(G group, Group info) + { + checkArgument(info.state() == RUNNING); + + baselineWeights.addOrReplace(group, info.weight()); + } + + private void transitionToBlocked(G group, Group info) + { + checkArgument(info.state() == BLOCKED); + + blocked.put(group, new Blocked(info.weight() - baselineWeight())); + baselineWeights.removeIfPresent(group); + runnableQueue.removeIfPresent(group); + } + + private void transitionToRunnable(G group, Group info) + { + checkArgument(info.state() == RUNNABLE); + + runnableQueue.addOrReplace(group, info.weight()); + baselineWeights.addOrReplace(group, info.weight()); + } + + public State state(G group) + { + Group info = groups.get(group); + checkArgument(info != null, "Unknown group: %s", group); + + return info.state(); + } + + private long baselineWeight() + { + if (baselineWeights.isEmpty()) { + return 0; + } + + return baselineWeights.nextPriority(); + } + + private void verifyState(G groupKey) + { + Group group = groups.get(groupKey); + checkArgument(group != null, "Unknown group: %s", groupKey); + + switch (group.state()) { + case BLOCKED -> { + checkState(!runnableQueue.contains(groupKey), "Group in BLOCKED state should not be in queue: %s", groupKey); + checkState(!baselineWeights.contains(groupKey)); + } + case RUNNABLE -> { + checkState(runnableQueue.contains(groupKey), "Group in RUNNABLE state should be in queue: %s", groupKey); + checkState(baselineWeights.contains(groupKey)); + } + case RUNNING -> { + checkState(!runnableQueue.contains(groupKey), "Group in RUNNING state should not be in queue: %s", groupKey); + checkState(baselineWeights.contains(groupKey)); + } + } + } + + @Override + public String toString() + { + StringBuilder builder = new StringBuilder(); + + builder.append("Baseline weight: %s\n".formatted(baselineWeight())); + builder.append("\n"); + + for (Map.Entry> entry : groups.entrySet()) { + G group = entry.getKey(); + Group info = entry.getValue(); + + switch (entry.getValue().state()) { + case BLOCKED -> builder.append(" - %s [BLOCKED, saved delta = %s]\n".formatted( + group, + blocked.get(entry.getKey()).savedDelta())); + case RUNNING, RUNNABLE -> builder.append(" %s %s [%s, weight = %s, baseline = %s]\n".formatted( + group == runnableQueue.peek() ? "=>" : " -", + group, + info.state(), + info.weight(), + info.baselineWeight())); + } + + for (Map.Entry taskEntry : info.tasks.entrySet()) { + T task = taskEntry.getKey(); + Task taskInfo = taskEntry.getValue(); + + if (info.blocked.containsKey(task)) { + builder.append(" %s [BLOCKED, saved delta = %s]\n".formatted(task, info.blocked.get(task).savedDelta())); + } + else if (info.runnableQueue.contains(task)) { + builder.append(" %s %s [RUNNABLE, weight = %s]\n".formatted( + task == info.peek() ? "=>" : " ", + task, + taskInfo.weight())); + } + else { + builder.append(" %s [RUNNING, weight = %s, uncommitted = %s]\n".formatted( + task, + taskInfo.weight(), + taskInfo.uncommittedWeight())); + } + } + } + + return builder.toString(); + } + + private static class Group + { + private State state; + private long weight; + private final Map tasks = new HashMap<>(); + private final PriorityQueue runnableQueue = new PriorityQueue<>(); + private final Map blocked = new HashMap<>(); + private final PriorityQueue baselineWeights = new PriorityQueue<>(); + + public Group(long weight) + { + this.state = BLOCKED; + this.weight = weight; + } + + public void enqueue(T task, long deltaWeight) + { + Task info = tasks.get(task); + + if (info == null) { + // New tasks get assigned the baseline weight so that they don't monopolize the queue + // while they catch up + info = new Task(baselineWeight()); + tasks.put(task, info); + } + else if (blocked.containsKey(task)) { + Blocked blockedTask = blocked.remove(task); + info.adjustWeight(blockedTask.savedDelta()); + } + + weight -= info.uncommittedWeight(); + weight += deltaWeight; + + info.commitWeight(deltaWeight); + runnableQueue.add(task, info.weight()); + baselineWeights.addOrReplace(task, info.weight()); + + updateState(); + } + + public T dequeue(long expectedWeight) + { + checkArgument(state == RUNNABLE); + + T task = runnableQueue.poll(); + + Task info = tasks.get(task); + info.setUncommittedWeight(expectedWeight); + weight += expectedWeight; + + baselineWeights.addOrReplace(task, info.weight()); + + updateState(); + + return task; + } + + public void finish(T task) + { + checkArgument(tasks.containsKey(task), "Unknown task: %s", task); + tasks.remove(task); + runnableQueue.removeIfPresent(task); + baselineWeights.removeIfPresent(task); + + updateState(); + } + + public void block(T task, long deltaWeight) + { + checkArgument(tasks.containsKey(task), "Unknown task: %s", task); + checkArgument(!runnableQueue.contains(task), "Task is already in queue: %s", task); + + weight += deltaWeight; + + Task info = tasks.get(task); + info.commitWeight(deltaWeight); + + blocked.put(task, new Blocked(weight - baselineWeight())); + baselineWeights.remove(task); + + updateState(); + } + + private long baselineWeight() + { + if (baselineWeights.isEmpty()) { + return 0; + } + + return baselineWeights.nextPriority(); + } + + public void adjustWeight(long delta) + { + weight += delta; + } + + private void updateState() + { + if (blocked.size() == tasks.size()) { + state = BLOCKED; + } + else if (runnableQueue.isEmpty()) { + state = RUNNING; + } + else { + state = RUNNABLE; + } + } + + public long weight() + { + return weight; + } + + public Set tasks() + { + return tasks.keySet(); + } + + public State state() + { + return state; + } + + public T peek() + { + return runnableQueue.peek(); + } + } + + private static class Task + { + private long weight; + private long uncommittedWeight; + + public Task(long initialWeight) + { + weight = initialWeight; + } + + public void commitWeight(long delta) + { + weight += delta; + uncommittedWeight = 0; + } + + public void adjustWeight(long delta) + { + weight += delta; + } + + public long weight() + { + return weight + uncommittedWeight; + } + + public void setUncommittedWeight(long weight) + { + this.uncommittedWeight = weight; + } + + public long uncommittedWeight() + { + return uncommittedWeight; + } + } + + public enum State + { + BLOCKED, // all tasks are blocked + RUNNING, // all tasks are dequeued and running + RUNNABLE // some tasks are enqueued and ready to run + } + + private record Blocked(long savedDelta) + { + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/TaskControl.java b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/TaskControl.java new file mode 100644 index 000000000000..732fe4990530 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/TaskControl.java @@ -0,0 +1,348 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import com.google.common.base.Ticker; +import com.google.errorprone.annotations.concurrent.GuardedBy; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +import static java.util.Objects.requireNonNull; + +/** + * Equality is based on group and id for the purpose of adding to the scheduling queue. + */ +final class TaskControl +{ + private final Group group; + private final int id; + private final Ticker ticker; + + private final Lock lock = new ReentrantLock(); + + @GuardedBy("lock") + private final Condition wakeup = lock.newCondition(); + + @GuardedBy("lock") + private boolean ready; + + @GuardedBy("lock") + private boolean blocked; + + @GuardedBy("lock") + private boolean cancelled; + + @GuardedBy("lock") + private State state; + + private volatile long periodStart; + private final AtomicLong scheduledNanos = new AtomicLong(); + private final AtomicLong blockedNanos = new AtomicLong(); + private final AtomicLong waitNanos = new AtomicLong(); + private volatile Thread thread; + + public TaskControl(Group group, int id, Ticker ticker) + { + this.group = requireNonNull(group, "group is null"); + this.id = id; + this.ticker = requireNonNull(ticker, "ticker is null"); + this.state = State.NEW; + this.ready = false; + this.periodStart = ticker.read(); + } + + public void setThread(Thread thread) + { + this.thread = thread; + } + + public void cancel() + { + lock.lock(); + try { + cancelled = true; + wakeup.signal(); + + // TODO: it should be possible to interrupt the thread, but + // it appears that it's not safe to do so. It can cause the query + // to get stuck (e.g., AbstractDistributedEngineOnlyQueries.testSelectiveLimit) + // + // Thread thread = this.thread; + // if (thread != null) { + // thread.interrupt(); + // } + } + finally { + lock.unlock(); + } + } + + /** + * Called by the scheduler thread when the task is ready to run. It + * causes anyone blocking in {@link #awaitReady()} to wake up. + * + * @return false if the task was already cancelled + */ + public boolean markReady() + { + lock.lock(); + try { + if (cancelled) { + return false; + } + ready = true; + wakeup.signal(); + } + finally { + lock.unlock(); + } + + return true; + } + + public void markNotReady() + { + lock.lock(); + try { + ready = false; + } + finally { + lock.unlock(); + } + } + + public boolean isReady() + { + lock.lock(); + try { + return ready; + } + finally { + lock.unlock(); + } + } + + /** + * @return false if the operation was interrupted due to cancellation + */ + public boolean awaitReady() + { + lock.lock(); + try { + while (!ready && !cancelled) { + try { + wakeup.await(); + } + catch (InterruptedException e) { + } + } + + return !cancelled; + } + finally { + lock.unlock(); + } + } + + public void markUnblocked() + { + lock.lock(); + try { + blocked = false; + wakeup.signal(); + } + finally { + lock.unlock(); + } + } + + public void markBlocked() + { + lock.lock(); + try { + blocked = true; + } + finally { + lock.unlock(); + } + } + + public void awaitUnblock() + { + lock.lock(); + try { + while (blocked && !cancelled) { + try { + wakeup.await(); + } + catch (InterruptedException e) { + } + } + } + finally { + lock.unlock(); + } + } + + /** + * @return false if the transition was unsuccessful due to the task being interrupted + */ + public boolean transitionToBlocked() + { + boolean success = transitionTo(State.BLOCKED); + + if (success) { + markBlocked(); + } + + return success; + } + + /** + * @return false if the transition was unsuccessful due to the task being interrupted + */ + public boolean transitionToWaiting() + { + boolean success = transitionTo(State.WAITING); + + if (success) { + markNotReady(); + } + + return success; + } + + /** + * @return false if the transition was unsuccessful due to the task being interrupted + */ + public boolean transitionToRunning() + { + return transitionTo(State.RUNNING); + } + + private boolean transitionTo(State state) + { + lock.lock(); + try { + recordPeriodEnd(); + + if (cancelled) { + this.state = State.INTERRUPTED; + return false; + } + else { + this.state = state; + return true; + } + } + finally { + lock.unlock(); + } + } + + private void recordPeriodEnd() + { + long now = ticker.read(); + long elapsed = now - periodStart; + switch (state) { + case RUNNING -> scheduledNanos.addAndGet(elapsed); + case BLOCKED -> blockedNanos.addAndGet(elapsed); + case NEW, WAITING -> waitNanos.addAndGet(elapsed); + default -> { + // make checkstyle happy. TODO: remove this + } + } + periodStart = now; + } + + public Group group() + { + return group; + } + + public State getState() + { + lock.lock(); + try { + return state; + } + finally { + lock.unlock(); + } + } + + public long elapsed() + { + return ticker.read() - periodStart; + } + + public long getWaitNanos() + { + return waitNanos.get(); + } + + public long getScheduledNanos() + { + return scheduledNanos.get(); + } + + public long getBlockedNanos() + { + return blockedNanos.get(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TaskControl that = (TaskControl) o; + return id == that.id && group.equals(that.group); + } + + @Override + public int hashCode() + { + return Objects.hash(group, id); + } + + @Override + public String toString() + { + lock.lock(); + try { + return group.name() + "-" + id + " [" + state + "]"; + } + finally { + lock.unlock(); + } + } + + public enum State + { + NEW, + WAITING, + RUNNING, + BLOCKED, + INTERRUPTED + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/group-state-diagram.dot b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/group-state-diagram.dot new file mode 100644 index 000000000000..bc1346753ee0 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor2/scheduler/group-state-diagram.dot @@ -0,0 +1,24 @@ +digraph Group { + node [shape=box]; + + start [shape=point]; + split1 [shape=point]; + split2 [shape=point]; + + + start -> blocked; + blocked -> runnable [label="enqueueTask()"]; + runnable -> runnable [label="enqueueTask()\nblockTask()"]; + runnable -> split1 [label="dequeueTask()"]; + split1 -> runnable [label="queue not empty"]; + split1 -> running [label="queue empty"]; + running -> split2 [label="blockTask()"]; + running -> runnable [label="enqueueTask()"]; + split2 -> blocked [label="all blocked"]; + split2 -> running [label="not all blocked"]; + blocked -> blocked [label="finishTask()"]; + running -> running [label="finishTask()"]; + runnable -> split3 [label="finishTask()"]; + split3 -> blocked [label="all blocked"]; + split3 -> running [label="all running"]; +} diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index 6d8c876fd61d..8ffb27347218 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -52,6 +52,7 @@ import io.trino.execution.executor.MultilevelSplitQueue; import io.trino.execution.executor.TaskExecutor; import io.trino.execution.executor.TimeSharingTaskExecutor; +import io.trino.execution.executor2.ThreadPerDriverTaskExecutor; import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.NodeSchedulerConfig; import io.trino.execution.scheduler.TopologyAwareNodeSelectorModule; @@ -309,10 +310,6 @@ protected void setup(Binder binder) newOptionalBinder(binder, VersionEmbedder.class).setDefault().to(EmbedVersion.class).in(Scopes.SINGLETON); newExporter(binder).export(SqlTaskManager.class).withGeneratedName(); - binder.bind(TaskExecutor.class) - .to(TimeSharingTaskExecutor.class) - .in(Scopes.SINGLETON); - newExporter(binder).export(TaskExecutor.class).withGeneratedName(); binder.bind(MultilevelSplitQueue.class).in(Scopes.SINGLETON); newExporter(binder).export(MultilevelSplitQueue.class).withGeneratedName(); @@ -323,6 +320,20 @@ protected void setup(Binder binder) binder.bind(PageFunctionCompiler.class).in(Scopes.SINGLETON); newExporter(binder).export(PageFunctionCompiler.class).withGeneratedName(); configBinder(binder).bindConfig(TaskManagerConfig.class); + + // TODO: use conditional module + TaskManagerConfig taskManagerConfig = buildConfigObject(TaskManagerConfig.class); + if (taskManagerConfig.isThreadPerDriverSchedulerEnabled()) { + binder.bind(TaskExecutor.class) + .to(ThreadPerDriverTaskExecutor.class) + .in(Scopes.SINGLETON); + } + else { + binder.bind(TaskExecutor.class) + .to(TimeSharingTaskExecutor.class) + .in(Scopes.SINGLETON); + } + if (retryPolicy == TASK) { configBinder(binder).bindConfigDefaults(TaskManagerConfig.class, TaskManagerConfig::applyFaultTolerantExecutionDefaults); } diff --git a/core/trino-main/src/main/java/io/trino/tracing/TrinoAttributes.java b/core/trino-main/src/main/java/io/trino/tracing/TrinoAttributes.java index c5cbcd9fdf6f..e9ce0836e9da 100644 --- a/core/trino-main/src/main/java/io/trino/tracing/TrinoAttributes.java +++ b/core/trino-main/src/main/java/io/trino/tracing/TrinoAttributes.java @@ -54,6 +54,7 @@ private TrinoAttributes() {} public static final AttributeKey SPLIT_SCHEDULED_TIME_NANOS = longKey("trino.split.scheduled_time_nanos"); public static final AttributeKey SPLIT_CPU_TIME_NANOS = longKey("trino.split.cpu_time_nanos"); public static final AttributeKey SPLIT_WAIT_TIME_NANOS = longKey("trino.split.wait_time_nanos"); + public static final AttributeKey SPLIT_BLOCK_TIME_NANOS = longKey("trino.split.block_time_nanos"); public static final AttributeKey SPLIT_BLOCKED = booleanKey("trino.split.blocked"); public static final AttributeKey EVENT_STATE = stringKey("state"); diff --git a/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java b/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java index 58e468ea2a1f..4241afb6e643 100644 --- a/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/BaseTestSqlTaskManager.java @@ -195,6 +195,7 @@ public void testAbort() try (SqlTaskManager sqlTaskManager = createSqlTaskManager(new TaskManagerConfig())) { TaskId taskId = newTaskId(); TaskInfo taskInfo = createTask(sqlTaskManager, taskId, PipelinedOutputBuffers.createInitial(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()); + assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); assertNull(taskInfo.getStats().getEndTime()); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerThreadPerDriver.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerThreadPerDriver.java new file mode 100644 index 000000000000..67007faf1486 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerThreadPerDriver.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.base.Ticker; +import io.airlift.tracing.Tracing; +import io.trino.execution.executor.TaskExecutor; +import io.trino.execution.executor2.ThreadPerDriverTaskExecutor; +import io.trino.execution.executor2.scheduler.FairScheduler; + +import static io.trino.version.EmbedVersion.testingVersionEmbedder; + +public class TestSqlTaskManagerThreadPerDriver + extends BaseTestSqlTaskManager +{ + @Override + protected TaskExecutor createTaskExecutor() + { + return new ThreadPerDriverTaskExecutor( + Tracing.noopTracer(), + testingVersionEmbedder(), + new FairScheduler(8, "Runner-%d", Ticker.systemTicker())); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java index 48e6603948d7..48bbf7c9d5cf 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java @@ -41,6 +41,7 @@ public class TestTaskManagerConfig public void testDefaults() { assertRecordedDefaults(recordDefaults(TaskManagerConfig.class) + .setThreadPerDriverSchedulerEnabled(true) .setInitialSplitsPerNode(Runtime.getRuntime().availableProcessors() * 2) .setSplitConcurrencyAdjustmentInterval(new Duration(100, TimeUnit.MILLISECONDS)) .setStatusRefreshMaxWait(new Duration(1, TimeUnit.SECONDS)) @@ -87,6 +88,7 @@ public void testExplicitPropertyMappings() int maxWriterCount = DEFAULT_SCALE_WRITERS_MAX_WRITER_COUNT == 32 ? 16 : 32; int partitionedWriterCount = DEFAULT_PARTITIONED_WRITER_COUNT == 64 ? 32 : 64; Map properties = ImmutableMap.builder() + .put("experimental.thread-per-split-scheduler-enabled", "false") .put("task.initial-splits-per-node", "1") .put("task.split-concurrency-adjustment-interval", "1s") .put("task.status-refresh-max-wait", "2s") @@ -127,6 +129,7 @@ public void testExplicitPropertyMappings() .buildOrThrow(); TaskManagerConfig expected = new TaskManagerConfig() + .setThreadPerDriverSchedulerEnabled(false) .setInitialSplitsPerNode(1) .setSplitConcurrencyAdjustmentInterval(new Duration(1, TimeUnit.SECONDS)) .setStatusRefreshMaxWait(new Duration(2, TimeUnit.SECONDS)) diff --git a/core/trino-main/src/test/java/io/trino/execution/executor2/TestThreadPerDriverTaskExecutor.java b/core/trino-main/src/test/java/io/trino/execution/executor2/TestThreadPerDriverTaskExecutor.java new file mode 100644 index 000000000000..47bca43cc0e3 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/executor2/TestThreadPerDriverTaskExecutor.java @@ -0,0 +1,258 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.AbstractFuture; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.testing.TestingTicker; +import io.airlift.units.Duration; +import io.opentelemetry.api.trace.Span; +import io.trino.execution.SplitRunner; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; +import io.trino.execution.TaskManagerConfig; +import io.trino.execution.executor.TaskHandle; +import io.trino.execution.executor2.scheduler.FairScheduler; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.List; +import java.util.OptionalInt; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.Phaser; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import static io.airlift.tracing.Tracing.noopTracer; +import static io.trino.version.EmbedVersion.testingVersionEmbedder; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestThreadPerDriverTaskExecutor +{ + @Test + @Timeout(10) + public void testCancellationWhileProcessing() + throws ExecutionException, InterruptedException + { + ThreadPerDriverTaskExecutor executor = new ThreadPerDriverTaskExecutor(new TaskManagerConfig(), noopTracer(), testingVersionEmbedder()); + executor.start(); + try { + TaskId taskId = new TaskId(new StageId("query", 1), 1, 1); + TaskHandle task = executor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + + CountDownLatch started = new CountDownLatch(1); + + SplitRunner split = new TestingSplitRunner(ImmutableList.of(duration -> { + started.countDown(); + try { + Thread.currentThread().join(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + return Futures.immediateVoidFuture(); + })); + + ListenableFuture splitDone = executor.enqueueSplits(task, false, ImmutableList.of(split)).get(0); + + started.await(); + executor.removeTask(task); + + splitDone.get(); + assertThat(split.isFinished()).isTrue(); + } + finally { + executor.stop(); + } + } + + @Test + @Timeout(10) + public void testBlocking() + throws ExecutionException, InterruptedException + { + ThreadPerDriverTaskExecutor executor = new ThreadPerDriverTaskExecutor(new TaskManagerConfig(), noopTracer(), testingVersionEmbedder()); + executor.start(); + + try { + TaskId taskId = new TaskId(new StageId("query", 1), 1, 1); + TaskHandle task = executor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + + TestFuture blocked = new TestFuture(); + + SplitRunner split = new TestingSplitRunner(ImmutableList.of( + duration -> blocked, + duration -> Futures.immediateVoidFuture())); + + ListenableFuture splitDone = executor.enqueueSplits(task, false, ImmutableList.of(split)).get(0); + + blocked.awaitListenerAdded(); + blocked.set(null); // unblock the split + + splitDone.get(); + assertThat(split.isFinished()).isTrue(); + } + finally { + executor.stop(); + } + } + + @Test + @Timeout(10) + public void testYielding() + throws ExecutionException, InterruptedException + { + TestingTicker ticker = new TestingTicker(); + FairScheduler scheduler = new FairScheduler(1, "Runner-%d", ticker); + ThreadPerDriverTaskExecutor executor = new ThreadPerDriverTaskExecutor(noopTracer(), testingVersionEmbedder(), scheduler); + executor.start(); + + try { + TaskId taskId = new TaskId(new StageId("query", 1), 1, 1); + TaskHandle task = executor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + + Phaser phaser = new Phaser(2); + SplitRunner split = new TestingSplitRunner(ImmutableList.of( + duration -> { + phaser.arriveAndAwaitAdvance(); // wait to start + phaser.arriveAndAwaitAdvance(); // wait to advance time + return Futures.immediateVoidFuture(); + }, + duration -> { + phaser.arriveAndAwaitAdvance(); + return Futures.immediateVoidFuture(); + })); + + ListenableFuture splitDone = executor.enqueueSplits(task, false, ImmutableList.of(split)).get(0); + + phaser.arriveAndAwaitAdvance(); // wait for split to start + + // cause the task to yield + ticker.increment(FairScheduler.QUANTUM_NANOS * 2, TimeUnit.NANOSECONDS); + phaser.arriveAndAwaitAdvance(); + + // wait for reschedule + assertThat(phaser.arriveAndAwaitAdvance()).isEqualTo(3); // wait for reschedule + + splitDone.get(); + assertThat(split.isFinished()).isTrue(); + } + finally { + executor.stop(); + } + } + + private static class TestFuture + extends AbstractFuture + { + private final CountDownLatch listenerAdded = new CountDownLatch(1); + + @Override + public void addListener(Runnable listener, Executor executor) + { + super.addListener(listener, executor); + listenerAdded.countDown(); + } + + @Override + public boolean set(Void value) + { + return super.set(value); + } + + public void awaitListenerAdded() + throws InterruptedException + { + listenerAdded.await(); + } + } + + private static class TestingSplitRunner + implements SplitRunner + { + private final List>> invocations; + private int invocation; + private volatile boolean finished; + private volatile Thread runnerThread; + + public TestingSplitRunner(List>> invocations) + { + this.invocations = invocations; + } + + @Override + public final int getPipelineId() + { + return 0; + } + + @Override + public final Span getPipelineSpan() + { + return Span.getInvalid(); + } + + @Override + public final boolean isFinished() + { + return finished; + } + + @Override + public final ListenableFuture processFor(Duration duration) + { + ListenableFuture blocked; + + runnerThread = Thread.currentThread(); + try { + blocked = invocations.get(invocation).apply(duration); + } + finally { + runnerThread = null; + } + + invocation++; + + if (invocation == invocations.size()) { + finished = true; + } + + return blocked; + } + + @Override + public final String getInfo() + { + return ""; + } + + @Override + public final void close() + { + finished = true; + + Thread runnerThread = this.runnerThread; + + if (runnerThread != null) { + runnerThread.interrupt(); + } + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestFairScheduler.java b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestFairScheduler.java new file mode 100644 index 000000000000..446be1d4e49e --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestFairScheduler.java @@ -0,0 +1,227 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import com.google.common.util.concurrent.AbstractFuture; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.testing.TestingTicker; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.assertj.core.api.Assertions.assertThat; + +public class TestFairScheduler +{ + @Test + public void testBasic() + throws ExecutionException, InterruptedException + { + try (FairScheduler scheduler = FairScheduler.newInstance(1)) { + Group group = scheduler.createGroup("G1"); + + AtomicBoolean ran = new AtomicBoolean(); + ListenableFuture done = scheduler.submit(group, 1, context -> ran.set(true)); + + done.get(); + assertThat(ran.get()) + .describedAs("Ran task") + .isTrue(); + } + } + + @Test + @Timeout(5) + public void testYield() + throws ExecutionException, InterruptedException + { + TestingTicker ticker = new TestingTicker(); + try (FairScheduler scheduler = FairScheduler.newInstance(1, ticker)) { + Group group = scheduler.createGroup("G"); + + CountDownLatch task1Started = new CountDownLatch(1); + AtomicBoolean task2Ran = new AtomicBoolean(); + + ListenableFuture task1 = scheduler.submit(group, 1, context -> { + task1Started.countDown(); + while (!task2Ran.get()) { + if (!context.maybeYield()) { + return; + } + } + }); + + task1Started.await(); + + ListenableFuture task2 = scheduler.submit(group, 2, context -> { + task2Ran.set(true); + }); + + while (!task2.isDone()) { + ticker.increment(FairScheduler.QUANTUM_NANOS * 2, TimeUnit.NANOSECONDS); + } + + task1.get(); + } + } + + @Test + public void testBlocking() + throws InterruptedException, ExecutionException + { + try (FairScheduler scheduler = FairScheduler.newInstance(1)) { + Group group = scheduler.createGroup("G"); + + CountDownLatch task1Started = new CountDownLatch(1); + CountDownLatch task2Submitted = new CountDownLatch(1); + CountDownLatch task2Started = new CountDownLatch(1); + AtomicBoolean task2Ran = new AtomicBoolean(); + + SettableFuture task1Blocked = SettableFuture.create(); + + ListenableFuture task1 = scheduler.submit(group, 1, context -> { + try { + task1Started.countDown(); + task2Submitted.await(); + + assertThat(task2Ran.get()) + .describedAs("Task 2 run") + .isFalse(); + + context.block(task1Blocked); + + assertThat(task2Ran.get()) + .describedAs("Task 2 run") + .isTrue(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + }); + + task1Started.await(); + + ListenableFuture task2 = scheduler.submit(group, 2, context -> { + task2Started.countDown(); + task2Ran.set(true); + }); + + task2Submitted.countDown(); + task2Started.await(); + + // unblock task 1 + task1Blocked.set(null); + + task1.get(); + task2.get(); + } + } + + @Test + public void testCancelWhileYielding() + throws InterruptedException, ExecutionException + { + TestingTicker ticker = new TestingTicker(); + try (FairScheduler scheduler = FairScheduler.newInstance(1, ticker)) { + Group group = scheduler.createGroup("G"); + + CountDownLatch task1Started = new CountDownLatch(1); + CountDownLatch task1TimeAdvanced = new CountDownLatch(1); + + ListenableFuture task1 = scheduler.submit(group, 1, context -> { + try { + task1Started.countDown(); + task1TimeAdvanced.await(); + + assertThat(context.maybeYield()) + .describedAs("Cancelled while yielding") + .isFalse(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + }); + + task1Started.await(); + scheduler.pause(); // prevent rescheduling after yield + + ticker.increment(FairScheduler.QUANTUM_NANOS * 2, TimeUnit.NANOSECONDS); + task1TimeAdvanced.countDown(); + + scheduler.removeGroup(group); + task1.get(); + } + } + + @Test + public void testCancelWhileBlocking() + throws InterruptedException, ExecutionException + { + TestingTicker ticker = new TestingTicker(); + try (FairScheduler scheduler = FairScheduler.newInstance(1, ticker)) { + Group group = scheduler.createGroup("G"); + + CountDownLatch task1Started = new CountDownLatch(1); + TestFuture task1Blocked = new TestFuture(); + + ListenableFuture task1 = scheduler.submit(group, 1, context -> { + task1Started.countDown(); + + assertThat(context.block(task1Blocked)) + .describedAs("Cancelled while blocking") + .isFalse(); + }); + + task1Started.await(); + + task1Blocked.awaitListenerAdded(); // When the listener is added, we know the task is blocked + + scheduler.removeGroup(group); + task1.get(); + } + } + + private static class TestFuture + extends AbstractFuture + { + private final CountDownLatch listenerAdded = new CountDownLatch(1); + + @Override + public void addListener(Runnable listener, Executor executor) + { + super.addListener(listener, executor); + listenerAdded.countDown(); + } + + @Override + public boolean set(Void value) + { + return super.set(value); + } + + public void awaitListenerAdded() + throws InterruptedException + { + listenerAdded.await(); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestPriorityQueue.java b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestPriorityQueue.java new file mode 100644 index 000000000000..ab21233f86a0 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestPriorityQueue.java @@ -0,0 +1,196 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import com.google.common.collect.ImmutableSet; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestPriorityQueue +{ + @Test + public void testEmpty() + { + PriorityQueue queue = new PriorityQueue<>(); + + assertThat(queue.poll()).isNull(); + assertThat(queue.isEmpty()).isTrue(); + } + + @Test + public void testNotEmpty() + { + PriorityQueue queue = new PriorityQueue<>(); + + queue.add("hello", 1); + assertThat(queue.isEmpty()).isFalse(); + } + + @Test + public void testDuplicate() + { + PriorityQueue queue = new PriorityQueue<>(); + + queue.add("hello", 1); + assertThatThrownBy(() -> queue.add("hello", 2)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void testOrder() + { + PriorityQueue queue = new PriorityQueue<>(); + + queue.add("jumps", 5); + queue.add("fox", 4); + queue.add("over", 6); + queue.add("brown", 3); + queue.add("dog", 8); + queue.add("the", 1); + queue.add("lazy", 7); + queue.add("quick", 2); + + assertThat(queue.poll()).isEqualTo("the"); + assertThat(queue.poll()).isEqualTo("quick"); + assertThat(queue.poll()).isEqualTo("brown"); + assertThat(queue.poll()).isEqualTo("fox"); + assertThat(queue.poll()).isEqualTo("jumps"); + assertThat(queue.poll()).isEqualTo("over"); + assertThat(queue.poll()).isEqualTo("lazy"); + assertThat(queue.poll()).isEqualTo("dog"); + assertThat(queue.poll()).isNull(); + } + + @Test + public void testInterleaved() + { + PriorityQueue queue = new PriorityQueue<>(); + + queue.add("jumps", 5); + queue.add("over", 6); + queue.add("fox", 4); + + assertThat(queue.poll()).isEqualTo("fox"); + assertThat(queue.poll()).isEqualTo("jumps"); + + queue.add("brown", 3); + queue.add("dog", 8); + queue.add("the", 1); + + assertThat(queue.poll()).isEqualTo("the"); + assertThat(queue.poll()).isEqualTo("brown"); + assertThat(queue.poll()).isEqualTo("over"); + + queue.add("lazy", 7); + queue.add("quick", 2); + + assertThat(queue.poll()).isEqualTo("quick"); + assertThat(queue.poll()).isEqualTo("lazy"); + assertThat(queue.poll()).isEqualTo("dog"); + assertThat(queue.poll()).isNull(); + } + + @Test + public void testRemove() + { + PriorityQueue queue = new PriorityQueue<>(); + + queue.add("fox", 4); + queue.add("brown", 3); + queue.add("the", 1); + queue.add("quick", 2); + + queue.remove("brown"); + + assertThat(queue.poll()).isEqualTo("the"); + assertThat(queue.poll()).isEqualTo("quick"); + assertThat(queue.poll()).isEqualTo("fox"); + assertThat(queue.poll()).isNull(); + } + + @Test + public void testRemoveMissing() + { + PriorityQueue queue = new PriorityQueue<>(); + + queue.add("the", 1); + queue.add("quick", 2); + queue.add("brown", 3); + + assertThatThrownBy(() -> queue.remove("fox")) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void testContains() + { + PriorityQueue queue = new PriorityQueue<>(); + + queue.add("the", 1); + queue.add("quick", 2); + queue.add("brown", 3); + + assertThat(queue.contains("quick")).isTrue(); + assertThat(queue.contains("fox")).isFalse(); + } + + @Test + public void testRecycle() + { + PriorityQueue queue = new PriorityQueue<>(); + + queue.add("hello", 1); + assertThat(queue.poll()).isEqualTo("hello"); + + queue.add("hello", 2); + assertThat(queue.poll()).isEqualTo("hello"); + } + + @Test + public void testValues() + { + PriorityQueue queue = new PriorityQueue<>(); + + assertThat(queue.values()).isEmpty(); + + queue.add("hello", 1); + queue.add("world", 2); + + assertThat(queue.values()) + .isEqualTo(ImmutableSet.of("hello", "world")); + } + + @Test + public void testNextPriority() + { + PriorityQueue queue = new PriorityQueue<>(); + + assertThatThrownBy(queue::nextPriority) + .isInstanceOf(IllegalStateException.class); + + queue.add("hello", 10); + queue.add("world", 20); + + assertThat(queue.nextPriority()).isEqualTo(10); + + queue.poll(); + assertThat(queue.nextPriority()).isEqualTo(20); + + queue.poll(); + assertThatThrownBy(queue::nextPriority) + .isInstanceOf(IllegalStateException.class); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestSchedulingQueue.java b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestSchedulingQueue.java new file mode 100644 index 000000000000..4fdfc7570934 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/executor2/scheduler/TestSchedulingQueue.java @@ -0,0 +1,286 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor2.scheduler; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +public class TestSchedulingQueue +{ + @Test + public void testEmpty() + { + SchedulingQueue queue = new SchedulingQueue<>(); + + assertThat(queue.dequeue(1)).isNull(); + } + + @Test + public void testSingleGroup() + { + SchedulingQueue queue = new SchedulingQueue<>(); + + queue.startGroup("G1"); + + queue.enqueue("G1", "T1", 1); + queue.enqueue("G1", "T2", 3); + queue.enqueue("G1", "T3", 5); + queue.enqueue("G1", "T4", 7); + + assertThat(queue.dequeue(1)).isEqualTo("T1"); + assertThat(queue.dequeue(1)).isEqualTo("T2"); + assertThat(queue.dequeue(1)).isEqualTo("T3"); + assertThat(queue.dequeue(1)).isEqualTo("T4"); + + queue.enqueue("G1", "T1", 10); + queue.enqueue("G1", "T2", 10); + queue.enqueue("G1", "T3", 10); + queue.enqueue("G1", "T4", 10); + + assertThat(queue.dequeue(1)).isEqualTo("T1"); + assertThat(queue.dequeue(1)).isEqualTo("T2"); + assertThat(queue.dequeue(1)).isEqualTo("T3"); + assertThat(queue.dequeue(1)).isEqualTo("T4"); + + queue.enqueue("G1", "T1", 16); + queue.enqueue("G1", "T2", 12); + queue.enqueue("G1", "T3", 8); + queue.enqueue("G1", "T4", 4); + + assertThat(queue.dequeue(1)).isEqualTo("T4"); + assertThat(queue.dequeue(1)).isEqualTo("T3"); + assertThat(queue.dequeue(1)).isEqualTo("T2"); + assertThat(queue.dequeue(1)).isEqualTo("T1"); + + queue.finish("G1", "T1"); + queue.finish("G1", "T2"); + queue.finish("G1", "T3"); + queue.finish("G1", "T4"); + + assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.BLOCKED); + } + + @Test + public void testBasic() + { + SchedulingQueue queue = new SchedulingQueue<>(); + + queue.startGroup("G1"); + queue.startGroup("G2"); + + queue.enqueue("G1", "T1.0", 1); + queue.enqueue("G1", "T1.1", 2); + queue.enqueue("G2", "T2.0", 3); + queue.enqueue("G2", "T2.1", 4); + + assertThat(queue.dequeue(1)).isEqualTo("T1.0"); + assertThat(queue.dequeue(1)).isEqualTo("T1.1"); + assertThat(queue.dequeue(1)).isEqualTo("T2.0"); + assertThat(queue.dequeue(1)).isEqualTo("T2.1"); + + queue.enqueue("G1", "T1.0", 10); + queue.enqueue("G1", "T1.1", 20); + queue.enqueue("G2", "T2.0", 15); + queue.enqueue("G2", "T2.1", 5); + + assertThat(queue.dequeue(1)).isEqualTo("T2.1"); + assertThat(queue.dequeue(1)).isEqualTo("T2.0"); + assertThat(queue.dequeue(1)).isEqualTo("T1.0"); + assertThat(queue.dequeue(1)).isEqualTo("T1.1"); + + queue.enqueue("G1", "T1.0", 100); + queue.enqueue("G2", "T2.0", 90); + assertThat(queue.dequeue(1)).isEqualTo("T2.0"); + assertThat(queue.dequeue(1)).isEqualTo("T1.0"); + } + + @Test + public void testSomeEmptyGroups() + { + SchedulingQueue queue = new SchedulingQueue<>(); + + queue.startGroup("G1"); + queue.startGroup("G2"); + + queue.enqueue("G2", "T1", 0); + + assertThat(queue.dequeue(1)).isEqualTo("T1"); + } + + @Test + public void testDelayedCreation() + { + SchedulingQueue queue = new SchedulingQueue<>(); + + queue.startGroup("G1"); + queue.startGroup("G2"); + + queue.enqueue("G1", "T1.0", 100); + queue.enqueue("G2", "T2.0", 200); + + queue.startGroup("G3"); // new group gets a priority baseline equal to the minimum current priority + queue.enqueue("G3", "T3.0", 50); + + assertThat(queue.dequeue(1)).isEqualTo("T1.0"); + assertThat(queue.dequeue(1)).isEqualTo("T3.0"); + assertThat(queue.dequeue(1)).isEqualTo("T2.0"); + } + + @Test + public void testDelayedCreationWhileAllRunning() + { + SchedulingQueue queue = new SchedulingQueue<>(); + + queue.startGroup("G1"); + queue.startGroup("G2"); + + queue.enqueue("G1", "T1.0", 0); + + queue.enqueue("G2", "T2.0", 100); + queue.dequeue(50); + queue.dequeue(50); + + queue.startGroup("G3"); // new group gets a priority baseline equal to the minimum current priority + queue.enqueue("G3", "T3.0", 10); + + queue.enqueue("G1", "T1.0", 50); + queue.enqueue("G2", "T2.0", 50); + + assertThat(queue.dequeue(1)).isEqualTo("T1.0"); + assertThat(queue.dequeue(1)).isEqualTo("T3.0"); + assertThat(queue.dequeue(1)).isEqualTo("T2.0"); + } + + @Test + public void testGroupState() + { + SchedulingQueue queue = new SchedulingQueue<>(); + + // initial state with no tasks + queue.startGroup("G1"); + assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.BLOCKED); + + // after adding a task, it should be runnable + queue.enqueue("G1", "T1", 0); + assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNABLE); + queue.enqueue("G1", "T2", 0); + assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNABLE); + + // after dequeueing, still runnable if there's at least one runnable task + queue.dequeue(1); + assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNABLE); + + // after all tasks are dequeued, it should be running + queue.dequeue(1); + assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNING); + + // still running while at least one task is running and there are no runnable tasks + queue.block("G1", "T1", 1); + assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNING); + + // blocked when all tasks are blocked + queue.block("G1", "T2", 1); + assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.BLOCKED); + + // back to runnable after unblocking + queue.enqueue("G1", "T1", 1); + assertThat(queue.state("G1")).isEqualTo(SchedulingQueue.State.RUNNABLE); + } + + @Test + public void testNonGreedyDeque() + { + SchedulingQueue queue = new SchedulingQueue<>(); + + queue.startGroup("G1"); + queue.startGroup("G2"); + + queue.enqueue("G1", "T1.0", 0); + queue.enqueue("G2", "T2.0", 1); + + queue.enqueue("G1", "T1.1", 2); + queue.enqueue("G1", "T1.2", 3); + + queue.enqueue("G2", "T2.1", 2); + queue.enqueue("G2", "T2.2", 3); + + assertThat(queue.dequeue(2)).isEqualTo("T1.0"); + assertThat(queue.dequeue(2)).isEqualTo("T2.0"); + assertThat(queue.dequeue(2)).isEqualTo("T1.1"); + assertThat(queue.dequeue(2)).isEqualTo("T2.1"); + assertThat(queue.dequeue(2)).isEqualTo("T1.2"); + assertThat(queue.dequeue(2)).isEqualTo("T2.2"); + assertThat(queue.dequeue(2)).isNull(); + } + + @Test + public void testFinishTask() + { + SchedulingQueue queue = new SchedulingQueue<>(); + + queue.startGroup("G1"); + queue.enqueue("G1", "T1", 0); + queue.enqueue("G1", "T2", 1); + queue.enqueue("G1", "T3", 2); + + assertThat(queue.peek()).isEqualTo("T1"); + queue.finish("G1", "T1"); + assertThat(queue.peek()).isEqualTo("T2"); + + // check that the group becomes not-runnable + queue.finish("G1", "T2"); + queue.finish("G1", "T3"); + assertThat(queue.peek()).isNull(); + + // check that the group becomes runnable again + queue.enqueue("G1", "T4", 0); + assertThat(queue.peek()).isEqualTo("T4"); + } + + @Test + public void testFinishGroup() + { + SchedulingQueue queue = new SchedulingQueue<>(); + + queue.startGroup("G1"); + queue.enqueue("G1", "T1.1", 0); + assertThat(queue.peek()).isEqualTo("T1.1"); + + queue.startGroup("G2"); + queue.enqueue("G2", "T2.1", 1); + assertThat(queue.peek()).isEqualTo("T1.1"); + + queue.finishGroup("G1"); + assertThat(queue.peek()).isEqualTo("T2.1"); + } + + record Group(int id) + { + @Override + public String toString() + { + return "G" + id; + } + } + + record Task(Group group, int id) + { + @Override + public String toString() + { + return "T" + group + "." + id; + } + } +} diff --git a/pom.xml b/pom.xml index 4bbf755b3bed..961fbd5d31ec 100644 --- a/pom.xml +++ b/pom.xml @@ -2411,7 +2411,8 @@ -Xep:StreamResourceLeak:ERROR \ -Xep:UnnecessaryMethodReference:ERROR \ -Xep:UnnecessaryOptionalGet:ERROR \ - -Xep:UnusedVariable:ERROR \ + + -Xep:UseEnumSwitch:ERROR \ -XepExcludedPaths:.*/target/generated-(|test-)sources/.* diff --git a/testing/trino-server-dev/etc/config.properties b/testing/trino-server-dev/etc/config.properties index b786657e8d2b..6692959ccb50 100644 --- a/testing/trino-server-dev/etc/config.properties +++ b/testing/trino-server-dev/etc/config.properties @@ -60,3 +60,6 @@ plugin.bundles=\ ../../plugin/trino-mysql-event-listener/pom.xml node-scheduler.include-coordinator=true + +tracing.enabled=true +tracing.exporter.endpoint=http://localhost:4317