diff --git a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java index 0179a142..77b6ab23 100644 --- a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java +++ b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java @@ -5,7 +5,6 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; -import java.util.concurrent.Semaphore; import java.util.function.BiConsumer; import java.util.function.BinaryOperator; import java.util.function.Function; @@ -34,19 +33,23 @@ class ParallelStreamCollector implements Collector function; private final CompletionStrategy completionStrategy; private final Set characteristics; - private final Semaphore limiter; - private final Executor executor; + private final Dispatcher dispatcher; private ParallelStreamCollector( Function function, CompletionStrategy completionStrategy, Set characteristics, - Executor executor, int parallelism) { + Dispatcher dispatcher) { this.completionStrategy = completionStrategy; this.characteristics = characteristics; - this.limiter = new Semaphore(parallelism); + this.dispatcher = dispatcher; this.function = function; - this.executor = executor; + } + + private void startConsuming() { + if (!dispatcher.isRunning()) { + dispatcher.start(); + } } @Override @@ -57,19 +60,8 @@ public Supplier>> supplier() { @Override public BiConsumer>, T> accumulator() { return (acc, e) -> { - try { - limiter.acquire(); - acc.add(CompletableFuture.supplyAsync(() -> { - try { - return function.apply(e); - } finally { - limiter.release(); - } - }, executor)); - } catch (InterruptedException interruptedException) { - Thread.currentThread().interrupt(); - throw new RuntimeException(interruptedException); - } + startConsuming(); + acc.add(dispatcher.enqueue(() -> function.apply(e))); }; } @@ -82,7 +74,10 @@ public BinaryOperator>> combiner() { @Override public Function>, Stream> finisher() { - return acc -> completionStrategy.apply(acc.build()); + return acc -> { + dispatcher.stop(); + return completionStrategy.apply(acc.build()); + }; } @Override @@ -99,9 +94,7 @@ public Set characteristics() { requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); - return parallelism == 1 - ? BatchingCollectors.syncCollector(mapper) - : new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, executor, parallelism); + return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.of(executor, parallelism)); } static Collector> streamingOrdered(Function mapper, Executor executor) { @@ -113,9 +106,7 @@ public Set characteristics() { requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); - return parallelism == 1 - ? BatchingCollectors.syncCollector(mapper) - : new ParallelStreamCollector<>(mapper, ordered(), emptySet(), executor, parallelism); + return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.of(executor, parallelism)); } static final class BatchingCollectors { @@ -153,16 +144,14 @@ private BatchingCollectors() { mapper, ordered(), emptySet(), - executor, - parallelism)); + Dispatcher.of(executor, parallelism))); } else { return partitioned(list, parallelism) .collect(collectingAndThen(new ParallelStreamCollector<>( batching(mapper), ordered(), emptySet(), - executor, - parallelism), + Dispatcher.of(executor, parallelism)), s -> s.flatMap(Collection::stream))); } }); diff --git a/src/test/java/com/pivovarit/collectors/FunctionalTest.java b/src/test/java/com/pivovarit/collectors/FunctionalTest.java index fd7943cb..214b1ff0 100644 --- a/src/test/java/com/pivovarit/collectors/FunctionalTest.java +++ b/src/test/java/com/pivovarit/collectors/FunctionalTest.java @@ -1,6 +1,7 @@ package com.pivovarit.collectors; import com.pivovarit.collectors.ParallelCollectors.Batching; +import org.awaitility.Awaitility; import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestFactory; @@ -24,6 +25,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.LongAdder; @@ -111,6 +113,26 @@ void shouldCollectInCompletionOrder() { assertThat(result).isSorted(); } + @Test + void shouldCollectEagerlyInCompletionOrder() { + // given + executor = threadPoolExecutor(4); + AtomicBoolean result = new AtomicBoolean(false); + CompletableFuture.runAsync(() -> { + Stream.of(1, 10000, 1, 0) + .collect(parallelToStream(i -> returnWithDelay(i, ofMillis(i)), executor, 2)) + .forEach(i -> { + if (i == 1) { + result.set(true); + } + }); + }); + + await() + .atMost(1, SECONDS) + .until(result::get); + } + @Test void shouldExecuteEagerlyOnProvidedThreadPool() { ExecutorService executor = Executors.newFixedThreadPool(2);