Skip to content

Commit

Permalink
Control max parallelism on a dedicated thread (#875)
Browse files Browse the repository at this point in the history
Addresses: #867
pivovarit authored May 1, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 24f0891 commit 22a84b7
Showing 4 changed files with 166 additions and 29 deletions.
13 changes: 11 additions & 2 deletions src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java
Original file line number Diff line number Diff line change
@@ -57,12 +57,21 @@ public BinaryOperator<List<CompletableFuture<R>>> combiner() {

@Override
public BiConsumer<List<CompletableFuture<R>>, T> accumulator() {
return (acc, e) -> acc.add(dispatcher.enqueue(() -> mapper.apply(e)));
return (acc, e) -> {
if (!dispatcher.isRunning()) {
dispatcher.start();
}
acc.add(dispatcher.enqueue(() -> mapper.apply(e)));
};
}

@Override
public Function<List<CompletableFuture<R>>, CompletableFuture<C>> finisher() {
return futures -> combine(futures).thenApply(processor);
return futures -> {
dispatcher.stop();

return combine(futures).thenApply(processor);
};
}

@Override
109 changes: 84 additions & 25 deletions src/main/java/com/pivovarit/collectors/Dispatcher.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package com.pivovarit.collectors;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.function.BiConsumer;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Supplier;

import static java.lang.Runtime.getRuntime;
@@ -14,10 +19,18 @@
*/
final class Dispatcher<T> {

private static final Runnable POISON_PILL = () -> System.out.println("Why so serious?");

private final CompletableFuture<Void> completionSignaller = new CompletableFuture<>();
private final BlockingQueue<Runnable> workingQueue = new LinkedBlockingQueue<>();
private final ExecutorService dispatcher = Executors.newSingleThreadExecutor();
private final Executor executor;
private final Semaphore limiter;

private final AtomicBoolean started = new AtomicBoolean(false);

private volatile boolean shortCircuited = false;

private Dispatcher(Executor executor, int permits) {
this.executor = executor;
this.limiter = new Semaphore(permits);
@@ -27,34 +40,81 @@ static <T> Dispatcher<T> from(Executor executor, int permits) {
return new Dispatcher<>(executor, permits);
}

CompletableFuture<T> enqueue(Supplier<T> supplier) {
InterruptibleCompletableFuture<T> future = new InterruptibleCompletableFuture<>();
completionSignaller.whenComplete(shortcircuit(future));
void start() {
if (!started.getAndSet(true)) {
dispatcher.execute(() -> {
try {
while (true) {
try {
if (limiter != null) {
limiter.acquire();
}
} catch (InterruptedException e) {
handle(e);
}
Runnable task;
if ((task = workingQueue.take()) != POISON_PILL) {
executor.execute(() -> {
try {
task.run();
} finally {
if (limiter != null) {
limiter.release();
}
}
});
} else {
break;
}
}
} catch (Throwable e) {
handle(e);
}
});
}
}

void stop() {
try {
executor.execute(completionTask(supplier, future));
} catch (Throwable e) {
workingQueue.put(POISON_PILL);
} catch (InterruptedException e) {
completionSignaller.completeExceptionally(e);
CompletableFuture<T> result = new CompletableFuture<>();
result.completeExceptionally(e);
return result;
} finally {
dispatcher.shutdown();
}
}

boolean isRunning() {
return started.get();
}

CompletableFuture<T> enqueue(Supplier<T> supplier) {
InterruptibleCompletableFuture<T> future = new InterruptibleCompletableFuture<>();
workingQueue.add(completionTask(supplier, future));
completionSignaller.exceptionally(shortcircuit(future));
return future;
}

private FutureTask<T> completionTask(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) {
FutureTask<T> task = new FutureTask<>(() -> {
if (!completionSignaller.isCompletedExceptionally()) {
try {
withLimiter(supplier, future);
} catch (Throwable e) {
completionSignaller.completeExceptionally(e);
private FutureTask<Void> completionTask(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) {
FutureTask<Void> task = new FutureTask<>(() -> {
try {
if (!shortCircuited) {
future.complete(supplier.get());
}
} catch (Throwable e) {
handle(e);
}
}, null);
future.completedBy(task);
return task;
}

private void handle(Throwable e) {
shortCircuited = true;
completionSignaller.completeExceptionally(e);
dispatcher.shutdownNow();
}

private void withLimiter(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) throws InterruptedException {
try {
limiter.acquire();
@@ -64,12 +124,11 @@ private void withLimiter(Supplier<T> supplier, InterruptibleCompletableFuture<T>
}
}

private static <T> BiConsumer<T, Throwable> shortcircuit(InterruptibleCompletableFuture<?> future) {
return (__, throwable) -> {
if (throwable != null) {
future.completeExceptionally(throwable);
future.cancel(true);
}
private static Function<Throwable, Void> shortcircuit(InterruptibleCompletableFuture<?> future) {
return throwable -> {
future.completeExceptionally(throwable);
future.cancel(true);
return null;
};
}

@@ -79,15 +138,15 @@ static int getDefaultParallelism() {

static final class InterruptibleCompletableFuture<T> extends CompletableFuture<T> {

private volatile FutureTask<T> backingTask;
private volatile FutureTask<?> backingTask;

private void completedBy(FutureTask<T> task) {
private void completedBy(FutureTask<?> task) {
backingTask = task;
}

@Override
public boolean cancel(boolean mayInterruptIfRunning) {
FutureTask<T> task = backingTask;
FutureTask<?> task = backingTask;
if (task != null) {
task.cancel(mayInterruptIfRunning);
}
Original file line number Diff line number Diff line change
@@ -58,7 +58,10 @@ public Supplier<List<CompletableFuture<R>>> supplier() {

@Override
public BiConsumer<List<CompletableFuture<R>>, T> accumulator() {
return (acc, e) -> acc.add(dispatcher.enqueue(() -> function.apply(e)));
return (acc, e) -> {
dispatcher.start();
acc.add(dispatcher.enqueue(() -> function.apply(e)));
};
}

@Override
@@ -71,7 +74,10 @@ public BinaryOperator<List<CompletableFuture<R>>> combiner() {

@Override
public Function<List<CompletableFuture<R>>, Stream<R>> finisher() {
return completionStrategy;
return acc -> {
dispatcher.stop();
return completionStrategy.apply(acc);
};
}

@Override
63 changes: 63 additions & 0 deletions src/test/java/com/pivovarit/collectors/ExecutorPollutionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package com.pivovarit.collectors;

import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.TestFactory;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toList;
import static java.util.stream.Stream.of;

class ExecutorPollutionTest {

@TestFactory
Stream<DynamicTest> shouldStartProcessingElementsTests() {
return of(
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors::parallel, "parallel#1"),
shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.parallel(f, toList(), e, p), "parallel#2"),
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors::parallelToStream, "parallelToStream"),
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors::parallelToOrderedStream, "parallelToOrderedStream"),
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors.Batching::parallel, "parallel#1 (batching)"),
shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.Batching.parallel(f, toList(), e, p), "parallel#2 (batching)"),
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors.Batching::parallelToStream, "parallelToStream (batching)"),
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors.Batching::parallelToOrderedStream, "parallelToOrderedStream (batching)")
);
}

private static DynamicTest shouldNotSubmitMoreTasksThanParallelism(CollectorFactory<Integer> collector, String name) {
return DynamicTest.dynamicTest(name, () -> {
ExecutorService e = warmedUp(new ThreadPoolExecutor(2, 2, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(2)));

Object result = Stream.generate(() -> 42)
.limit(1000)
.collect(collector.apply(i -> i, e, 2));

if (result instanceof CompletableFuture<?>) {
((CompletableFuture<?>) result).join();
} else if (result instanceof Stream<?>) {
((Stream<?>) result).forEach((__) -> {});
} else {
throw new IllegalStateException("can't happen");
}
});
}

interface CollectorFactory<T> {
Collector<T, ?, ?> apply(Function<T, ?> function, Executor executorService, int parallelism);
}

private static ThreadPoolExecutor warmedUp(ThreadPoolExecutor e) {
for (int i = 0; i < e.getCorePoolSize(); i++) {
e.submit(() -> {});
}
return e;
}
}

0 comments on commit 22a84b7

Please sign in to comment.