Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control max parallelism on a dedicated thread #875

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,21 @@ public BinaryOperator<List<CompletableFuture<R>>> combiner() {

@Override
public BiConsumer<List<CompletableFuture<R>>, T> accumulator() {
return (acc, e) -> acc.add(dispatcher.enqueue(() -> mapper.apply(e)));
return (acc, e) -> {
if (!dispatcher.isRunning()) {
dispatcher.start();
}
acc.add(dispatcher.enqueue(() -> mapper.apply(e)));
};
}

@Override
public Function<List<CompletableFuture<R>>, CompletableFuture<C>> finisher() {
return futures -> combine(futures).thenApply(processor);
return futures -> {
dispatcher.stop();

return combine(futures).thenApply(processor);
};
}

@Override
Expand Down
109 changes: 84 additions & 25 deletions src/main/java/com/pivovarit/collectors/Dispatcher.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package com.pivovarit.collectors;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.function.BiConsumer;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Supplier;

import static java.lang.Runtime.getRuntime;
Expand All @@ -14,10 +19,18 @@
*/
final class Dispatcher<T> {

private static final Runnable POISON_PILL = () -> System.out.println("Why so serious?");

private final CompletableFuture<Void> completionSignaller = new CompletableFuture<>();
private final BlockingQueue<Runnable> workingQueue = new LinkedBlockingQueue<>();
private final ExecutorService dispatcher = Executors.newSingleThreadExecutor();
private final Executor executor;
private final Semaphore limiter;

private final AtomicBoolean started = new AtomicBoolean(false);

private volatile boolean shortCircuited = false;

private Dispatcher(Executor executor, int permits) {
this.executor = executor;
this.limiter = new Semaphore(permits);
Expand All @@ -27,34 +40,81 @@ static <T> Dispatcher<T> from(Executor executor, int permits) {
return new Dispatcher<>(executor, permits);
}

CompletableFuture<T> enqueue(Supplier<T> supplier) {
InterruptibleCompletableFuture<T> future = new InterruptibleCompletableFuture<>();
completionSignaller.whenComplete(shortcircuit(future));
void start() {
if (!started.getAndSet(true)) {
dispatcher.execute(() -> {
try {
while (true) {
try {
if (limiter != null) {
limiter.acquire();
}
} catch (InterruptedException e) {
handle(e);
}
Runnable task;
if ((task = workingQueue.take()) != POISON_PILL) {
executor.execute(() -> {
try {
task.run();
} finally {
if (limiter != null) {
limiter.release();
}
}
});
} else {
break;
}
}
} catch (Throwable e) {
handle(e);
}
});
}
}

void stop() {
try {
executor.execute(completionTask(supplier, future));
} catch (Throwable e) {
workingQueue.put(POISON_PILL);
} catch (InterruptedException e) {
completionSignaller.completeExceptionally(e);
CompletableFuture<T> result = new CompletableFuture<>();
result.completeExceptionally(e);
return result;
} finally {
dispatcher.shutdown();
}
}

boolean isRunning() {
return started.get();
}

CompletableFuture<T> enqueue(Supplier<T> supplier) {
InterruptibleCompletableFuture<T> future = new InterruptibleCompletableFuture<>();
workingQueue.add(completionTask(supplier, future));
completionSignaller.exceptionally(shortcircuit(future));
return future;
}

private FutureTask<T> completionTask(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) {
FutureTask<T> task = new FutureTask<>(() -> {
if (!completionSignaller.isCompletedExceptionally()) {
try {
withLimiter(supplier, future);
} catch (Throwable e) {
completionSignaller.completeExceptionally(e);
private FutureTask<Void> completionTask(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) {
FutureTask<Void> task = new FutureTask<>(() -> {
try {
if (!shortCircuited) {
future.complete(supplier.get());
}
} catch (Throwable e) {
handle(e);
}
}, null);
future.completedBy(task);
return task;
}

private void handle(Throwable e) {
shortCircuited = true;
completionSignaller.completeExceptionally(e);
dispatcher.shutdownNow();
}

private void withLimiter(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) throws InterruptedException {
try {
limiter.acquire();
Expand All @@ -64,12 +124,11 @@ private void withLimiter(Supplier<T> supplier, InterruptibleCompletableFuture<T>
}
}

private static <T> BiConsumer<T, Throwable> shortcircuit(InterruptibleCompletableFuture<?> future) {
return (__, throwable) -> {
if (throwable != null) {
future.completeExceptionally(throwable);
future.cancel(true);
}
private static Function<Throwable, Void> shortcircuit(InterruptibleCompletableFuture<?> future) {
return throwable -> {
future.completeExceptionally(throwable);
future.cancel(true);
return null;
};
}

Expand All @@ -79,15 +138,15 @@ static int getDefaultParallelism() {

static final class InterruptibleCompletableFuture<T> extends CompletableFuture<T> {

private volatile FutureTask<T> backingTask;
private volatile FutureTask<?> backingTask;

private void completedBy(FutureTask<T> task) {
private void completedBy(FutureTask<?> task) {
backingTask = task;
}

@Override
public boolean cancel(boolean mayInterruptIfRunning) {
FutureTask<T> task = backingTask;
FutureTask<?> task = backingTask;
if (task != null) {
task.cancel(mayInterruptIfRunning);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ public Supplier<List<CompletableFuture<R>>> supplier() {

@Override
public BiConsumer<List<CompletableFuture<R>>, T> accumulator() {
return (acc, e) -> acc.add(dispatcher.enqueue(() -> function.apply(e)));
return (acc, e) -> {
dispatcher.start();
acc.add(dispatcher.enqueue(() -> function.apply(e)));
};
}

@Override
Expand All @@ -71,7 +74,10 @@ public BinaryOperator<List<CompletableFuture<R>>> combiner() {

@Override
public Function<List<CompletableFuture<R>>, Stream<R>> finisher() {
return completionStrategy;
return acc -> {
dispatcher.stop();
return completionStrategy.apply(acc);
};
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package com.pivovarit.collectors;

import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.TestFactory;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toList;
import static java.util.stream.Stream.of;

class ExecutorPollutionTest {

@TestFactory
Stream<DynamicTest> shouldStartProcessingElementsTests() {
return of(
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors::parallel, "parallel#1"),
shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.parallel(f, toList(), e, p), "parallel#2"),
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors::parallelToStream, "parallelToStream"),
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors::parallelToOrderedStream, "parallelToOrderedStream"),
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors.Batching::parallel, "parallel#1 (batching)"),
shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.Batching.parallel(f, toList(), e, p), "parallel#2 (batching)"),
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors.Batching::parallelToStream, "parallelToStream (batching)"),
shouldNotSubmitMoreTasksThanParallelism(ParallelCollectors.Batching::parallelToOrderedStream, "parallelToOrderedStream (batching)")
);
}

private static DynamicTest shouldNotSubmitMoreTasksThanParallelism(CollectorFactory<Integer> collector, String name) {
return DynamicTest.dynamicTest(name, () -> {
ExecutorService e = warmedUp(new ThreadPoolExecutor(2, 2, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(2)));

Object result = Stream.generate(() -> 42)
.limit(1000)
.collect(collector.apply(i -> i, e, 2));

if (result instanceof CompletableFuture<?>) {
((CompletableFuture<?>) result).join();
} else if (result instanceof Stream<?>) {
((Stream<?>) result).forEach((__) -> {});
} else {
throw new IllegalStateException("can't happen");
}
});
}

interface CollectorFactory<T> {
Collector<T, ?, ?> apply(Function<T, ?> function, Executor executorService, int parallelism);
}

private static ThreadPoolExecutor warmedUp(ThreadPoolExecutor e) {
for (int i = 0; i < e.getCorePoolSize(); i++) {
e.submit(() -> {});
}
return e;
}
}