Skip to content

Commit

Permalink
Use a calling thread when using blocking streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
pivovarit committed Sep 26, 2020
1 parent c365e96 commit 2af520b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 28 deletions.
43 changes: 24 additions & 19 deletions src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Semaphore;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
Expand All @@ -31,26 +32,22 @@ class ParallelStreamCollector<T, R> implements Collector<T, Stream.Builder<Compl

private static final EnumSet<Characteristics> UNORDERED = EnumSet.of(Characteristics.UNORDERED);

private final Dispatcher<R> dispatcher;
private final Function<T, R> function;
private final CompletionStrategy<R> completionStrategy;
private final Set<Characteristics> characteristics;
private final Semaphore limiter;
private final Executor executor;

private ParallelStreamCollector(
Function<T, R> function,
CompletionStrategy<R> completionStrategy,
Set<Characteristics> characteristics,
Dispatcher<R> dispatcher) {
Executor executor, int permits) {
this.completionStrategy = completionStrategy;
this.characteristics = characteristics;
this.dispatcher = dispatcher;
this.limiter = new Semaphore(permits);
this.function = function;
}

private void startConsuming() {
if (!dispatcher.isRunning()) {
dispatcher.start();
}
this.executor = executor;
}

@Override
Expand All @@ -61,8 +58,19 @@ public Supplier<Stream.Builder<CompletableFuture<R>>> supplier() {
@Override
public BiConsumer<Stream.Builder<CompletableFuture<R>>, T> accumulator() {
return (acc, e) -> {
startConsuming();
acc.add(dispatcher.enqueue(() -> function.apply(e)));
try {
limiter.acquire();
acc.add(CompletableFuture.supplyAsync(() -> {
try {
return function.apply(e);
} finally {
limiter.release();
}
}, executor));
} catch (InterruptedException interruptedException) {
Thread.currentThread().interrupt();
throw new RuntimeException(interruptedException);
}
};
}

Expand All @@ -75,10 +83,7 @@ public BinaryOperator<Stream.Builder<CompletableFuture<R>>> combiner() {

@Override
public Function<Stream.Builder<CompletableFuture<R>>, Stream<R>> finisher() {
return acc -> {
dispatcher.stop();
return completionStrategy.apply(acc.build());
};
return acc -> completionStrategy.apply(acc.build());
}

@Override
Expand All @@ -97,7 +102,7 @@ public Set<Characteristics> characteristics() {

return parallelism == 1
? BatchingCollectors.syncCollector(mapper)
: new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.of(executor, parallelism));
: new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, executor, parallelism);
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor) {
Expand All @@ -111,7 +116,7 @@ public Set<Characteristics> characteristics() {

return parallelism == 1
? BatchingCollectors.syncCollector(mapper)
: new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.of(executor, parallelism));
: new ParallelStreamCollector<>(mapper, ordered(), emptySet(), executor, parallelism);
}

