Skip to content

Commit

Permalink
Dispatcher to use caller thread instead of dedicated scheduler thread (
Browse files Browse the repository at this point in the history
…#822)

Remove the internal single-thread scheduler and rely on the caller
thread to submit all relevant tasks to a given thread pool. This not
only simplified the solution, but also:
- helped avoid context propagation issues when execution switches
between multiple threads
- made the tool more Loom-friendly since instances of
`ParallelCollectors` do not create their own threads

----
backport:
c48c915
  • Loading branch information
pivovarit authored Jan 27, 2024
1 parent bd169ad commit ca27d4b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 149 deletions.
21 changes: 6 additions & 15 deletions src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,12 @@ public BinaryOperator<List<CompletableFuture<R>>> combiner() {

@Override
public BiConsumer<List<CompletableFuture<R>>, T> accumulator() {
return (acc, e) -> {
if (!dispatcher.isRunning()) {
dispatcher.start();
}
acc.add(dispatcher.enqueue(() -> mapper.apply(e)));
};
return (acc, e) -> acc.add(dispatcher.enqueue(() -> mapper.apply(e)));
}

@Override
public Function<List<CompletableFuture<R>>, CompletableFuture<C>> finisher() {
return futures -> {
dispatcher.stop();

return combine(futures).thenApply(processor);
};
return futures -> combine(futures).thenApply(processor);
}

@Override
Expand Down Expand Up @@ -105,7 +96,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T

return parallelism == 1
? asyncCollector(mapper, executor, i -> i)
: new AsyncParallelCollector<>(mapper, Dispatcher.of(executor, parallelism), t -> t);
: new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), t -> t);
}

static <T, R, RR> Collector<T, ?, CompletableFuture<RR>> collectingWithCollector(Collector<R, ?, RR> collector, Function<T, R> mapper, Executor executor) {
Expand All @@ -120,7 +111,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T

return parallelism == 1
? asyncCollector(mapper, executor, s -> s.collect(collector))
: new AsyncParallelCollector<>(mapper, Dispatcher.of(executor, parallelism), s -> s.collect(collector));
: new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), s -> s.collect(collector));
}

