From e03876ac7eecdefec52be106491a999af2f17c3f Mon Sep 17 00:00:00 2001 From: Grzegorz Piwowarek Date: Wed, 1 May 2024 09:53:19 +0200 Subject: [PATCH] Control max parallelism on a dedicated thread --- .../collectors/AsyncParallelCollector.java | 13 ++- .../com/pivovarit/collectors/Dispatcher.java | 109 ++++++++++++++---- .../collectors/ParallelStreamCollector.java | 10 +- .../collectors/ExecutorPollutionTest.java | 63 ++++++++++ 4 files changed, 166 insertions(+), 29 deletions(-) create mode 100644 src/test/java/com/pivovarit/collectors/ExecutorPollutionTest.java diff --git a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java index d2a71b42..d9a2a012 100644 --- a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java +++ b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java @@ -57,12 +57,21 @@ public BinaryOperator>> combiner() { @Override public BiConsumer>, T> accumulator() { - return (acc, e) -> acc.add(dispatcher.enqueue(() -> mapper.apply(e))); + return (acc, e) -> { + if (!dispatcher.isRunning()) { + dispatcher.start(); + } + acc.add(dispatcher.enqueue(() -> mapper.apply(e))); + }; } @Override public Function>, CompletableFuture> finisher() { - return futures -> combine(futures).thenApply(processor); + return futures -> { + dispatcher.stop(); + + return combine(futures).thenApply(processor); + }; } @Override diff --git a/src/main/java/com/pivovarit/collectors/Dispatcher.java b/src/main/java/com/pivovarit/collectors/Dispatcher.java index 5109bb28..8c08e6c8 100644 --- a/src/main/java/com/pivovarit/collectors/Dispatcher.java +++ b/src/main/java/com/pivovarit/collectors/Dispatcher.java @@ -1,10 +1,15 @@ package com.pivovarit.collectors; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.FutureTask; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; -import java.util.function.BiConsumer; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import java.util.function.Supplier; import static java.lang.Runtime.getRuntime; @@ -14,10 +19,18 @@ */ final class Dispatcher { + private static final Runnable POISON_PILL = () -> System.out.println("Why so serious?"); + private final CompletableFuture completionSignaller = new CompletableFuture<>(); + private final BlockingQueue workingQueue = new LinkedBlockingQueue<>(); + private final ExecutorService dispatcher = Executors.newSingleThreadExecutor(); private final Executor executor; private final Semaphore limiter; + private final AtomicBoolean started = new AtomicBoolean(false); + + private volatile boolean shortCircuited = false; + private Dispatcher(Executor executor, int permits) { this.executor = executor; this.limiter = new Semaphore(permits); @@ -27,34 +40,81 @@ static Dispatcher from(Executor executor, int permits) { return new Dispatcher<>(executor, permits); } - CompletableFuture enqueue(Supplier supplier) { - InterruptibleCompletableFuture future = new InterruptibleCompletableFuture<>(); - completionSignaller.whenComplete(shortcircuit(future)); + void start() { + if (!started.getAndSet(true)) { + dispatcher.execute(() -> { + try { + while (true) { + try { + if (limiter != null) { + limiter.acquire(); + } + } catch (InterruptedException e) { + handle(e); + } + Runnable task; + if ((task = workingQueue.take()) != POISON_PILL) { + executor.execute(() -> { + try { + task.run(); + } finally { + if (limiter != null) { + limiter.release(); + } + } + }); + } else { + break; + } + } + } catch (Throwable e) { + handle(e); + } + }); + } + } + + void stop() { try { - executor.execute(completionTask(supplier, future)); - } catch (Throwable e) { + workingQueue.put(POISON_PILL); + } catch (InterruptedException e) { completionSignaller.completeExceptionally(e); - CompletableFuture result = new CompletableFuture<>(); - result.completeExceptionally(e); - return result; + } finally { + dispatcher.shutdown(); } + } + + boolean isRunning() { + return started.get(); + } + + CompletableFuture enqueue(Supplier supplier) { + InterruptibleCompletableFuture future = new InterruptibleCompletableFuture<>(); + workingQueue.add(completionTask(supplier, future)); + completionSignaller.exceptionally(shortcircuit(future)); return future; } - private FutureTask completionTask(Supplier supplier, InterruptibleCompletableFuture future) { - FutureTask task = new FutureTask<>(() -> { - if (!completionSignaller.isCompletedExceptionally()) { - try { - withLimiter(supplier, future); - } catch (Throwable e) { - completionSignaller.completeExceptionally(e); + private FutureTask completionTask(Supplier supplier, InterruptibleCompletableFuture future) { + FutureTask task = new FutureTask<>(() -> { + try { + if (!shortCircuited) { + future.complete(supplier.get()); } + } catch (Throwable e) { + handle(e); } }, null); future.completedBy(task); return task; } + private void handle(Throwable e) { + shortCircuited = true; + completionSignaller.completeExceptionally(e); + dispatcher.shutdownNow(); + } + private void withLimiter(Supplier supplier, InterruptibleCompletableFuture future) throws InterruptedException { try { limiter.acquire(); @@ -64,12 +124,11 @@ private void withLimiter(Supplier supplier, InterruptibleCompletableFuture } } - private static BiConsumer shortcircuit(InterruptibleCompletableFuture future) { - return (__, throwable) -> { - if (throwable != null) { - future.completeExceptionally(throwable); - future.cancel(true); - } + private static Function shortcircuit(InterruptibleCompletableFuture future) { + return throwable -> { + future.completeExceptionally(throwable); + future.cancel(true); + return null; }; } @@ -79,15 +138,15 @@ static int getDefaultParallelism() { static final class InterruptibleCompletableFuture extends CompletableFuture { - private volatile FutureTask backingTask; + private volatile FutureTask backingTask; - private void completedBy(FutureTask task) { + private void completedBy(FutureTask task) { backingTask = task; } @Override public boolean cancel(boolean mayInterruptIfRunning) { - FutureTask task = backingTask; + FutureTask task = backingTask; if (task != null) { task.cancel(mayInterruptIfRunning); } diff --git a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java index 8c1ac0c4..a5fa3c0e 100644 --- a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java +++ b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java @@ -58,7 +58,10 @@ public Supplier>> supplier() { @Override public BiConsumer>, T> accumulator() { - return (acc, e) -> acc.add(dispatcher.enqueue(() -> function.apply(e))); + return (acc, e) -> { + dispatcher.start(); + acc.add(dispatcher.enqueue(() -> function.apply(e))); + }; } @Override @@ -71,7 +74,10 @@ public BinaryOperator>> combiner() { @Override public Function>, Stream> finisher() { - return completionStrategy; + return acc -> { + dispatcher.stop(); + return completionStrategy.apply(acc); + }; } @Override diff --git a/src/test/java/com/pivovarit/collectors/ExecutorPollutionTest.java b/src/test/java/com/pivovarit/collectors/ExecutorPollutionTest.java new file mode 100644 index 00000000..9b0f91fc --- /dev/null +++ b/src/test/java/com/pivovarit/collectors/ExecutorPollutionTest.java @@ -0,0 +1,63 @@ +package com.pivovarit.collectors; + +import org.junit.jupiter.api.DynamicTest; +import org.junit.jupiter.api.TestFactory; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collector; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.toList; +import static java.util.stream.Stream.of; + +class ExecutorPollutionTest { + + @TestFactory + Stream shouldStartProcessingElementsTests() { + return of( + shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors::parallel, "parallel#1"), + shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.parallel(f, toList(), e, p), "parallel#2"), + shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors::parallelToStream, "parallelToStream"), + shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors::parallelToOrderedStream, "parallelToOrderedStream"), + shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors.Batching::parallel, "parallel#1 (batching)"), + shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.Batching.parallel(f, toList(), e, p), "parallel#2 (batching)"), + shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors.Batching::parallelToStream, "parallelToStream (batching)"), + shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors.Batching::parallelToOrderedStream, "parallelToOrderedStream (batching)") + ); + } + + private static DynamicTest shouldNotSubmitMoreTasksThanParallelism(CollectorFactory collector, String name) { + return DynamicTest.dynamicTest(name, () -> { + ExecutorService e = warmedUp(new ThreadPoolExecutor(2, 2, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(2))); + + Object result = Stream.generate(() -> 42) + .limit(1000) + .collect(collector.apply(i -> i, e, 2)); + + if (result instanceof CompletableFuture) { + ((CompletableFuture) result).join(); + } else if (result instanceof Stream) { + ((Stream) result).forEach((__) -> {}); + } else { + throw new IllegalStateException("can't happen"); + } + }); + } + + interface CollectorFactory { + Collector apply(Function function, Executor executorService, int parallelism); + } + + private static ThreadPoolExecutor warmedUp(ThreadPoolExecutor e) { + for (int i = 0; i < e.getCorePoolSize(); i++) { + e.submit(() -> {}); + } + return e; + } +}