static final class BatchingCollectors {
Expand All @@ -125,7 +130,7 @@ private BatchingCollectors() {

return parallelism == 1
? syncCollector(mapper)
: batched(new ParallelStreamCollector<>(batching(mapper), unordered(), UNORDERED, Dispatcher.of(executor, parallelism)), parallelism);
: batched(new ParallelStreamCollector<>(batching(mapper), unordered(), UNORDERED, executor, parallelism), parallelism);
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor, int parallelism) {
Expand All @@ -135,7 +140,7 @@ private BatchingCollectors() {

return parallelism == 1
? syncCollector(mapper)
: batched(new ParallelStreamCollector<>(batching(mapper), ordered(), emptySet(), Dispatcher.of(executor, parallelism)), parallelism);
: batched(new ParallelStreamCollector<>(batching(mapper), ordered(), emptySet(), executor, parallelism), parallelism);
}

private static <T, R> Collector<T, ?, Stream<R>> batched(ParallelStreamCollector<List<T>, List<R>> downstream, int parallelism) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ Stream<DynamicTest> collectors() {
tests((m, e, p) -> parallel(m, toList(), e, p), format("ParallelCollectors.parallel(toList(), p=%d)", PARALLELISM), true),
tests((m, e, p) -> parallel(m, toSet(), e, p), format("ParallelCollectors.parallel(toSet(), p=%d)", PARALLELISM), false),
tests((m, e, p) -> parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.parallel(toCollection(), p=%d)", PARALLELISM), true),
tests((m, e, p) -> adapt(parallel(m, e, p)), format("ParallelCollectors.parallel(p=%d)", PARALLELISM), true),
tests((m, e, p) -> adaptAsync(parallelToStream(m, e, p)), format("ParallelCollectors.parallelToStream(p=%d)", PARALLELISM), false),
tests((m, e, p) -> adaptAsync(parallelToOrderedStream(m, e, p)), format("ParallelCollectors.parallelToOrderedStream(p=%d)", PARALLELISM), true)
tests((m, e, p) -> adapt(parallel(m, e, p)), format("ParallelCollectors.parallel(p=%d)", PARALLELISM), true)
).flatMap(identity());
}

Expand All @@ -80,9 +78,23 @@ Stream<DynamicTest> batching_collectors() {
batchTests((m, e, p) -> Batching.parallel(m, toList(), e, p), format("ParallelCollectors.Batching.parallel(toList(), p=%d)", PARALLELISM), true),
batchTests((m, e, p) -> Batching.parallel(m, toSet(), e, p), format("ParallelCollectors.Batching.parallel(toSet(), p=%d)", PARALLELISM), false),
batchTests((m, e, p) -> Batching.parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.Batching.parallel(toCollection(), p=%d)", PARALLELISM), true),
batchTests((m, e, p) -> adapt(Batching.parallel(m, e, p)), format("ParallelCollectors.Batching.parallel(p=%d)", PARALLELISM), true),
batchTests((m, e, p) -> adaptAsync(Batching.parallelToStream(m, e, p)), format("ParallelCollectors.Batching.parallelToStream(p=%d)", PARALLELISM), false),
batchTests((m, e, p) -> adaptAsync(Batching.parallelToOrderedStream(m, e, p)), format("ParallelCollectors.Batching.parallelToOrderedStream(p=%d)", PARALLELISM), true)
batchTests((m, e, p) -> adapt(Batching.parallel(m, e, p)), format("ParallelCollectors.Batching.parallel(p=%d)", PARALLELISM), true)
).flatMap(identity());
}

@TestFactory
Stream<DynamicTest> streaming_collectors() {
return of(
streamingTests((m, e, p) -> adaptAsync(parallelToStream(m, e, p)), format("ParallelCollectors.parallelToStream(p=%d)", PARALLELISM), false),
streamingTests((m, e, p) -> adaptAsync(parallelToOrderedStream(m, e, p)), format("ParallelCollectors.parallelToOrderedStream(p=%d)", PARALLELISM), true)
).flatMap(identity());
}

@TestFactory
Stream<DynamicTest> streaming_batching_collectors() {
return of(
batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToStream(m, e, p)), format("ParallelCollectors.Batching.parallelToStream(p=%d)", PARALLELISM), false),
batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToOrderedStream(m, e, p)), format("ParallelCollectors.Batching.parallelToOrderedStream(p=%d)", PARALLELISM), true)
).flatMap(identity());
}

Expand Down Expand Up @@ -139,12 +151,34 @@ private static <R extends Collection<Integer>> Stream<DynamicTest> tests(Collect
);
}

private static <R extends Collection<Integer>> Stream<DynamicTest> streamingTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name, boolean maintainsOrder) {
return of(
shouldCollect(collector, name, 1),
shouldCollect(collector, name, PARALLELISM),
shouldCollectToEmpty(collector, name),
shouldStartConsumingImmediately(collector, name),
shouldNotBlockTheCallingThread(collector, name),
shouldMaintainOrder(collector, name, maintainsOrder),
shouldRespectParallelism(collector, name),
shouldHandleThrowable(collector, name),
shouldShortCircuitOnException(collector, name),
shouldHandleRejectedExecutionException(collector, name),
shouldRemainConsistent(collector, name)
);
}

private static <R extends Collection<Integer>> Stream<DynamicTest> batchTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name, boolean maintainsOrder) {
return Stream.concat(
tests(collector, name, maintainsOrder),
of(shouldProcessOnNThreadsETParallelism(collector, name)));
}

private static <R extends Collection<Integer>> Stream<DynamicTest> batchStreamingTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name, boolean maintainsOrder) {
return Stream.concat(
streamingTests(collector, name, maintainsOrder),
of(shouldProcessOnNThreadsETParallelism(collector, name)));
}

private static <R extends Collection<Integer>> DynamicTest shouldNotBlockTheCallingThread(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> c, String name) {
return dynamicTest(format("%s: should not block when returning future", name), () -> {
assertTimeoutPreemptively(ofMillis(100), () ->
Expand Down Expand Up @@ -258,15 +292,19 @@ private static <R extends Collection<Integer>> DynamicTest shouldHandleThrowable
}

private static <R extends Collection<Integer>> DynamicTest shouldHandleRejectedExecutionException(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return dynamicTest(format("%s: should survive rejected execution exception", name), () -> {
return dynamicTest(format("%s: should propagate rejected execution exception", name), () -> {
Executor executor = command -> { throw new RejectedExecutionException(); };
List<Integer> elements = IntStream.range(0, 1000).boxed().collect(toList());

assertThatThrownBy(() -> elements.stream()
.collect(collector.apply(i -> returnWithDelay(i, ofMillis(10000)), executor, PARALLELISM))
.join())
.isInstanceOf(CompletionException.class)
.hasCauseExactlyInstanceOf(RejectedExecutionException.class);
.isInstanceOfAny(RejectedExecutionException.class, CompletionException.class)
.matches(ex -> {
if (ex instanceof CompletionException) {
return ex.getCause() instanceof RejectedExecutionException;
} else return true;
});
});
}

Expand Down

0 comments on commit 2af520b

Please sign in to comment.