static void requireValidParallelism(int parallelism) {
Expand Down Expand Up @@ -176,13 +167,13 @@ private BatchingCollectors() {
return list.stream()
.collect(new AsyncParallelCollector<>(
mapper,
Dispatcher.of(executor, parallelism),
Dispatcher.from(executor, parallelism),
finisher));
} else {
return partitioned(list, parallelism)
.collect(new AsyncParallelCollector<>(
batching(mapper),
Dispatcher.of(executor, parallelism),
Dispatcher.from(executor, parallelism),
listStream -> finisher.apply(listStream.flatMap(Collection::stream))));
}
});
Expand Down
131 changes: 35 additions & 96 deletions src/main/java/com/pivovarit/collectors/Dispatcher.java
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
package com.pivovarit.collectors;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import static java.lang.Runtime.getRuntime;
Expand All @@ -22,104 +14,61 @@
*/
final class Dispatcher<T> {

private static final Runnable POISON_PILL = () -> System.out.println("Why so serious?");

private final CompletableFuture<Void> completionSignaller = new CompletableFuture<>();

private final BlockingQueue<Runnable> workingQueue = new LinkedBlockingQueue<>();

private final ExecutorService dispatcher = newLazySingleThreadExecutor();
private final Executor executor;
private final Semaphore limiter;

private final AtomicBoolean started = new AtomicBoolean(false);

private volatile boolean shortCircuited = false;

private Dispatcher(Executor executor, int permits) {
this.executor = executor;
this.limiter = new Semaphore(permits);
}

static <T> Dispatcher<T> of(Executor executor, int permits) {
static <T> Dispatcher<T> from(Executor executor, int permits) {
return new Dispatcher<>(executor, permits);
}

void start() {
if (!started.getAndSet(true)) {
dispatcher.execute(() -> {
try {
while (true) {
Runnable task;
if ((task = workingQueue.take()) != POISON_PILL) {
limiter.acquire();
executor.execute(withFinally(task, limiter::release));
} else {
break;
}
}
} catch (Throwable e) {
handle(e);
}
});
}
}

void stop() {
CompletableFuture<T> enqueue(Supplier<T> supplier) {
InterruptibleCompletableFuture<T> future = new InterruptibleCompletableFuture<>();
completionSignaller.whenComplete(shortcircuit(future));
try {
workingQueue.put(POISON_PILL);
} catch (InterruptedException e) {
executor.execute(completionTask(supplier, future));
} catch (Throwable e) {
completionSignaller.completeExceptionally(e);
} finally {
dispatcher.shutdown();
CompletableFuture<T> result = new CompletableFuture<>();
result.completeExceptionally(e);
return result;
}
}

boolean isRunning() {
return started.get();
}

CompletableFuture<T> enqueue(Supplier<T> supplier) {
InterruptibleCompletableFuture<T> future = new InterruptibleCompletableFuture<>();
workingQueue.add(completionTask(supplier, future));
completionSignaller.exceptionally(shortcircuit(future));
return future;
}

private FutureTask<Void> completionTask(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) {
FutureTask<Void> task = new FutureTask<>(() -> {
try {
if (!shortCircuited) {
future.complete(supplier.get());
private FutureTask<T> completionTask(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) {
FutureTask<T> task = new FutureTask<>(() -> {
if (!completionSignaller.isCompletedExceptionally()) {
try {
withLimiter(supplier, future);
} catch (Throwable e) {
completionSignaller.completeExceptionally(e);
}
} catch (Throwable e) {
handle(e);
}
}, null);
future.completedBy(task);
return task;
}

private void handle(Throwable e) {
shortCircuited = true;
completionSignaller.completeExceptionally(e);
dispatcher.shutdownNow();
}

private static Function<Throwable, Void> shortcircuit(InterruptibleCompletableFuture<?> future) {
return throwable -> {
future.completeExceptionally(throwable);
future.cancel(true);
return null;
};
private void withLimiter(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) throws InterruptedException {
try {
limiter.acquire();
future.complete(supplier.get());
} finally {
limiter.release();
}
}

private static Runnable withFinally(Runnable task, Runnable finisher) {
return () -> {
try {
task.run();
} finally {
finisher.run();
private static <T> BiConsumer<T, Throwable> shortcircuit(InterruptibleCompletableFuture<?> future) {
return (__, throwable) -> {
if (throwable != null) {
future.completeExceptionally(throwable);
future.cancel(true);
}
};
}
Expand All @@ -128,29 +77,19 @@ static int getDefaultParallelism() {
return Math.max(getRuntime().availableProcessors() - 1, 4);
}

private static ThreadPoolExecutor newLazySingleThreadExecutor() {
return new ThreadPoolExecutor(0, 1,
0L, TimeUnit.MILLISECONDS,
new SynchronousQueue<>(),
task -> {
Thread thread = Executors.defaultThreadFactory().newThread(task);
thread.setName("parallel-collector-" + thread.getName());
thread.setDaemon(false);
return thread;
});
}

static final class InterruptibleCompletableFuture<T> extends CompletableFuture<T> {
private volatile FutureTask<?> backingTask;

private void completedBy(FutureTask<Void> task) {
private volatile FutureTask<T> backingTask;

private void completedBy(FutureTask<T> task) {
backingTask = task;
}

@Override
public boolean cancel(boolean mayInterruptIfRunning) {
if (backingTask != null) {
backingTask.cancel(mayInterruptIfRunning);
FutureTask<T> task = backingTask;
if (task != null) {
task.cancel(mayInterruptIfRunning);
}
return super.cancel(mayInterruptIfRunning);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ public Supplier<List<CompletableFuture<R>>> supplier() {

@Override
public BiConsumer<List<CompletableFuture<R>>, T> accumulator() {
return (acc, e) -> {
dispatcher.start();
acc.add(dispatcher.enqueue(() -> function.apply(e)));
};
return (acc, e) -> acc.add(dispatcher.enqueue(() -> function.apply(e)));
}

@Override
Expand All @@ -74,10 +71,7 @@ public BinaryOperator<List<CompletableFuture<R>>> combiner() {

@Override
public Function<List<CompletableFuture<R>>, Stream<R>> finisher() {
return acc -> {
dispatcher.stop();
return completionStrategy.apply(acc);
};
return completionStrategy;
}

@Override
Expand All @@ -94,7 +88,7 @@ public Set<Characteristics> characteristics() {
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);

return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.of(executor, parallelism));
return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor, parallelism));
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor) {
Expand All @@ -107,7 +101,7 @@ public Set<Characteristics> characteristics() {
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);

return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.of(executor, parallelism));
return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor, parallelism));
}

static final class BatchingCollectors {
Expand Down Expand Up @@ -149,15 +143,15 @@ private BatchingCollectors() {
mapper,
ordered(),
emptySet(),
Dispatcher.of(executor, parallelism)));
Dispatcher.from(executor, parallelism)));
}
else {
return partitioned(list, parallelism)
.collect(collectingAndThen(new ParallelStreamCollector<>(
batching(mapper),
ordered(),
emptySet(),
Dispatcher.of(executor, parallelism)),
Dispatcher.from(executor, parallelism)),
s -> s.flatMap(Collection::stream)));
}
});
Expand Down
26 changes: 0 additions & 26 deletions src/test/java/com/pivovarit/collectors/FunctionalTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ private static <R extends Collection<Integer>> Stream<DynamicTest> tests(Collect
shouldCollectNElementsWithNParallelism(collector, name, PARALLELISM),
shouldCollectToEmpty(collector, name),
shouldStartConsumingImmediately(collector, name),
shouldTerminateAfterConsumingAllElements(collector, name),
shouldNotBlockTheCallingThread(collector, name),
shouldRespectParallelism(collector, name),
shouldHandleThrowable(collector, name),
Expand All @@ -184,7 +183,6 @@ private static <R extends Collection<Integer>> Stream<DynamicTest> streamingTest
shouldCollect(collector, name, PARALLELISM),
shouldCollectToEmpty(collector, name),
shouldStartConsumingImmediately(collector, name),
shouldTerminateAfterConsumingAllElements(collector, name),
shouldNotBlockTheCallingThread(collector, name),
shouldRespectParallelism(collector, name),
shouldHandleThrowable(collector, name),
Expand Down Expand Up @@ -286,30 +284,6 @@ private static <R extends Collection<Integer>> DynamicTest shouldCollectNElement
});
}

private static <R extends Collection<Integer>> DynamicTest shouldTerminateAfterConsumingAllElements(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> factory, String name) {
return dynamicTest(format("%s: should terminate after consuming all elements", name), () -> {
List<Integer> elements = IntStream.range(0, 10).boxed().collect(toList());
Collector<Integer, ?, CompletableFuture<R>> ctor = factory.apply(i -> i, executor, 10);
Collection<Integer> result = elements.stream().collect(ctor)
.join();

assertThat(result).hasSameElementsAs(elements);

if (ctor instanceof AsyncParallelCollector) {
Field dispatcherField = AsyncParallelCollector.class.getDeclaredField("dispatcher");
dispatcherField.setAccessible(true);
Dispatcher<?> dispatcher = (Dispatcher<?>) dispatcherField.get(ctor);
Field innerDispatcherField = Dispatcher.class.getDeclaredField("dispatcher");
innerDispatcherField.setAccessible(true);
ExecutorService executor = (ExecutorService) innerDispatcherField.get(dispatcher);

await()
.atMost(Duration.ofSeconds(2))
.until(executor::isTerminated);
}
});
}

private static <R extends Collection<Integer>> DynamicTest shouldMaintainOrder(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return dynamicTest(format("%s: should maintain order", name), () -> {
int parallelism = 4;
Expand Down

0 comments on commit ca27d4b

Please sign in to comment.