diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisher.java b/core/src/main/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisher.java index 1a14e3f702ec..f6d070de1e75 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisher.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisher.java @@ -65,14 +65,14 @@ static final class SurroundingSubscriber implements Subscriber, Subscripti AtomicLongFieldUpdater.newUpdater(SurroundingSubscriber.class, "needToPublish"); enum State { - SENDING_HEAD, - SENDING_BODY, - SENDING_TAIL, - SENDING_COMPLETE, + REQUIRE_HEAD, + REQUIRE_BODY, + REQUIRE_TAIL, + REQUIRE_COMPLETE, DONE, } - private volatile State state = State.SENDING_HEAD; + private volatile State state; @Nullable private final T head; @@ -85,12 +85,14 @@ enum State { private volatile long requested; private volatile long needToPublish; + private volatile boolean subscribed; private volatile boolean cancelled; SurroundingSubscriber(@Nullable T head, Publisher publisher, @Nullable T tail, Subscriber downstream) { requireNonNull(publisher, "publisher"); requireNonNull(downstream, "downstream"); + state = head != null ? State.REQUIRE_HEAD : State.REQUIRE_BODY; this.head = head; this.publisher = publisher; this.tail = tail; @@ -117,42 +119,56 @@ public void request(long n) { } } - switch (state) { - case SENDING_HEAD: { - setState(State.SENDING_HEAD, State.SENDING_BODY); - if (head != null) { - downstream.onNext(head); - requestedUpdater.decrementAndGet(this); - } - publisher.subscribe(this); - return; - } - case SENDING_BODY: { - requestUpstream(upstream); + publish(); + } + + private void publish() { + for (;;) { + if (requested <= 0) { return; } - case SENDING_TAIL: { - sendTail(); - if (n > 1) { + switch (state) { + case REQUIRE_HEAD: { + sendHead(); + continue; + } + case REQUIRE_BODY: { + if (!subscribed) { + subscribed = true; + publisher.subscribe(this); + return; + } + if (upstream != null) { + requestUpstream(upstream); + return; + } + } + case REQUIRE_TAIL: { + sendTail(); + continue; + } + case REQUIRE_COMPLETE: { sendComplete(); + return; + } + case DONE: { + upstream.cancel(); + return; } - return; - } - case SENDING_COMPLETE: { - sendComplete(); - return; - } - case DONE: { - upstream.cancel(); - return; } } } + private void sendHead() { + setState(State.REQUIRE_HEAD, State.REQUIRE_BODY); + requestedUpdater.decrementAndGet(this); + downstream.onNext(head); + } + private void sendTail() { - setState(State.SENDING_TAIL, State.SENDING_COMPLETE); + setState(State.REQUIRE_TAIL, State.REQUIRE_COMPLETE); + requestedUpdater.decrementAndGet(this); if (tail != null) { - requestedUpdater.decrementAndGet(this); downstream.onNext(tail); } else { sendComplete(); @@ -160,10 +176,8 @@ private void sendTail() { } private void sendComplete() { - if (state == State.DONE) { - return; - } - setState(State.SENDING_COMPLETE, State.DONE); + setState(State.REQUIRE_COMPLETE, State.DONE); + requestedUpdater.decrementAndGet(this); downstream.onComplete(); } @@ -173,18 +187,13 @@ private void requestUpstream(Subscription subscription) { if (requested == 0) { return; } - if (state == State.SENDING_TAIL) { - sendTail(); - if (requested > 1) { - sendComplete(); - } - return; - } if (requestedUpdater.compareAndSet(this, requested, 0)) { - if (needToPublishUpdater.get(this) > Long.MAX_VALUE - requested) { - needToPublishUpdater.set(this, Long.MAX_VALUE); - } else { - needToPublishUpdater.addAndGet(this, requested); + for (;;) { + final long oldNeedToPublish = needToPublish; + final long newNeedToPublish = LongMath.saturatedAdd(oldNeedToPublish, requested); + if (needToPublishUpdater.compareAndSet(this, oldNeedToPublish, newNeedToPublish)) { + break; + } } subscription.request(requested); return; @@ -222,15 +231,10 @@ public void onError(Throwable cause) { @Override public void onComplete() { - final long needToPublish = needToPublishUpdater.get(this); - setState(State.SENDING_BODY, State.SENDING_TAIL); - if (needToPublish == 0) { - return; - } - - sendTail(); - if (needToPublish > 1) { - sendComplete(); + setState(State.REQUIRE_BODY, State.REQUIRE_TAIL); + if (needToPublish > 0) { + requestedUpdater.addAndGet(this, needToPublish); + publish(); } } @@ -248,7 +252,7 @@ public void cancel() { } private boolean setState(State oldState, State newState) { - assert newState != State.SENDING_HEAD : "oldState: " + oldState + ", newState: " + newState; + assert newState != State.REQUIRE_HEAD : "oldState: " + oldState + ", newState: " + newState; return stateUpdater.compareAndSet(this, oldState, newState); } } diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisherTckTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisherTckTest.java index 3796e7795356..e865db76dccf 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisherTckTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisherTckTest.java @@ -45,7 +45,7 @@ public Publisher createPublisher(long elements) { return new SurroundingPublisher<>("head", Mono.empty(), null); } if (elements == 2) { - return new SurroundingPublisher<>("head", Mono.empty(), "tail"); + return new SurroundingPublisher<>(null, Mono.just(1), "tail"); } return new SurroundingPublisher<>("head", Flux.fromStream(LongStream.range(0, elements - 2).boxed()),