diff --git a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java index bb7c44d1..bd6a782b 100644 --- a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java +++ b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java @@ -7,6 +7,7 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import java.util.function.BiConsumer; import java.util.function.BinaryOperator; import java.util.function.Function; @@ -83,6 +84,12 @@ private static CompletableFuture> combine(List Collector>> collectingToStream(Function mapper) { + requireNonNull(mapper, "mapper can't be null"); + + return new AsyncParallelCollector<>(mapper, Dispatcher.virtual(), Function.identity()); + } + static Collector>> collectingToStream(Function mapper, Executor executor, int parallelism) { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); @@ -93,6 +100,13 @@ private static CompletableFuture> combine(List(mapper, Dispatcher.from(executor, parallelism), Function.identity()); } + static Collector> collectingWithCollector(Collector collector, Function mapper) { + requireNonNull(collector, "collector can't be null"); + requireNonNull(mapper, "mapper can't be null"); + + return new AsyncParallelCollector<>(mapper, Dispatcher.virtual(),s -> s.collect(collector)); + } + static Collector> collectingWithCollector(Collector collector, Function mapper, Executor executor, int parallelism) { requireNonNull(collector, "collector can't be null"); requireNonNull(executor, "executor can't be null"); diff --git a/src/main/java/com/pivovarit/collectors/Dispatcher.java b/src/main/java/com/pivovarit/collectors/Dispatcher.java index 23a9aad6..f9d1f0d1 100644 --- a/src/main/java/com/pivovarit/collectors/Dispatcher.java +++ b/src/main/java/com/pivovarit/collectors/Dispatcher.java @@ -17,9 +17,9 @@ final class Dispatcher { private final Executor executor; private final Semaphore limiter; - private Dispatcher(int permits) { + private Dispatcher() { this.executor = Executors.newVirtualThreadPerTaskExecutor(); - this.limiter = new Semaphore(permits); + this.limiter = null; } private Dispatcher(Executor executor, int permits) { @@ -31,8 +31,8 @@ static Dispatcher from(Executor executor, int permits) { return new Dispatcher<>(executor, permits); } - static Dispatcher virtual(int permits) { - return new Dispatcher<>(permits); + static Dispatcher virtual() { + return new Dispatcher<>(); } CompletableFuture enqueue(Supplier supplier) { @@ -51,7 +51,11 @@ private FutureTask completionTask(Supplier supplier, InterruptibleCompleta FutureTask task = new FutureTask<>(() -> { if (!completionSignaller.isCompletedExceptionally()) { try { - withLimiter(supplier, future); + if (limiter == null) { + future.complete(supplier.get()); + } else { + withLimiter(supplier, future); + } } catch (Throwable e) { completionSignaller.completeExceptionally(e); } diff --git a/src/main/java/com/pivovarit/collectors/ParallelCollectors.java b/src/main/java/com/pivovarit/collectors/ParallelCollectors.java index d611b974..e4ab1546 100644 --- a/src/main/java/com/pivovarit/collectors/ParallelCollectors.java +++ b/src/main/java/com/pivovarit/collectors/ParallelCollectors.java @@ -17,6 +17,32 @@ public final class ParallelCollectors { private ParallelCollectors() { } + + /** + * A convenience {@link Collector} used for executing parallel computations using Virtual Threads + * and returning them as a {@link CompletableFuture} containing a result of the application of the user-provided {@link Collector}. + * + *
+ * Example: + *
{@code
+     * CompletableFuture> result = Stream.of(1, 2, 3)
+     *   .collect(parallel(i -> foo(i), toList()));
+     * }
+ * + * @param mapper a transformation to be performed in parallel + * @param collector the {@code Collector} describing the reduction + * @param the type of the collected elements + * @param the result returned by {@code mapper} + * @param the reduction result {@code collector} + * + * @return a {@code Collector} which collects all processed elements into a user-provided mutable {@code Collection} in parallel + * + * @since 3.0.0 + */ + public static Collector> parallel(Function mapper, Collector collector) { + return AsyncParallelCollector.collectingWithCollector(collector, mapper); + } + /** * A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor} * and returning them as a {@link CompletableFuture} containing a result of the application of the user-provided {@link Collector}. @@ -44,6 +70,32 @@ private ParallelCollectors() { return AsyncParallelCollector.collectingWithCollector(collector, mapper, executor, parallelism); } + /** + * A convenience {@link Collector} used for executing parallel computations using Virtual Threads + * and returning them as {@link CompletableFuture} containing a {@link Stream} of these elements. + * + *

+ * The collector maintains the order of processed {@link Stream}. Instances should not be reused. + * + *
+ * Example: + *
{@code
+     * CompletableFuture> result = Stream.of(1, 2, 3)
+     *   .collect(parallel(i -> foo()));
+     * }
+ * + * @param mapper a transformation to be performed in parallel + * @param the type of the collected elements + * @param the result returned by {@code mapper} + * + * @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel + * + * @since 3.0.0 + */ + public static Collector>> parallel(Function mapper) { + return AsyncParallelCollector.collectingToStream(mapper); + } + /** * A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor} * and returning them as {@link CompletableFuture} containing a {@link Stream} of these elements. @@ -72,6 +124,32 @@ private ParallelCollectors() { return AsyncParallelCollector.collectingToStream(mapper, executor, parallelism); } + /** + * A convenience {@link Collector} used for executing parallel computations using Virtual Threads + * and returning a {@link Stream} instance returning results as they arrive. + *

+ * For the parallelism of 1, the stream is executed by the calling thread. + * + *
+ * Example: + *

{@code
+     * Stream.of(1, 2, 3)
+     *   .collect(parallelToStream(i -> foo()))
+     *   .forEach(System.out::println);
+     * }
+ * + * @param mapper a transformation to be performed in parallel + * @param the type of the collected elements + * @param the result returned by {@code mapper} + * + * @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel + * + * @since 3.0.0 + */ + public static Collector> parallelToStream(Function mapper) { + return ParallelStreamCollector.streaming(mapper); + } + /** * A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor} * and returning a {@link Stream} instance returning results as they arrive. @@ -100,6 +178,32 @@ private ParallelCollectors() { return ParallelStreamCollector.streaming(mapper, executor, parallelism); } + /** + * A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor} + * and returning a {@link Stream} instance returning results as they arrive while maintaining the initial order. + *

+ * For the parallelism of 1, the stream is executed by the calling thread. + * + *
+ * Example: + *

{@code
+     * Stream.of(1, 2, 3)
+     *   .collect(parallelToOrderedStream(i -> foo()))
+     *   .forEach(System.out::println);
+     * }
+ * + * @param mapper a transformation to be performed in parallel + * @param the type of the collected elements + * @param the result returned by {@code mapper} + * + * @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel + * + * @since 3.0.0 + */ + public static Collector> parallelToOrderedStream(Function mapper) { + return ParallelStreamCollector.streamingOrdered(mapper); + } + /** * A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor} * and returning a {@link Stream} instance returning results as they arrive while maintaining the initial order. diff --git a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java index 4902946d..acd2e4b6 100644 --- a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java +++ b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java @@ -78,6 +78,12 @@ public Set characteristics() { return characteristics; } + static Collector> streaming(Function mapper) { + requireNonNull(mapper, "mapper can't be null"); + + return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.virtual()); + } + static Collector> streaming(Function mapper, Executor executor, int parallelism) { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); @@ -86,6 +92,12 @@ public Set characteristics() { return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor, parallelism)); } + static Collector> streamingOrdered(Function mapper) { + requireNonNull(mapper, "mapper can't be null"); + + return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.virtual()); + } + static Collector> streamingOrdered(Function mapper, Executor executor, int parallelism) { requireNonNull(executor, "executor can't be null"); diff --git a/src/test/java/com/pivovarit/collectors/CompletionOrderSpliteratorTest.java b/src/test/java/com/pivovarit/collectors/CompletionOrderSpliteratorTest.java index 7a04dd3a..0dd4ea93 100644 --- a/src/test/java/com/pivovarit/collectors/CompletionOrderSpliteratorTest.java +++ b/src/test/java/com/pivovarit/collectors/CompletionOrderSpliteratorTest.java @@ -1,6 +1,5 @@ package com.pivovarit.collectors; -import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import java.util.Arrays; @@ -11,10 +10,8 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.function.Consumer; -import java.util.stream.Collectors; import java.util.stream.StreamSupport; -import static java.time.Duration.ofMillis; import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; diff --git a/src/test/java/com/pivovarit/collectors/FunctionalTest.java b/src/test/java/com/pivovarit/collectors/FunctionalTest.java index 0e2b1164..f3a318bd 100644 --- a/src/test/java/com/pivovarit/collectors/FunctionalTest.java +++ b/src/test/java/com/pivovarit/collectors/FunctionalTest.java @@ -7,7 +7,6 @@ import java.time.Duration; import java.time.LocalTime; -import java.util.Arrays; import java.util.Collection; import java.util.HashSet; import java.util.LinkedList; @@ -64,6 +63,12 @@ class FunctionalTest { @TestFactory Stream collectors() { return of( + // virtual threads + virtualThreadsTests((m, e, p) -> parallel(m, toList()), "ParallelCollectors.parallel(toList()) [virtual]", true), + virtualThreadsTests((m, e, p) -> parallel(m, toSet()), "ParallelCollectors.parallel(toSet()) [virtual]", false), + virtualThreadsTests((m, e, p) -> parallel(m, toCollection(LinkedList::new)), "ParallelCollectors.parallel(toCollection()) [virtual]", true), + virtualThreadsTests((m, e, p) -> adapt(parallel(m)), "ParallelCollectors.parallel() [virtual]", true), + // platform threads tests((m, e, p) -> parallel(m, toList(), e, p), format("ParallelCollectors.parallel(toList(), p=%d)", PARALLELISM), true), tests((m, e, p) -> parallel(m, toSet(), e, p), format("ParallelCollectors.parallel(toSet(), p=%d)", PARALLELISM), false), tests((m, e, p) -> parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.parallel(toCollection(), p=%d)", PARALLELISM), true), @@ -84,6 +89,10 @@ Stream batching_collectors() { @TestFactory Stream streaming_collectors() { return of( + // virtual threads + virtualThreadsStreamingTests((m, e, p) -> adaptAsync(parallelToStream(m)), "ParallelCollectors.parallelToStream() [virtual]", false), + virtualThreadsStreamingTests((m, e, p) -> adaptAsync(parallelToOrderedStream(m)), "ParallelCollectors.parallelToOrderedStream() [virtual]", true), + // platform threads streamingTests((m, e, p) -> adaptAsync(parallelToStream(m, e, p)), format("ParallelCollectors.parallelToStream(p=%d)", PARALLELISM), false), streamingTests((m, e, p) -> adaptAsync(parallelToOrderedStream(m, e, p)), format("ParallelCollectors.parallelToOrderedStream(p=%d)", PARALLELISM), true) ).flatMap(i -> i); @@ -92,6 +101,10 @@ Stream streaming_collectors() { @TestFactory Stream streaming_batching_collectors() { return of( + // virtual threads + batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToStream(m, e, p)), "ParallelCollectors.Batching.parallelToStream() [virtual]", false), + batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToOrderedStream(m, e, p)), "ParallelCollectors.Batching.parallelToOrderedStream(p=%d) [virtual]", true), + // platform threads batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToStream(m, e, p)), format("ParallelCollectors.Batching.parallelToStream(p=%d)", PARALLELISM), false), batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToOrderedStream(m, e, p)), format("ParallelCollectors.Batching.parallelToOrderedStream(p=%d)", PARALLELISM), true) ).flatMap(i -> i); @@ -150,6 +163,26 @@ void shouldExecuteEagerlyOnProvidedThreadPool() { } } + private static > Stream virtualThreadsTests(CollectorSupplier, Executor, Integer, Collector>> collector, String name, boolean maintainsOrder) { + var tests = of( + shouldCollect(collector, name, 1), + shouldCollect(collector, name, PARALLELISM), + shouldCollectNElementsWithNParallelism(collector, name, 1), + shouldCollectNElementsWithNParallelism(collector, name, PARALLELISM), + shouldCollectToEmpty(collector, name), + shouldStartConsumingImmediately(collector, name), + shouldNotBlockTheCallingThread(collector, name), + shouldHandleThrowable(collector, name), + shouldShortCircuitOnException(collector, name), + shouldInterruptOnException(collector, name), + shouldRemainConsistent(collector, name) + ); + + return maintainsOrder + ? Stream.concat(tests, of(shouldMaintainOrder(collector, name))) + : tests; + } + private static > Stream tests(CollectorSupplier, Executor, Integer, Collector>> collector, String name, boolean maintainsOrder) { var tests = of( shouldCollect(collector, name, 1), @@ -174,6 +207,23 @@ private static > Stream tests(Collect : tests; } + private static > Stream virtualThreadsStreamingTests(CollectorSupplier, Executor, Integer, Collector>> collector, String name, boolean maintainsOrder) { + var tests = of( + shouldCollect(collector, name, 1), + shouldCollect(collector, name, PARALLELISM), + shouldCollectToEmpty(collector, name), + shouldStartConsumingImmediately(collector, name), + shouldNotBlockTheCallingThread(collector, name), + shouldHandleThrowable(collector, name), + shouldShortCircuitOnException(collector, name), + shouldRemainConsistent(collector, name) + ); + + return maintainsOrder + ? Stream.concat(tests, of(shouldMaintainOrder(collector, name))) + : tests; + } + private static > Stream streamingTests(CollectorSupplier, Executor, Integer, Collector>> collector, String name, boolean maintainsOrder) { var tests = of( shouldCollect(collector, name, 1), @@ -306,7 +356,7 @@ private static > DynamicTest shouldShortCircuitOnE int size = 4; runWithExecutor(e -> { - LongAdder counter = new LongAdder(); + AtomicInteger counter = new AtomicInteger(); assertThatThrownBy(elements.stream() .collect(collector.apply(i -> incrementAndThrow(counter), e, PARALLELISM))::join) diff --git a/src/test/java/com/pivovarit/collectors/TestUtils.java b/src/test/java/com/pivovarit/collectors/TestUtils.java index 0105fa1b..2d790b48 100644 --- a/src/test/java/com/pivovarit/collectors/TestUtils.java +++ b/src/test/java/com/pivovarit/collectors/TestUtils.java @@ -4,7 +4,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.atomic.LongAdder; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; public final class TestUtils { @@ -36,14 +36,12 @@ public static T returnWithDelay(T value, Duration duration) { return value; } - public static Integer incrementAndThrow(LongAdder counter) { - try { - Thread.sleep(100); - } catch (InterruptedException e) { - // ignore purposefully + public static Integer incrementAndThrow(AtomicInteger counter) { + if (counter.incrementAndGet() == 10) { + throw new IllegalArgumentException(); } - counter.increment(); - throw new IllegalArgumentException(); + + return counter.intValue(); } public static void runWithExecutor(Consumer consumer, int size) {