diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DirectStreamObserver.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DirectStreamObserver.java index 05421758ea59..c0c1e4e8e57d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DirectStreamObserver.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DirectStreamObserver.java @@ -53,7 +53,7 @@ public final class DirectStreamObserver implements StreamObserver { private final int maxMessagesBeforeCheck; private final Object lock = new Object(); - private int numMessages = -1; + private int numMessages; public DirectStreamObserver(Phaser phaser, CallStreamObserver outboundObserver) { this(phaser, outboundObserver, DEFAULT_MAX_MESSAGES_BEFORE_CHECK); @@ -69,7 +69,7 @@ public DirectStreamObserver(Phaser phaser, CallStreamObserver outboundObserve @Override public void onNext(T value) { synchronized (lock) { - if (++numMessages >= maxMessagesBeforeCheck) { + if (numMessages >= maxMessagesBeforeCheck) { numMessages = 0; int waitSeconds = 1; int totalSecondsWaited = 0; @@ -114,6 +114,7 @@ public void onNext(T value) { } } outboundObserver.onNext(value); + numMessages += 1; } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/stream/DirectStreamObserverTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/stream/DirectStreamObserverTest.java index 1d442c5f84e0..5d1c601884d5 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/stream/DirectStreamObserverTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/stream/DirectStreamObserverTest.java @@ -36,6 +36,7 @@ import org.apache.beam.sdk.fn.test.TestExecutors; import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService; import org.apache.beam.sdk.fn.test.TestStreams; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; @@ -224,17 +225,19 @@ public void testIsReadyCheckDoesntBlockIfPhaserCallbackNeverHappens() throws Exc public void testMessageCheckInterval() throws Exception { final AtomicInteger index = new AtomicInteger(); ArrayListMultimap values = ArrayListMultimap.create(); + + // An observer that is always ready but puts items into a new bucket each time it is queried + CallStreamObserver bucketingObserver = + TestStreams.withOnNext((String t) -> assertTrue(values.put(index.get(), t))) + .withIsReady( + () -> { + index.incrementAndGet(); + return true; + }) + .build(); + final DirectStreamObserver streamObserver = - new DirectStreamObserver<>( - new AdvancingPhaser(1), - TestStreams.withOnNext((String t) -> assertTrue(values.put(index.get(), t))) - .withIsReady( - () -> { - index.incrementAndGet(); - return true; - }) - .build(), - 10); + new DirectStreamObserver<>(new AdvancingPhaser(1), bucketingObserver, 10); List prefixes = ImmutableList.of("0", "1", "2", "3", "4"); List> results = new ArrayList<>();