diff --git a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java index 76af2d0a..d2a71b42 100644 --- a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java +++ b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java @@ -57,21 +57,12 @@ public BinaryOperator>> combiner() { @Override public BiConsumer>, T> accumulator() { - return (acc, e) -> { - if (!dispatcher.isRunning()) { - dispatcher.start(); - } - acc.add(dispatcher.enqueue(() -> mapper.apply(e))); - }; + return (acc, e) -> acc.add(dispatcher.enqueue(() -> mapper.apply(e))); } @Override public Function>, CompletableFuture> finisher() { - return futures -> { - dispatcher.stop(); - - return combine(futures).thenApply(processor); - }; + return futures -> combine(futures).thenApply(processor); } @Override @@ -105,7 +96,7 @@ private static CompletableFuture> combine(List i) - : new AsyncParallelCollector<>(mapper, Dispatcher.of(executor, parallelism), t -> t); + : new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), t -> t); } static Collector> collectingWithCollector(Collector collector, Function mapper, Executor executor) { @@ -120,7 +111,7 @@ private static CompletableFuture> combine(List s.collect(collector)) - : new AsyncParallelCollector<>(mapper, Dispatcher.of(executor, parallelism), s -> s.collect(collector)); + : new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), s -> s.collect(collector)); } static void requireValidParallelism(int parallelism) { @@ -176,13 +167,13 @@ private BatchingCollectors() { return list.stream() .collect(new AsyncParallelCollector<>( mapper, - Dispatcher.of(executor, parallelism), + Dispatcher.from(executor, parallelism), finisher)); } else { return partitioned(list, parallelism) .collect(new AsyncParallelCollector<>( batching(mapper), - Dispatcher.of(executor, parallelism), + Dispatcher.from(executor, parallelism), listStream -> finisher.apply(listStream.flatMap(Collection::stream)))); } }); diff --git a/src/main/java/com/pivovarit/collectors/Dispatcher.java b/src/main/java/com/pivovarit/collectors/Dispatcher.java index 13379817..5109bb28 100644 --- a/src/main/java/com/pivovarit/collectors/Dispatcher.java +++ b/src/main/java/com/pivovarit/collectors/Dispatcher.java @@ -1,18 +1,10 @@ package com.pivovarit.collectors; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.FutureTask; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; -import java.util.concurrent.SynchronousQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; +import java.util.function.BiConsumer; import java.util.function.Supplier; import static java.lang.Runtime.getRuntime; @@ -22,104 +14,61 @@ */ final class Dispatcher { - private static final Runnable POISON_PILL = () -> System.out.println("Why so serious?"); - private final CompletableFuture completionSignaller = new CompletableFuture<>(); - - private final BlockingQueue workingQueue = new LinkedBlockingQueue<>(); - - private final ExecutorService dispatcher = newLazySingleThreadExecutor(); private final Executor executor; private final Semaphore limiter; - private final AtomicBoolean started = new AtomicBoolean(false); - - private volatile boolean shortCircuited = false; - private Dispatcher(Executor executor, int permits) { this.executor = executor; this.limiter = new Semaphore(permits); } - static Dispatcher of(Executor executor, int permits) { + static Dispatcher from(Executor executor, int permits) { return new Dispatcher<>(executor, permits); } - void start() { - if (!started.getAndSet(true)) { - dispatcher.execute(() -> { - try { - while (true) { - Runnable task; - if ((task = workingQueue.take()) != POISON_PILL) { - limiter.acquire(); - executor.execute(withFinally(task, limiter::release)); - } else { - break; - } - } - } catch (Throwable e) { - handle(e); - } - }); - } - } - - void stop() { + CompletableFuture enqueue(Supplier supplier) { + InterruptibleCompletableFuture future = new InterruptibleCompletableFuture<>(); + completionSignaller.whenComplete(shortcircuit(future)); try { - workingQueue.put(POISON_PILL); - } catch (InterruptedException e) { + executor.execute(completionTask(supplier, future)); + } catch (Throwable e) { completionSignaller.completeExceptionally(e); - } finally { - dispatcher.shutdown(); + CompletableFuture result = new CompletableFuture<>(); + result.completeExceptionally(e); + return result; } - } - - boolean isRunning() { - return started.get(); - } - - CompletableFuture enqueue(Supplier supplier) { - InterruptibleCompletableFuture future = new InterruptibleCompletableFuture<>(); - workingQueue.add(completionTask(supplier, future)); - completionSignaller.exceptionally(shortcircuit(future)); return future; } - private FutureTask completionTask(Supplier supplier, InterruptibleCompletableFuture future) { - FutureTask task = new FutureTask<>(() -> { - try { - if (!shortCircuited) { - future.complete(supplier.get()); + private FutureTask completionTask(Supplier supplier, InterruptibleCompletableFuture future) { + FutureTask task = new FutureTask<>(() -> { + if (!completionSignaller.isCompletedExceptionally()) { + try { + withLimiter(supplier, future); + } catch (Throwable e) { + completionSignaller.completeExceptionally(e); } - } catch (Throwable e) { - handle(e); } }, null); future.completedBy(task); return task; } - private void handle(Throwable e) { - shortCircuited = true; - completionSignaller.completeExceptionally(e); - dispatcher.shutdownNow(); - } - - private static Function shortcircuit(InterruptibleCompletableFuture future) { - return throwable -> { - future.completeExceptionally(throwable); - future.cancel(true); - return null; - }; + private void withLimiter(Supplier supplier, InterruptibleCompletableFuture future) throws InterruptedException { + try { + limiter.acquire(); + future.complete(supplier.get()); + } finally { + limiter.release(); + } } - private static Runnable withFinally(Runnable task, Runnable finisher) { - return () -> { - try { - task.run(); - } finally { - finisher.run(); + private static BiConsumer shortcircuit(InterruptibleCompletableFuture future) { + return (__, throwable) -> { + if (throwable != null) { + future.completeExceptionally(throwable); + future.cancel(true); } }; } @@ -128,29 +77,19 @@ static int getDefaultParallelism() { return Math.max(getRuntime().availableProcessors() - 1, 4); } - private static ThreadPoolExecutor newLazySingleThreadExecutor() { - return new ThreadPoolExecutor(0, 1, - 0L, TimeUnit.MILLISECONDS, - new SynchronousQueue<>(), - task -> { - Thread thread = Executors.defaultThreadFactory().newThread(task); - thread.setName("parallel-collector-" + thread.getName()); - thread.setDaemon(false); - return thread; - }); - } - static final class InterruptibleCompletableFuture extends CompletableFuture { - private volatile FutureTask backingTask; - private void completedBy(FutureTask task) { + private volatile FutureTask backingTask; + + private void completedBy(FutureTask task) { backingTask = task; } @Override public boolean cancel(boolean mayInterruptIfRunning) { - if (backingTask != null) { - backingTask.cancel(mayInterruptIfRunning); + FutureTask task = backingTask; + if (task != null) { + task.cancel(mayInterruptIfRunning); } return super.cancel(mayInterruptIfRunning); } diff --git a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java index d133f65f..8c1ac0c4 100644 --- a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java +++ b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java @@ -58,10 +58,7 @@ public Supplier>> supplier() { @Override public BiConsumer>, T> accumulator() { - return (acc, e) -> { - dispatcher.start(); - acc.add(dispatcher.enqueue(() -> function.apply(e))); - }; + return (acc, e) -> acc.add(dispatcher.enqueue(() -> function.apply(e))); } @Override @@ -74,10 +71,7 @@ public BinaryOperator>> combiner() { @Override public Function>, Stream> finisher() { - return acc -> { - dispatcher.stop(); - return completionStrategy.apply(acc); - }; + return completionStrategy; } @Override @@ -94,7 +88,7 @@ public Set characteristics() { requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); - return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.of(executor, parallelism)); + return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor, parallelism)); } static Collector> streamingOrdered(Function mapper, Executor executor) { @@ -107,7 +101,7 @@ public Set characteristics() { requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); - return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.of(executor, parallelism)); + return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor, parallelism)); } static final class BatchingCollectors { @@ -149,7 +143,7 @@ private BatchingCollectors() { mapper, ordered(), emptySet(), - Dispatcher.of(executor, parallelism))); + Dispatcher.from(executor, parallelism))); } else { return partitioned(list, parallelism) @@ -157,7 +151,7 @@ private BatchingCollectors() { batching(mapper), ordered(), emptySet(), - Dispatcher.of(executor, parallelism)), + Dispatcher.from(executor, parallelism)), s -> s.flatMap(Collection::stream))); } }); diff --git a/src/test/java/com/pivovarit/collectors/FunctionalTest.java b/src/test/java/com/pivovarit/collectors/FunctionalTest.java index 17f27c70..9f1fb88f 100644 --- a/src/test/java/com/pivovarit/collectors/FunctionalTest.java +++ b/src/test/java/com/pivovarit/collectors/FunctionalTest.java @@ -162,7 +162,6 @@ private static > Stream tests(Collect shouldCollectNElementsWithNParallelism(collector, name, PARALLELISM), shouldCollectToEmpty(collector, name), shouldStartConsumingImmediately(collector, name), - shouldTerminateAfterConsumingAllElements(collector, name), shouldNotBlockTheCallingThread(collector, name), shouldRespectParallelism(collector, name), shouldHandleThrowable(collector, name), @@ -184,7 +183,6 @@ private static > Stream streamingTest shouldCollect(collector, name, PARALLELISM), shouldCollectToEmpty(collector, name), shouldStartConsumingImmediately(collector, name), - shouldTerminateAfterConsumingAllElements(collector, name), shouldNotBlockTheCallingThread(collector, name), shouldRespectParallelism(collector, name), shouldHandleThrowable(collector, name), @@ -286,30 +284,6 @@ private static > DynamicTest shouldCollectNElement }); } - private static > DynamicTest shouldTerminateAfterConsumingAllElements(CollectorSupplier, Executor, Integer, Collector>> factory, String name) { - return dynamicTest(format("%s: should terminate after consuming all elements", name), () -> { - List elements = IntStream.range(0, 10).boxed().collect(toList()); - Collector> ctor = factory.apply(i -> i, executor, 10); - Collection result = elements.stream().collect(ctor) - .join(); - - assertThat(result).hasSameElementsAs(elements); - - if (ctor instanceof AsyncParallelCollector) { - Field dispatcherField = AsyncParallelCollector.class.getDeclaredField("dispatcher"); - dispatcherField.setAccessible(true); - Dispatcher dispatcher = (Dispatcher) dispatcherField.get(ctor); - Field innerDispatcherField = Dispatcher.class.getDeclaredField("dispatcher"); - innerDispatcherField.setAccessible(true); - ExecutorService executor = (ExecutorService) innerDispatcherField.get(dispatcher); - - await() - .atMost(Duration.ofSeconds(2)) - .until(executor::isTerminated); - } - }); - } - private static > DynamicTest shouldMaintainOrder(CollectorSupplier, Executor, Integer, Collector>> collector, String name) { return dynamicTest(format("%s: should maintain order", name), () -> { int parallelism = 4;