From 1835c37718cf2a468688e3e85075d11a7492c13a Mon Sep 17 00:00:00 2001 From: Grzegorz Piwowarek Date: Sat, 26 Sep 2020 18:59:16 +0200 Subject: [PATCH] Use a calling thread when using blocking streaming instead of creating a new one (#515) --- pom.xml | 2 +- .../collectors/ParallelStreamCollector.java | 43 +++++++------- .../collectors/blackbox/FunctionalTest.java | 56 ++++++++++++++++--- 3 files changed, 72 insertions(+), 29 deletions(-) diff --git a/pom.xml b/pom.xml index 7dd01f4c..07d756dd 100644 --- a/pom.xml +++ b/pom.xml @@ -19,7 +19,7 @@ com.pivovarit parallel-collectors - 2.3.4-SNAPSHOT + 2.4.0-SNAPSHOT jar diff --git a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java index 33d697dc..52f8e051 100644 --- a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java +++ b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java @@ -6,6 +6,7 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; +import java.util.concurrent.Semaphore; import java.util.function.BiConsumer; import java.util.function.BinaryOperator; import java.util.function.Function; @@ -31,26 +32,22 @@ class ParallelStreamCollector implements Collector UNORDERED = EnumSet.of(Characteristics.UNORDERED); - private final Dispatcher dispatcher; private final Function function; private final CompletionStrategy completionStrategy; private final Set characteristics; + private final Semaphore limiter; + private final Executor executor; private ParallelStreamCollector( Function function, CompletionStrategy completionStrategy, Set characteristics, - Dispatcher dispatcher) { + Executor executor, int parallelism) { this.completionStrategy = completionStrategy; this.characteristics = characteristics; - this.dispatcher = dispatcher; + this.limiter = new Semaphore(parallelism); this.function = function; - } - - private void startConsuming() { - if (!dispatcher.isRunning()) { - dispatcher.start(); - } + this.executor = executor; } @Override @@ -61,8 +58,19 @@ public Supplier>> supplier() { @Override public BiConsumer>, T> accumulator() { return (acc, e) -> { - startConsuming(); - acc.add(dispatcher.enqueue(() -> function.apply(e))); + try { + limiter.acquire(); + acc.add(CompletableFuture.supplyAsync(() -> { + try { + return function.apply(e); + } finally { + limiter.release(); + } + }, executor)); + } catch (InterruptedException interruptedException) { + Thread.currentThread().interrupt(); + throw new RuntimeException(interruptedException); + } }; } @@ -75,10 +83,7 @@ public BinaryOperator>> combiner() { @Override public Function>, Stream> finisher() { - return acc -> { - dispatcher.stop(); - return completionStrategy.apply(acc.build()); - }; + return acc -> completionStrategy.apply(acc.build()); } @Override @@ -97,7 +102,7 @@ public Set characteristics() { return parallelism == 1 ? BatchingCollectors.syncCollector(mapper) - : new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.of(executor, parallelism)); + : new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, executor, parallelism); } static Collector> streamingOrdered(Function mapper, Executor executor) { @@ -111,7 +116,7 @@ public Set characteristics() { return parallelism == 1 ? BatchingCollectors.syncCollector(mapper) - : new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.of(executor, parallelism)); + : new ParallelStreamCollector<>(mapper, ordered(), emptySet(), executor, parallelism); } static final class BatchingCollectors { @@ -125,7 +130,7 @@ private BatchingCollectors() { return parallelism == 1 ? syncCollector(mapper) - : batched(new ParallelStreamCollector<>(batching(mapper), unordered(), UNORDERED, Dispatcher.of(executor, parallelism)), parallelism); + : batched(new ParallelStreamCollector<>(batching(mapper), unordered(), UNORDERED, executor, parallelism), parallelism); } static Collector> streamingOrdered(Function mapper, Executor executor, int parallelism) { @@ -135,7 +140,7 @@ private BatchingCollectors() { return parallelism == 1 ? syncCollector(mapper) - : batched(new ParallelStreamCollector<>(batching(mapper), ordered(), emptySet(), Dispatcher.of(executor, parallelism)), parallelism); + : batched(new ParallelStreamCollector<>(batching(mapper), ordered(), emptySet(), executor, parallelism), parallelism); } private static Collector> batched(ParallelStreamCollector, List> downstream, int parallelism) { diff --git a/src/test/java/com/pivovarit/collectors/blackbox/FunctionalTest.java b/src/test/java/com/pivovarit/collectors/blackbox/FunctionalTest.java index 53f14fb8..96197d43 100644 --- a/src/test/java/com/pivovarit/collectors/blackbox/FunctionalTest.java +++ b/src/test/java/com/pivovarit/collectors/blackbox/FunctionalTest.java @@ -68,9 +68,7 @@ Stream collectors() { tests((m, e, p) -> parallel(m, toList(), e, p), format("ParallelCollectors.parallel(toList(), p=%d)", PARALLELISM), true), tests((m, e, p) -> parallel(m, toSet(), e, p), format("ParallelCollectors.parallel(toSet(), p=%d)", PARALLELISM), false), tests((m, e, p) -> parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.parallel(toCollection(), p=%d)", PARALLELISM), true), - tests((m, e, p) -> adapt(parallel(m, e, p)), format("ParallelCollectors.parallel(p=%d)", PARALLELISM), true), - tests((m, e, p) -> adaptAsync(parallelToStream(m, e, p)), format("ParallelCollectors.parallelToStream(p=%d)", PARALLELISM), false), - tests((m, e, p) -> adaptAsync(parallelToOrderedStream(m, e, p)), format("ParallelCollectors.parallelToOrderedStream(p=%d)", PARALLELISM), true) + tests((m, e, p) -> adapt(parallel(m, e, p)), format("ParallelCollectors.parallel(p=%d)", PARALLELISM), true) ).flatMap(identity()); } @@ -80,9 +78,23 @@ Stream batching_collectors() { batchTests((m, e, p) -> Batching.parallel(m, toList(), e, p), format("ParallelCollectors.Batching.parallel(toList(), p=%d)", PARALLELISM), true), batchTests((m, e, p) -> Batching.parallel(m, toSet(), e, p), format("ParallelCollectors.Batching.parallel(toSet(), p=%d)", PARALLELISM), false), batchTests((m, e, p) -> Batching.parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.Batching.parallel(toCollection(), p=%d)", PARALLELISM), true), - batchTests((m, e, p) -> adapt(Batching.parallel(m, e, p)), format("ParallelCollectors.Batching.parallel(p=%d)", PARALLELISM), true), - batchTests((m, e, p) -> adaptAsync(Batching.parallelToStream(m, e, p)), format("ParallelCollectors.Batching.parallelToStream(p=%d)", PARALLELISM), false), - batchTests((m, e, p) -> adaptAsync(Batching.parallelToOrderedStream(m, e, p)), format("ParallelCollectors.Batching.parallelToOrderedStream(p=%d)", PARALLELISM), true) + batchTests((m, e, p) -> adapt(Batching.parallel(m, e, p)), format("ParallelCollectors.Batching.parallel(p=%d)", PARALLELISM), true) + ).flatMap(identity()); + } + + @TestFactory + Stream streaming_collectors() { + return of( + streamingTests((m, e, p) -> adaptAsync(parallelToStream(m, e, p)), format("ParallelCollectors.parallelToStream(p=%d)", PARALLELISM), false), + streamingTests((m, e, p) -> adaptAsync(parallelToOrderedStream(m, e, p)), format("ParallelCollectors.parallelToOrderedStream(p=%d)", PARALLELISM), true) + ).flatMap(identity()); + } + + @TestFactory + Stream streaming_batching_collectors() { + return of( + batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToStream(m, e, p)), format("ParallelCollectors.Batching.parallelToStream(p=%d)", PARALLELISM), false), + batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToOrderedStream(m, e, p)), format("ParallelCollectors.Batching.parallelToOrderedStream(p=%d)", PARALLELISM), true) ).flatMap(identity()); } @@ -139,12 +151,34 @@ private static > Stream tests(Collect ); } + private static > Stream streamingTests(CollectorSupplier, Executor, Integer, Collector>> collector, String name, boolean maintainsOrder) { + return of( + shouldCollect(collector, name, 1), + shouldCollect(collector, name, PARALLELISM), + shouldCollectToEmpty(collector, name), + shouldStartConsumingImmediately(collector, name), + shouldNotBlockTheCallingThread(collector, name), + shouldMaintainOrder(collector, name, maintainsOrder), + shouldRespectParallelism(collector, name), + shouldHandleThrowable(collector, name), + shouldShortCircuitOnException(collector, name), + shouldHandleRejectedExecutionException(collector, name), + shouldRemainConsistent(collector, name) + ); + } + private static > Stream batchTests(CollectorSupplier, Executor, Integer, Collector>> collector, String name, boolean maintainsOrder) { return Stream.concat( tests(collector, name, maintainsOrder), of(shouldProcessOnNThreadsETParallelism(collector, name))); } + private static > Stream batchStreamingTests(CollectorSupplier, Executor, Integer, Collector>> collector, String name, boolean maintainsOrder) { + return Stream.concat( + streamingTests(collector, name, maintainsOrder), + of(shouldProcessOnNThreadsETParallelism(collector, name))); + } + private static > DynamicTest shouldNotBlockTheCallingThread(CollectorSupplier, Executor, Integer, Collector>> c, String name) { return dynamicTest(format("%s: should not block when returning future", name), () -> { assertTimeoutPreemptively(ofMillis(100), () -> @@ -258,15 +292,19 @@ private static > DynamicTest shouldHandleThrowable } private static > DynamicTest shouldHandleRejectedExecutionException(CollectorSupplier, Executor, Integer, Collector>> collector, String name) { - return dynamicTest(format("%s: should survive rejected execution exception", name), () -> { + return dynamicTest(format("%s: should propagate rejected execution exception", name), () -> { Executor executor = command -> { throw new RejectedExecutionException(); }; List elements = IntStream.range(0, 1000).boxed().collect(toList()); assertThatThrownBy(() -> elements.stream() .collect(collector.apply(i -> returnWithDelay(i, ofMillis(10000)), executor, PARALLELISM)) .join()) - .isInstanceOf(CompletionException.class) - .hasCauseExactlyInstanceOf(RejectedExecutionException.class); + .isInstanceOfAny(RejectedExecutionException.class, CompletionException.class) + .matches(ex -> { + if (ex instanceof CompletionException) { + return ex.getCause() instanceof RejectedExecutionException; + } else return true; + }); }); }