diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java b/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java index 72de35fef45..6aa86c2ee2b 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java @@ -50,6 +50,7 @@ import com.linecorp.armeria.common.stream.SubscriptionOption; import com.linecorp.armeria.internal.common.DefaultHttpRequest; import com.linecorp.armeria.internal.common.DefaultSplitHttpRequest; +import com.linecorp.armeria.internal.common.stream.SurroundingPublisher; import com.linecorp.armeria.unsafe.PooledObjects; import io.netty.buffer.ByteBufAllocator; @@ -282,6 +283,26 @@ static HttpRequest of(RequestHeaders headers, Publisher pu } } + /** + * Creates a new instance from an existing {@link RequestHeaders}, {@link Publisher} and trailers. + * + *

Note that the {@link HttpData}s in the {@link Publisher} are not released when + * {@link Subscription#cancel()} or {@link #abort()} is called. You should add a hook in order to + * release the elements. See {@link PublisherBasedStreamMessage} for more information. + */ + @UnstableApi + static HttpRequest of(RequestHeaders headers, + Publisher publisher, + HttpHeaders trailers) { + requireNonNull(headers, "headers"); + requireNonNull(publisher, "publisher"); + requireNonNull(trailers, "trailers"); + if (trailers.isEmpty()) { + return of(headers, publisher); + } + return of(headers, new SurroundingPublisher<>(null, publisher, trailers)); + } + /** * Creates a new HTTP request whose {@link Publisher} is produced by the specified * {@link CompletionStage}. If the specified {@link CompletionStage} fails, the returned request will be diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java b/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java index 23cd7249311..ecc8fe74685 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java @@ -488,6 +488,23 @@ static HttpResponse of(ResponseHeaders headers, Publisher return PublisherBasedHttpResponse.from(headers, publisher); } + /** + * Creates a new HTTP response with the specified headers and trailers + * whose stream is produced from an existing {@link Publisher}. + * + *

Note that the {@link HttpData}s in the {@link Publisher} are not released when + * {@link Subscription#cancel()} or {@link #abort()} is called. You should add a hook in order to + * release the elements. See {@link PublisherBasedStreamMessage} for more information. + */ + static HttpResponse of(ResponseHeaders headers, + Publisher publisher, + HttpHeaders trailers) { + requireNonNull(headers, "headers"); + requireNonNull(publisher, "publisher"); + requireNonNull(trailers, "trailers"); + return PublisherBasedHttpResponse.from(headers, publisher, trailers); + } + /** * Creates a new HTTP response that delegates to the {@link HttpResponse} produced by the specified * {@link CompletionStage}. If the specified {@link CompletionStage} fails, the returned response will be diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpResponseBuilder.java b/core/src/main/java/com/linecorp/armeria/common/HttpResponseBuilder.java index dd7b795f33b..98d918906aa 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpResponseBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpResponseBuilder.java @@ -28,7 +28,6 @@ import com.google.errorprone.annotations.FormatString; import com.linecorp.armeria.common.annotation.UnstableApi; -import com.linecorp.armeria.common.stream.StreamMessage; /** * Builds a new {@link HttpResponse}. @@ -299,8 +298,7 @@ public HttpResponse build() { if (trailers == null) { return HttpResponse.of(responseHeaders, publisher); } else { - return HttpResponse.of(responseHeaders, - StreamMessage.concat(publisher, StreamMessage.of(trailers.build()))); + return HttpResponse.of(responseHeaders, publisher, trailers.build()); } } } diff --git a/core/src/main/java/com/linecorp/armeria/common/PublisherBasedHttpResponse.java b/core/src/main/java/com/linecorp/armeria/common/PublisherBasedHttpResponse.java index aab05a2c4e8..d1fd409e0b7 100644 --- a/core/src/main/java/com/linecorp/armeria/common/PublisherBasedHttpResponse.java +++ b/core/src/main/java/com/linecorp/armeria/common/PublisherBasedHttpResponse.java @@ -21,12 +21,21 @@ import org.reactivestreams.Publisher; import com.linecorp.armeria.common.stream.PublisherBasedStreamMessage; -import com.linecorp.armeria.internal.common.stream.PrependingPublisher; +import com.linecorp.armeria.internal.common.stream.SurroundingPublisher; final class PublisherBasedHttpResponse extends PublisherBasedStreamMessage implements HttpResponse { static PublisherBasedHttpResponse from(ResponseHeaders headers, Publisher publisher) { - return new PublisherBasedHttpResponse(new PrependingPublisher<>(headers, publisher)); + return new PublisherBasedHttpResponse(new SurroundingPublisher<>(headers, publisher, null)); + } + + static PublisherBasedHttpResponse from(ResponseHeaders headers, + Publisher publisher, + HttpHeaders trailers) { + if (trailers.isEmpty()) { + return from(headers, publisher); + } + return new PublisherBasedHttpResponse(new SurroundingPublisher<>(headers, publisher, trailers)); } PublisherBasedHttpResponse(Publisher publisher) { diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/stream/PrependingPublisher.java b/core/src/main/java/com/linecorp/armeria/internal/common/stream/PrependingPublisher.java deleted file mode 100644 index 4abd85bfc0e..00000000000 --- a/core/src/main/java/com/linecorp/armeria/internal/common/stream/PrependingPublisher.java +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Copyright 2020 LINE Corporation - * - * LINE Corporation licenses this file to you under the Apache License, - * version 2.0 (the "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at: - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ - -package com.linecorp.armeria.internal.common.stream; - -import static java.util.Objects.requireNonNull; - -import java.util.concurrent.atomic.AtomicLongFieldUpdater; - -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import com.google.common.math.LongMath; - -import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.common.stream.NoopSubscriber; - -public final class PrependingPublisher implements Publisher { - - private final T first; - private final Publisher rest; - - public PrependingPublisher(T first, Publisher rest) { - this.first = first; - this.rest = rest; - } - - @Override - public void subscribe(Subscriber subscriber) { - requireNonNull(subscriber, "subscriber"); - final RestSubscriber restSubscriber = new RestSubscriber<>(first, rest, subscriber); - subscriber.onSubscribe(restSubscriber); - } - - static final class RestSubscriber implements Subscriber, Subscription { - - @SuppressWarnings("rawtypes") - private static final AtomicLongFieldUpdater demandUpdater = - AtomicLongFieldUpdater.newUpdater(RestSubscriber.class, "demand"); - - private final T first; - private final Publisher rest; - private Subscriber downstream; - @Nullable - private volatile Subscription upstream; - private volatile long demand; - private boolean firstSent; - private boolean subscribed; - private volatile boolean cancelled; - - RestSubscriber(T first, Publisher rest, Subscriber downstream) { - this.first = first; - this.rest = rest; - this.downstream = downstream; - } - - @Override - public void request(long n) { - if (n <= 0) { - downstream.onError(new IllegalArgumentException("non-positive request signals are illegal")); - return; - } - if (cancelled) { - return; - } - for (;;) { - final long demand = this.demand; - final long newDemand = LongMath.saturatedAdd(demand, n); - if (demandUpdater.compareAndSet(this, demand, newDemand)) { - if (demand > 0) { - return; - } - break; - } - } - if (!firstSent) { - firstSent = true; - downstream.onNext(first); - if (demand != Long.MAX_VALUE) { - demandUpdater.decrementAndGet(this); - } - } - if (!subscribed) { - subscribed = true; - rest.subscribe(this); - } - if (demand == 0) { - return; - } - final Subscription upstream = this.upstream; - if (upstream != null) { - final long demand = this.demand; - if (demand > 0) { - if (demandUpdater.compareAndSet(this, demand, 0)) { - upstream.request(demand); - } - } - } - } - - @Override - public void cancel() { - if (cancelled) { - return; - } - cancelled = true; - downstream = NoopSubscriber.get(); - final Subscription upstream = this.upstream; - if (upstream != null) { - upstream.cancel(); - } - } - - @Override - public void onSubscribe(Subscription subscription) { - if (cancelled) { - subscription.cancel(); - return; - } - upstream = subscription; - for (;;) { - final long demand = this.demand; - if (demand == 0) { - break; - } - if (demandUpdater.compareAndSet(this, demand, 0)) { - subscription.request(demand); - } - } - } - - @Override - public void onNext(T t) { - requireNonNull(t, "element"); - downstream.onNext(t); - } - - @Override - public void onError(Throwable t) { - requireNonNull(t, "throwable"); - downstream.onError(t); - } - - @Override - public void onComplete() { - downstream.onComplete(); - } - } -} diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisher.java b/core/src/main/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisher.java new file mode 100644 index 00000000000..8d003f69107 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisher.java @@ -0,0 +1,457 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.common.stream; + +import static com.linecorp.armeria.internal.common.stream.InternalStreamMessageUtil.containsNotifyCancellation; +import static java.util.Objects.requireNonNull; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.google.common.math.LongMath; + +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.stream.AbortedStreamException; +import com.linecorp.armeria.common.stream.CancelledSubscriptionException; +import com.linecorp.armeria.common.stream.NoopSubscriber; +import com.linecorp.armeria.common.stream.PublisherBasedStreamMessage; +import com.linecorp.armeria.common.stream.StreamMessage; +import com.linecorp.armeria.common.stream.SubscriptionOption; +import com.linecorp.armeria.common.util.EventLoopCheckingFuture; + +import io.netty.util.concurrent.EventExecutor; + +public final class SurroundingPublisher implements StreamMessage { + + @SuppressWarnings("rawtypes") + private static final AtomicIntegerFieldUpdater subscribedUpdater = + AtomicIntegerFieldUpdater.newUpdater(SurroundingPublisher.class, "subscribed"); + + @Nullable + private final T head; + private final StreamMessage publisher; + @Nullable + private final T tail; + + private volatile int subscribed; + private final CompletableFuture completionFuture = new EventLoopCheckingFuture<>(); + + @Nullable + private volatile SurroundingSubscriber surroundingSubscriber; + + @SuppressWarnings("unchecked") + public SurroundingPublisher(@Nullable T head, Publisher publisher, @Nullable T tail) { + requireNonNull(publisher, "publisher"); + this.head = head; + if (publisher instanceof StreamMessage) { + this.publisher = (StreamMessage) publisher; + } else { + this.publisher = new PublisherBasedStreamMessage<>(publisher); + } + this.tail = tail; + } + + @Override + public boolean isOpen() { + return !completionFuture.isDone(); + } + + @Override + public boolean isEmpty() { + if (isOpen()) { + return false; + } + final SurroundingSubscriber surroundingSubscriber = this.surroundingSubscriber; + return surroundingSubscriber == null || !surroundingSubscriber.publishedAny; + } + + @Override + public long demand() { + final SurroundingSubscriber surroundingSubscriber = this.surroundingSubscriber; + if (surroundingSubscriber != null) { + return surroundingSubscriber.requested; + } else { + return 0; + } + } + + @Override + public CompletableFuture whenComplete() { + return completionFuture; + } + + @Override + public void subscribe(Subscriber subscriber, EventExecutor executor, + SubscriptionOption... options) { + requireNonNull(subscriber, "subscriber"); + requireNonNull(executor, "executor"); + requireNonNull(options, "options"); + + if (!subscribedUpdater.compareAndSet(this, 0, 1)) { + subscriber.onSubscribe(NoopSubscription.get()); + if (completionFuture.isCompletedExceptionally()) { + completionFuture.exceptionally(cause -> { + subscriber.onError(cause); + return null; + }); + } else { + subscriber.onError(new IllegalStateException("Only single subscriber is allowed!")); + } + return; + } + + if (executor.inEventLoop()) { + subscribe0(subscriber, executor, options); + } else { + executor.execute(() -> subscribe0(subscriber, executor, options)); + } + } + + private void subscribe0(Subscriber subscriber, EventExecutor executor, + SubscriptionOption... options) { + + final SurroundingSubscriber surroundingSubscriber = new SurroundingSubscriber<>( + head, publisher, tail, subscriber, executor, completionFuture, options); + this.surroundingSubscriber = surroundingSubscriber; + subscriber.onSubscribe(surroundingSubscriber); + + // To make sure to close the SurroundingSubscriber when this is aborted. + if (completionFuture.isCompletedExceptionally()) { + completionFuture.exceptionally(cause -> { + surroundingSubscriber.close(cause); + return null; + }); + } + } + + @Override + public void abort() { + abort(AbortedStreamException.get()); + } + + @Override + public void abort(Throwable cause) { + requireNonNull(cause, "cause"); + + // `completionFuture` should be set before `SurroundingSubscriber` publishes data + // to guarantee the visibility of the abortion `cause` after + // SurroundingSubscriber is set in `subscriber0()`. + completionFuture.completeExceptionally(cause); + + if (subscribedUpdater.compareAndSet(this, 0, 1)) { + publisher.abort(cause); + if (head != null) { + StreamMessageUtil.closeOrAbort(head, cause); + } + if (tail != null) { + StreamMessageUtil.closeOrAbort(tail, cause); + } + return; + } + + final SurroundingSubscriber surroundingSubscriber = this.surroundingSubscriber; + if (surroundingSubscriber != null) { + surroundingSubscriber.close(cause); + } + } + + private static final class SurroundingSubscriber implements Subscriber, Subscription { + + enum State { + REQUIRE_HEAD, + REQUIRE_BODY, + REQUIRE_TAIL, + DONE, + } + + private State state; + + @Nullable + private T head; + private final StreamMessage publisher; + @Nullable + private T tail; + + private Subscriber downstream; + private final EventExecutor executor; + @Nullable + private volatile Subscription upstream; + + private long requested; + private long upstreamRequested; + private boolean subscribed; + private volatile boolean publishedAny; + + private final CompletableFuture completionFuture; + private final SubscriptionOption[] options; + + SurroundingSubscriber(@Nullable T head, StreamMessage publisher, @Nullable T tail, + Subscriber downstream, EventExecutor executor, + CompletableFuture completionFuture, SubscriptionOption... options) { + requireNonNull(publisher, "publisher"); + requireNonNull(downstream, "downstream"); + requireNonNull(executor, "executor"); + state = head != null ? State.REQUIRE_HEAD : State.REQUIRE_BODY; + this.head = head; + this.publisher = publisher; + this.tail = tail; + this.downstream = downstream; + this.executor = executor; + this.completionFuture = completionFuture; + this.options = options; + } + + @Override + public void request(long n) { + if (n <= 0) { + close(new IllegalArgumentException("non-positive request signals are illegal")); + return; + } + if (executor.inEventLoop()) { + request0(n); + } else { + executor.execute(() -> request0(n)); + } + } + + private void request0(long n) { + if (state == State.DONE) { + return; + } + + final long oldRequested = requested; + if (oldRequested == Long.MAX_VALUE) { + return; + } + if (n == Long.MAX_VALUE) { + requested = Long.MAX_VALUE; + } else { + requested = LongMath.saturatedAdd(oldRequested, n); + } + + if (oldRequested > 0) { + // SurroundingSubscriber is publishing data. + // New requests will be handled by 'publishDownstream(item)'. + return; + } + + publish(); + } + + private void publish() { + if (state == State.DONE || requested <= 0 && upstreamRequested <= 0) { + return; + } + + switch (state) { + case REQUIRE_HEAD: { + sendHead(); + break; + } + case REQUIRE_BODY: { + if (!subscribed) { + subscribed = true; + publisher.subscribe(this, executor, options); + return; + } + if (upstreamRequested > 0) { + return; + } + final Subscription upstream = this.upstream; + if (upstream != null) { + requestUpstream(upstream); + } + break; + } + case REQUIRE_TAIL: { + sendTail(); + break; + } + } + } + + private void sendHead() { + setState(State.REQUIRE_HEAD, State.REQUIRE_BODY); + assert head != null; + final T head = this.head; + this.head = null; + publishDownstream(head, true); + } + + private void sendTail() { + assert state == State.REQUIRE_TAIL; + if (tail != null) { + final T tail = this.tail; + this.tail = null; + downstream.onNext(tail); + } + close0(null); + } + + private void requestUpstream(Subscription subscription) { + if (requested <= 0) { + return; + } + assert upstreamRequested == 0; + upstreamRequested = requested; + if (requested < Long.MAX_VALUE) { + requested = 0; + } + subscription.request(upstreamRequested); + } + + private void publishDownstream(T item, boolean head) { + requireNonNull(item, "item"); + if (state == State.DONE) { + StreamMessageUtil.closeOrAbort(item); + return; + } + downstream.onNext(item); + + if (head) { + if (requested < Long.MAX_VALUE) { + requested--; + } + subscribed = true; + publisher.subscribe(this, executor, options); + } else { + assert upstreamRequested > 0; + if (upstreamRequested < Long.MAX_VALUE) { + upstreamRequested--; + } + } + + if (!publishedAny) { + publishedAny = true; + } + + publish(); + } + + @Override + public void onSubscribe(Subscription subscription) { + requireNonNull(subscription, "subscription"); + if (state == State.DONE) { + subscription.cancel(); + return; + } + upstream = subscription; + requestUpstream(subscription); + } + + @Override + public void onNext(T item) { + requireNonNull(item, "item"); + publishDownstream(item, false); + } + + @Override + public void onError(Throwable cause) { + requireNonNull(cause, "cause"); + close(cause); + } + + @Override + public void onComplete() { + if (state == State.DONE) { + return; + } + setState(State.REQUIRE_BODY, State.REQUIRE_TAIL); + if (tail != null) { + publish(); + } else { + close0(null); + } + } + + @Override + public void cancel() { + if (executor.inEventLoop()) { + cancel0(); + } else { + executor.execute(this::cancel0); + } + } + + private void cancel0() { + if (state == State.DONE) { + return; + } + state = State.DONE; + + final Subscription upstream = this.upstream; + if (upstream != null) { + upstream.cancel(); + } + final CancelledSubscriptionException cause = CancelledSubscriptionException.get(); + if (containsNotifyCancellation(options)) { + downstream.onError(cause); + } + downstream = NoopSubscriber.get(); + completionFuture.completeExceptionally(cause); + release(null); + } + + private void close(@Nullable Throwable cause) { + if (executor.inEventLoop()) { + close0(cause); + } else { + executor.execute(() -> close0(cause)); + } + } + + private void close0(@Nullable Throwable cause) { + if (state == State.DONE) { + return; + } + state = State.DONE; + + if (cause == null) { + downstream.onComplete(); + completionFuture.complete(null); + } else { + final Subscription upstream = this.upstream; + if (upstream != null) { + upstream.cancel(); + } + downstream.onError(cause); + completionFuture.completeExceptionally(cause); + } + release(cause); + } + + private void release(@Nullable Throwable cause) { + if (head != null) { + StreamMessageUtil.closeOrAbort(head, cause); + } + if (tail != null) { + StreamMessageUtil.closeOrAbort(tail, cause); + } + } + + private void setState(State oldState, State newState) { + assert state == oldState + : "curState: " + state + ", oldState: " + oldState + ", newState: " + newState; + assert newState != State.REQUIRE_HEAD : "oldState: " + oldState + ", newState: " + newState; + state = newState; + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/common/HttpResponseBuilderTest.java b/core/src/test/java/com/linecorp/armeria/common/HttpResponseBuilderTest.java index 4c7eedfeea9..1bcd166db74 100644 --- a/core/src/test/java/com/linecorp/armeria/common/HttpResponseBuilderTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/HttpResponseBuilderTest.java @@ -261,6 +261,39 @@ void buildComplex() { assertThat(aggregatedRes.trailers().get("trailer-name")).isEqualTo("trailer-value"); } + @Test + void buildWithHeadersAndPublisherContentAndTrailers() { + final HttpResponse res = HttpResponse.builder() + .ok() + .headers(HttpHeaders.of("header-1", + "header-value1", + "header-2", + "header-value2")) + .content(MediaType.PLAIN_TEXT_UTF_8, + StreamMessage.of( + HttpData.ofUtf8( + "Armeriaはいろんな使い方がアルメリア" + ) + )) + .trailers(HttpHeaders.of("trailer-1", + "trailer-value1", + "trailer-2", + "trailer-value2")) + .build(); + final AggregatedHttpResponse aggregatedRes = res.aggregate().join(); + assertThat(aggregatedRes.status()).isEqualTo(HttpStatus.OK); + assertThat(aggregatedRes.headers().contains("header-1")).isTrue(); + assertThat(aggregatedRes.headers().contains("header-2")).isTrue(); + assertThat(aggregatedRes.headers().get("header-1")).isEqualTo("header-value1"); + assertThat(aggregatedRes.headers().get("header-2")).isEqualTo("header-value2"); + assertThat(aggregatedRes.contentUtf8()).isEqualTo("Armeriaはいろんな使い方がアルメリア"); + assertThat(aggregatedRes.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); + assertThat(aggregatedRes.trailers().contains("trailer-1")).isTrue(); + assertThat(aggregatedRes.trailers().contains("trailer-2")).isTrue(); + assertThat(aggregatedRes.trailers().get("trailer-1")).isEqualTo("trailer-value1"); + assertThat(aggregatedRes.trailers().get("trailer-2")).isEqualTo("trailer-value2"); + } + static class SampleObject { private final int id; diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/stream/PrependingPublisherTckTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisherTckTest.java similarity index 69% rename from core/src/test/java/com/linecorp/armeria/internal/common/stream/PrependingPublisherTckTest.java rename to core/src/test/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisherTckTest.java index 3f73de8e279..9269b986dfe 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/common/stream/PrependingPublisherTckTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisherTckTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 LINE Corporation + * Copyright 2023 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance @@ -16,48 +16,80 @@ package com.linecorp.armeria.internal.common.stream; +import java.util.concurrent.atomic.AtomicLong; import java.util.stream.LongStream; -import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; -import org.reactivestreams.tck.PublisherVerification; import org.reactivestreams.tck.TestEnvironment; import org.reactivestreams.tck.flow.support.PublisherVerificationRules; -import org.testng.Assert; import org.testng.SkipException; import org.testng.annotations.Test; +import com.google.common.math.LongMath; + +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.stream.StreamMessage; +import com.linecorp.armeria.common.stream.StreamMessageVerification; + import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @SuppressWarnings("checkstyle:LineLength") -public class PrependingPublisherTckTest extends PublisherVerification { +public class SurroundingPublisherTckTest extends StreamMessageVerification { - public PrependingPublisherTckTest() { + public SurroundingPublisherTckTest() { super(new TestEnvironment(200)); } @Override - public Publisher createPublisher(long elements) { + public StreamMessage createPublisher(long elements) { if (elements == 0) { - return Mono.empty(); + final StreamMessage publisher = new SurroundingPublisher<>(null, Mono.empty(), null); + // `SurroundingPublisher` doesn't check head, tail, publisher's availability before subscribed so manually set complete. + publisher.whenComplete().complete(null); + return publisher; + } + if (elements == 1) { + return new SurroundingPublisher<>("head", Mono.empty(), null); } - return new PrependingPublisher<>("Hello", Flux.fromStream(LongStream.range(0, elements - 1).boxed())); + if (elements == 2) { + return new SurroundingPublisher<>(null, Mono.just(1), "tail"); + } + return new SurroundingPublisher<>("head", Flux.fromStream(LongStream.range(0, elements - 2).boxed()), "tail"); } /** * Rule 1.4 and 1.9 ensure a Publisher's ability to signal error to the Subscriber, however the * implementation expects such error to occur immediately after subscribing, i.e. {@code onError()} is - * called after {@code onSubscribe()}. The {@link PrependingPublisher} however always serves at least one - * element before failing, therefore for the error to be signaled, we must make requests first. + * called after {@code onSubscribe()}. + * The {@link SurroundingPublisher} however subscribes publisher and signals error after the first request, + * therefore for the error to be signaled, we must make requests first. * * {@link PublisherVerificationRules#optional_spec104_mustSignalOnErrorWhenFails()} and * {@link PublisherVerificationRules#required_spec109_mayRejectCallsToSubscribeIfPublisherIsUnableOrUnwillingToServeThemRejectionMustTriggerOnErrorAfterOnSubscribe()} * are overridden below to call {@link Subscription#request(long)} after subscribing. */ @Override - public Publisher createFailedPublisher() { - return new PrependingPublisher<>("Hello", Mono.error(new RuntimeException())); + public StreamMessage createFailedPublisher() { + return new SurroundingPublisher<>(null, Mono.error(new RuntimeException()), "tail"); + } + + @Override + public @Nullable StreamMessage createAbortedPublisher(long elements) { + if (elements == 0) { + final StreamMessage publisher = new SurroundingPublisher<>(null, Mono.empty(), null); + publisher.abort(); + return publisher; + } + + final StreamMessage publisher = createPublisher(LongMath.saturatedAdd(elements, 1)); + final AtomicLong produced = new AtomicLong(); + return publisher + .peek(item -> { + if (produced.getAndIncrement() >= elements) { + publisher.abort(); + } + }); } @Test @@ -66,7 +98,6 @@ public void optional_spec104_mustSignalOnErrorWhenFails() { try { final TestEnvironment env = new TestEnvironment(200); whenHasErrorPublisherTest(pub -> { - final TestEnvironment.Latch onNextLatch = new TestEnvironment.Latch(env); final TestEnvironment.Latch onErrorLatch = new TestEnvironment.Latch(env); final TestEnvironment.Latch onSubscribeLatch = new TestEnvironment.Latch(env); pub.subscribe(new TestEnvironment.TestSubscriber(env) { @@ -77,23 +108,14 @@ public void onSubscribe(Subscription subs) { subs.request(Long.MAX_VALUE); } - @Override - public void onNext(Object element) { - onSubscribeLatch.assertClosed("onSubscribe should be called prior to onNext always"); - Assert.assertEquals(element, "Hello"); - onNextLatch.close(); - } - @Override public void onError(Throwable cause) { onSubscribeLatch.assertClosed("onSubscribe should be called prior to onError always"); - onNextLatch.assertClosed("onNext should already be called"); onErrorLatch.assertOpen(String.format("Error-state Publisher %s called `onError` twice on new Subscriber", pub)); onErrorLatch.close(); } }); onSubscribeLatch.expectClose("Should have received onSubscribe"); - onNextLatch.expectClose("Should have received onNext"); onErrorLatch.expectClose(String.format("Error-state Publisher %s did not call `onError` on new Subscriber", pub)); env.verifyNoAsyncErrors(); @@ -120,11 +142,6 @@ public void onSubscribe(Subscription subs) { subs.request(Long.MAX_VALUE); } - @Override - public void onNext(Object e) { - onSubscribeLatch.assertClosed("onSubscribe should be called prior to onNext always"); - } - @Override public void onError(Throwable cause) { onSubscribeLatch.assertClosed("onSubscribe should be called prior to onError always"); diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisherTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisherTest.java new file mode 100644 index 00000000000..0a561ff5dc9 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/internal/common/stream/SurroundingPublisherTest.java @@ -0,0 +1,309 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.common.stream; + +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import com.linecorp.armeria.common.stream.StreamMessage; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +class SurroundingPublisherTest { + + @Test + void zeroElementSurroundingPublisher() { + // given + final StreamMessage zeroElement = new SurroundingPublisher<>(null, Mono.empty(), null); + + // when & then + StepVerifier.create(zeroElement, 1) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(OneElementSurroundingPublisherProvider.class) + void oneElementSurroundingPublisher_request1(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 1) + .expectNext(1) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(OneElementSurroundingPublisherProvider.class) + void oneElementSurroundingPublisher_requestAll(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher) + .expectNext(1) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(TwoElementsSurroundingPublisherProvider.class) + void twoElementsSurroundingPublisher_request1(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 1) + .expectNext(1) + .thenRequest(1) + .expectNext(2) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(TwoElementsSurroundingPublisherProvider.class) + void twoElementsSurroundingPublisher_request1AndAll(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 1) + .expectNext(1) + .thenRequest(Long.MAX_VALUE) + .expectNext(2) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(TwoElementsSurroundingPublisherProvider.class) + void twoElementSurroundingPublisher_request2(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 2) + .expectNext(1, 2) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(TwoElementsSurroundingPublisherProvider.class) + void twoElementSurroundingPublisher_requestAll(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher) + .expectNext(1, 2) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(ThreeElementsSurroundingPublisherProvider.class) + void threeElementsSurroundingPublisher_request1(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 1) + .expectNext(1) + .thenRequest(1) + .expectNext(2) + .thenRequest(1) + .expectNext(3) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(ThreeElementsSurroundingPublisherProvider.class) + void threeElementsSurroundingPublisher_request1And2(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 1) + .expectNext(1) + .thenRequest(2) + .expectNext(2, 3) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(ThreeElementsSurroundingPublisherProvider.class) + void threeElementsSurroundingPublisher_request3(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 3) + .expectNext(1, 2, 3) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(ThreeElementsSurroundingPublisherProvider.class) + void threeElementsSurroundingPublisher_request1AndAll(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 1) + .expectNext(1) + .thenRequest(Long.MAX_VALUE) + .expectNext(2, 3) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(ThreeElementsSurroundingPublisherProvider.class) + void threeElementsSurroundingPublisher_request2(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 2) + .expectNext(1, 2) + .thenRequest(1) + .expectNext(3) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(ThreeElementsSurroundingPublisherProvider.class) + void threeElementsSurroundingPublisher_requestAll(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher) + .expectNext(1, 2, 3) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(FiveElementsSurroundingPublisherProvider.class) + void fiveElementsSurroundingPublisher_request1(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 1) + .expectNext(1) + .thenRequest(1) + .expectNext(2) + .thenRequest(1) + .expectNext(3) + .thenRequest(1) + .expectNext(4) + .thenRequest(1) + .expectNext(5) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(FiveElementsSurroundingPublisherProvider.class) + void fiveElementsSurroundingPublisher_request1And3And1(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 1) + .expectNext(1) + .thenRequest(3) + .expectNext(2, 3, 4) + .thenRequest(1) + .expectNext(5) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(FiveElementsSurroundingPublisherProvider.class) + void fiveElementsSurroundingPublisher_request1AndAll(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 1) + .expectNext(1) + .thenRequest(Long.MAX_VALUE) + .expectNext(2, 3, 4, 5) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(FiveElementsSurroundingPublisherProvider.class) + void fiveElementsSurroundingPublisher_request2And3(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 2) + .expectNext(1, 2) + .thenRequest(3) + .expectNext(3, 4, 5) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(FiveElementsSurroundingPublisherProvider.class) + void fiveElementsSurroundingPublisher_request5(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher, 5) + .expectNext(1, 2, 3, 4, 5) + .expectComplete() + .verify(); + } + + @ParameterizedTest + @ArgumentsSource(FiveElementsSurroundingPublisherProvider.class) + void fiveElementsSurroundingPublisher_requestAll(StreamMessage surroundingPublisher) { + // when & then + StepVerifier.create(surroundingPublisher) + .expectNext(1, 2, 3, 4, 5) + .expectComplete() + .verify(); + } + + private static class OneElementSurroundingPublisherProvider implements ArgumentsProvider { + + @Override + public Stream provideArguments(ExtensionContext context) throws Exception { + return Stream.of(new SurroundingPublisher<>(1, Mono.empty(), null), + new SurroundingPublisher<>(null, Mono.just(1), null), + new SurroundingPublisher<>(null, Mono.empty(), 1)) + .map(Arguments::of); + } + } + + private static class TwoElementsSurroundingPublisherProvider implements ArgumentsProvider { + + @Override + public Stream provideArguments(ExtensionContext context) throws Exception { + return Stream.of(new SurroundingPublisher<>(1, Mono.just(2), null), + new SurroundingPublisher<>(1, Mono.empty(), 2), + new SurroundingPublisher<>(null, Mono.just(1), 2), + new SurroundingPublisher<>(null, Flux.just(1, 2), null)) + .map(Arguments::of); + } + } + + private static class ThreeElementsSurroundingPublisherProvider implements ArgumentsProvider { + + @Override + public Stream provideArguments(ExtensionContext context) throws Exception { + return Stream.of(new SurroundingPublisher<>(1, Flux.just(2, 3), null), + new SurroundingPublisher<>(1, Mono.just(2), 3), + new SurroundingPublisher<>(null, Flux.just(1, 2), 3), + new SurroundingPublisher<>(null, Flux.just(1, 2, 3), null)) + .map(Arguments::of); + } + } + + private static class FiveElementsSurroundingPublisherProvider implements ArgumentsProvider { + + @Override + public Stream provideArguments(ExtensionContext context) throws Exception { + return Stream.of( + new SurroundingPublisher<>( + 1, Flux.fromStream(IntStream.range(2, 6).boxed()), null), + new SurroundingPublisher<>( + 1, Flux.fromStream(IntStream.range(2, 5).boxed()), 5), + new SurroundingPublisher<>( + null, Flux.fromStream(IntStream.range(1, 5).boxed()), 5), + new SurroundingPublisher<>( + null, Flux.fromStream(IntStream.range(1, 6).boxed()), null)) + .map(Arguments::of); + } + } +} diff --git a/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/AbstractServerHttpResponse.java b/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/AbstractServerHttpResponse.java index 7074c23fc8a..e0978e5847c 100644 --- a/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/AbstractServerHttpResponse.java +++ b/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/AbstractServerHttpResponse.java @@ -45,7 +45,6 @@ import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseCookie; -import org.springframework.http.server.reactive.ChannelSendOperator; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; diff --git a/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/ArmeriaServerHttpResponse.java b/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/ArmeriaServerHttpResponse.java index 51430886e16..2043a875cfd 100644 --- a/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/ArmeriaServerHttpResponse.java +++ b/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/ArmeriaServerHttpResponse.java @@ -87,7 +87,24 @@ private Mono write(Flux publisher) { HttpResponse.of(armeriaHeaders.build(), publisher.map(factoryWrapper::toHttpData) .contextWrite(contextView) - .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release)); + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release) + .doOnCancel(() -> { + logger.debug("{} Response stream cancelled", ctx, + new RuntimeException()); + }) + .doOnError(cause -> { + logger.debug("{} Response stream aborted. cause: {}", ctx, + cause, new RuntimeException()); + }) + .doOnComplete(() -> { + logger.debug("{} Response stream completed", ctx, + new RuntimeException()); + }) + .doFinally(signalType -> { + logger.debug("{} Response stream has been finished", ctx, + new RuntimeException()); + }) + ); future.complete(response); return Mono.fromFuture(response.whenComplete()) .onErrorResume(cause -> cause instanceof CancelledSubscriptionException || diff --git a/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/ChannelSendOperator.java b/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/ChannelSendOperator.java new file mode 100644 index 00000000000..b965ee473fd --- /dev/null +++ b/spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/ChannelSendOperator.java @@ -0,0 +1,438 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.linecorp.armeria.spring.web.reactive; + +import java.util.function.Function; + +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +/** + * Given a write function that accepts a source {@code Publisher} to write + * with and returns {@code Publisher} for the result, this operator helps + * to defer the invocation of the write function, until we know if the source + * publisher will begin publishing without an error. If the first emission is + * an error, the write function is bypassed, and the error is sent directly + * through the result publisher. Otherwise the write function is invoked. + * + * @author Rossen Stoyanchev + * @author Stephane Maldini + * @param the type of element signaled + * @since 5.0 + */ +final class ChannelSendOperator extends Mono implements Scannable { + + // Forked from https://github.com/spring-projects/spring-framework/blob/1e3099759e2d823b6dd1c0c43895abcbe3e02a12/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java + // and modified at L370 for not publishing the cache item before receiving request(n) from the subscriber. + private final Function, Publisher> writeFunction; + + private final Flux source; + + ChannelSendOperator(Publisher source, + Function, Publisher> writeFunction) { + this.source = Flux.from(source); + this.writeFunction = writeFunction; + } + + @Override + @Nullable + @SuppressWarnings("rawtypes") + public Object scanUnsafe(Attr key) { + if (key == Attr.PREFETCH) { + return Integer.MAX_VALUE; + } + if (key == Attr.PARENT) { + return source; + } + return null; + } + + @Override + public void subscribe(CoreSubscriber actual) { + source.subscribe(new WriteBarrier(actual)); + } + + private enum State { + + /** No emissions from the upstream source yet. */ + NEW, + + /** + * At least one signal of any kind has been received; we're ready to + * call the write function and proceed with actual writing. + */ + FIRST_SIGNAL_RECEIVED, + + /** + * The write subscriber has subscribed and requested; we're going to + * emit the cached signals. + */ + EMITTING_CACHED_SIGNALS, + + /** + * The write subscriber has subscribed, and cached signals have been + * emitted to it; we're ready to switch to a simple pass-through mode + * for all remaining signals. + **/ + READY_TO_WRITE + } + + /** + * A barrier inserted between the write source and the write subscriber + * (i.e. the HTTP server adapter) that pre-fetches and waits for the first + * signal before deciding whether to hook in to the write subscriber. + * + *

Acts as: + *

    + *
  • Subscriber to the write source. + *
  • Subscription to the write subscriber. + *
  • Publisher to the write subscriber. + *
+ * + *

Also uses {@link WriteCompletionBarrier} to communicate completion + * and detect cancel signals from the completion subscriber. + */ + private class WriteBarrier implements CoreSubscriber, Subscription, Publisher { + + /* Bridges signals to and from the completionSubscriber */ + private final WriteCompletionBarrier writeCompletionBarrier; + + /* Upstream write source subscription */ + @Nullable + private Subscription subscription; + + /** Cached data item before readyToWrite. */ + @Nullable + private T item; + + /** Cached error signal before readyToWrite. */ + @Nullable + private Throwable error; + + /** Cached onComplete signal before readyToWrite. */ + private boolean completed; + + /** Recursive demand while emitting cached signals. */ + private long demandBeforeReadyToWrite; + + /** Current state. */ + private State state = State.NEW; + + /** The actual writeSubscriber from the HTTP server adapter. */ + @Nullable + private Subscriber writeSubscriber; + + WriteBarrier(CoreSubscriber completionSubscriber) { + writeCompletionBarrier = new WriteCompletionBarrier(completionSubscriber, this); + } + + // Subscriber methods (we're the subscriber to the write source).. + + @Override + public final void onSubscribe(Subscription s) { + if (Operators.validate(subscription, s)) { + subscription = s; + writeCompletionBarrier.connect(); + s.request(1); + } + } + + @Override + public final void onNext(T item) { + if (state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onNext(item); + return; + } + //FIXME revisit in case of reentrant sync deadlock + synchronized (this) { + if (state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onNext(item); + } else if (state == State.NEW) { + this.item = item; + state = State.FIRST_SIGNAL_RECEIVED; + final Publisher result; + try { + result = writeFunction.apply(this); + } catch (Throwable ex) { + writeCompletionBarrier.onError(ex); + return; + } + result.subscribe(writeCompletionBarrier); + } else { + if (subscription != null) { + subscription.cancel(); + } + writeCompletionBarrier.onError(new IllegalStateException("Unexpected item.")); + } + } + } + + private Subscriber requiredWriteSubscriber() { + Assert.state(writeSubscriber != null, "No write subscriber"); + return writeSubscriber; + } + + @Override + public final void onError(Throwable ex) { + if (state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onError(ex); + return; + } + synchronized (this) { + if (state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onError(ex); + } else if (state == State.NEW) { + state = State.FIRST_SIGNAL_RECEIVED; + writeCompletionBarrier.onError(ex); + } else { + error = ex; + } + } + } + + @Override + public final void onComplete() { + if (state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onComplete(); + return; + } + synchronized (this) { + if (state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onComplete(); + } else if (state == State.NEW) { + completed = true; + state = State.FIRST_SIGNAL_RECEIVED; + final Publisher result; + try { + result = writeFunction.apply(this); + } catch (Throwable ex) { + writeCompletionBarrier.onError(ex); + return; + } + result.subscribe(writeCompletionBarrier); + } else { + completed = true; + } + } + } + + @Override + public Context currentContext() { + return writeCompletionBarrier.currentContext(); + } + + // Subscription methods (we're the Subscription to the writeSubscriber).. + + @Override + public void request(long n) { + final Subscription s = subscription; + if (s == null) { + return; + } + if (state == State.READY_TO_WRITE) { + s.request(n); + return; + } + synchronized (this) { + if (writeSubscriber != null) { + if (state == State.EMITTING_CACHED_SIGNALS) { + demandBeforeReadyToWrite = n; + return; + } + try { + state = State.EMITTING_CACHED_SIGNALS; + if (emitCachedSignals()) { + return; + } + n = n + demandBeforeReadyToWrite - 1; + if (n == 0) { + return; + } + } finally { + state = State.READY_TO_WRITE; + } + } + } + s.request(n); + } + + private boolean emitCachedSignals() { + if (error != null) { + try { + requiredWriteSubscriber().onError(error); + } finally { + releaseCachedItem(); + } + return true; + } + final T item = this.item; + this.item = null; + if (item != null) { + requiredWriteSubscriber().onNext(item); + } + if (completed) { + requiredWriteSubscriber().onComplete(); + return true; + } + return false; + } + + @Override + public void cancel() { + final Subscription s = subscription; + if (s != null) { + subscription = null; + try { + s.cancel(); + } finally { + releaseCachedItem(); + } + } + } + + private void releaseCachedItem() { + synchronized (this) { + final Object item = this.item; + if (item instanceof DataBuffer) { + DataBufferUtils.release((DataBuffer) item); + } + this.item = null; + } + } + + // Publisher methods (we're the Publisher to the writeSubscriber).. + + @Override + public void subscribe(Subscriber writeSubscriber) { + synchronized (this) { + Assert.state(this.writeSubscriber == null, "Only one write subscriber supported"); + this.writeSubscriber = writeSubscriber; + if (error != null || (completed && item == null)) { + this.writeSubscriber.onSubscribe(Operators.emptySubscription()); + emitCachedSignals(); + } else { + this.writeSubscriber.onSubscribe(this); + } + } + } + } + + /** + * We need an extra barrier between the WriteBarrier itself and the actual + * completion subscriber. + * + *

The completionSubscriber is subscribed initially to the WriteBarrier. + * Later after the first signal is received, we need one more subscriber + * instance (per spec can only subscribe once) to subscribe to the write + * function and switch to delegating completion signals from it. + */ + private class WriteCompletionBarrier implements CoreSubscriber, Subscription { + + /* Downstream write completion subscriber */ + private final CoreSubscriber completionSubscriber; + + private final WriteBarrier writeBarrier; + + @Nullable + private Subscription subscription; + + WriteCompletionBarrier(CoreSubscriber subscriber, WriteBarrier writeBarrier) { + completionSubscriber = subscriber; + this.writeBarrier = writeBarrier; + } + + /** + * Connect the underlying completion subscriber to this barrier in order + * to track cancel signals and pass them on to the write barrier. + */ + public void connect() { + completionSubscriber.onSubscribe(this); + } + + // Subscriber methods (we're the subscriber to the write function).. + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void aVoid) { + } + + @Override + public void onError(Throwable ex) { + try { + completionSubscriber.onError(ex); + } finally { + writeBarrier.releaseCachedItem(); + } + } + + @Override + public void onComplete() { + completionSubscriber.onComplete(); + } + + @Override + public Context currentContext() { + return completionSubscriber.currentContext(); + } + + @Override + public void request(long n) { + // Ignore: we don't produce data + } + + @Override + public void cancel() { + writeBarrier.cancel(); + final Subscription subscription = this.subscription; + if (subscription != null) { + subscription.cancel(); + } + } + } +} diff --git a/spring/boot3-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/ByteBufLeakTest.java b/spring/boot3-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/ByteBufLeakTest.java index 0e6774dccac..d7105a7eb33 100644 --- a/spring/boot3-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/ByteBufLeakTest.java +++ b/spring/boot3-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/ByteBufLeakTest.java @@ -96,8 +96,10 @@ Mono empty() { private static void addListenerForCountingCompletedRequests() { ServiceRequestContext.current().log().whenComplete() - .thenAccept(log -> completed.incrementAndGet()); - requestReceived.set(true); + .thenAccept(log -> { + completed.incrementAndGet(); + requestReceived.set(true); + }); } } } diff --git a/spring/boot3-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/MatrixVariablesTest.java b/spring/boot3-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/MatrixVariablesTest.java index f952c2638b9..a65584a7846 100644 --- a/spring/boot3-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/MatrixVariablesTest.java +++ b/spring/boot3-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/MatrixVariablesTest.java @@ -22,6 +22,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; import org.springframework.boot.test.web.server.LocalServerPort; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.MatrixVariable; @@ -29,6 +30,8 @@ import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.server.logging.LoggingService; +import com.linecorp.armeria.spring.ArmeriaServerConfigurator; import reactor.core.publisher.Flux; @@ -54,6 +57,11 @@ Flux findPet( return Flux.just(q1, q2); } } + + @Bean + public ArmeriaServerConfigurator serverConfigurator() { + return sb -> sb.decorator(LoggingService.newDecorator()); + } } @LocalServerPort