diff --git a/independent-projects/resteasy-reactive/server/runtime/src/main/java/org/jboss/resteasy/reactive/server/handlers/PublisherResponseHandler.java b/independent-projects/resteasy-reactive/server/runtime/src/main/java/org/jboss/resteasy/reactive/server/handlers/PublisherResponseHandler.java index 9a82c247bca5f..bf20a89b8d648 100644 --- a/independent-projects/resteasy-reactive/server/runtime/src/main/java/org/jboss/resteasy/reactive/server/handlers/PublisherResponseHandler.java +++ b/independent-projects/resteasy-reactive/server/runtime/src/main/java/org/jboss/resteasy/reactive/server/handlers/PublisherResponseHandler.java @@ -10,6 +10,7 @@ import java.util.concurrent.Flow.Publisher; import java.util.concurrent.Flow.Subscriber; import java.util.concurrent.Flow.Subscription; +import java.util.function.Consumer; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.sse.OutboundSseEvent; @@ -46,32 +47,104 @@ public void setStreamingResponseCustomizers(List st this.streamingResponseCustomizers = streamingResponseCustomizers; } + @SuppressWarnings("rawtypes") private static class SseMultiSubscriber extends AbstractMultiSubscriber { + private final Publisher publisher; + // 0: no items have been pushed by the stream + // 1: the first item has been pushed by the stream, and we have yet to send the empty buffer (with the headers) + // 2: the empty buffer (with the headers) was sent, and we have received a response + // 3: all items pulled from upstream and successfully sent downstream + // 4: we got an error sending an item + private volatile int state = 0; + SseMultiSubscriber(ResteasyReactiveRequestContext requestContext, List staticCustomizers, - long demand) { + Publisher publisher, long demand) { super(requestContext, staticCustomizers, demand); + this.publisher = publisher; + } + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + // we only request one item initially because we need to use that item to create the headers + // that will be sent in the first empty response + s.request(1); + } + + @Override + public void onComplete() { + // make sure we don't trigger cancel with our onCloseHandler + weClosed = true; + if (state == 1) { // stream only had one item that we have yet to send (we are waiting for the empty buffer to be sent) + // do nothing as we still need to send the first item + // the connection will be closed by doSend when the item is sent + } else if (state < 3) { + doClose(); + } else { + handleException(requestContext, new IllegalStateException("Unexpected state: " + state)); + } } @Override public void onNext(Object item) { - OutboundSseEvent event; - if (item instanceof OutboundSseEvent) { - event = (OutboundSseEvent) item; + if (state == 0) { // first item + state = 1; + SseUtil.setHeaders(requestContext, requestContext.serverResponse(), + determineCustomizers(publisher, true, staticCustomizers)); + + requestContext.serverResponse().write(EMPTY_BUFFER, new Consumer<>() { + @Override + public void accept(Throwable throwable) { + if (throwable == null) { + state = 2; + // now we can actually send the first item + doSend(item); + } else { + state = 4; + requestContext.resume(throwable); + } + } + }); + } else if (state == 2) { // the only should have got here is when the empty buffer was sent + doSend(item); } else { - event = new OutboundSseEventImpl.BuilderImpl().data(item).build(); + handleException(requestContext, new IllegalStateException("Unexpected state: " + state)); } - SseUtil.send(requestContext, event, staticCustomizers).whenComplete((v, t) -> { + } + + private void doSend(Object item) { + SseUtil.send(requestContext, fromItem(item), staticCustomizers).whenComplete((v, t) -> { if (t != null) { + state = 4; // need to cancel because the exception didn't come from the Multi subscription.cancel(); handleException(requestContext, t); + } else if (weClosed && !requestContext.serverResponse().closed()) { + // this is the case where the stream only had one item so we need to close the connection as onComplete could not do it at the time it was called + doClose(); } else { // send in the next item subscription.request(demand); } }); } + + private void doClose() { + state = 3; + requestContext.serverResponse().end(); + requestContext.close(); + } + + private OutboundSseEvent fromItem(Object item) { + OutboundSseEvent event; + if (item instanceof OutboundSseEvent) { + event = (OutboundSseEvent) item; + } else { + event = new OutboundSseEventImpl.BuilderImpl().data(item).build(); + } + return event; + } } @SuppressWarnings("rawtypes") @@ -103,7 +176,7 @@ private static class StreamingMultiSubscriber extends AbstractMultiSubscriber { @Override public void onNext(Object item) { - List customizers = determineCustomizers(!hadItem); + List customizers = determineCustomizers(publisher, !hadItem, staticCustomizers); hadItem = true; StreamingUtil.send(requestContext, customizers, item, messagePrefix(), messageSuffix()) .handle((v, t) -> { @@ -125,33 +198,12 @@ public void onNext(Object item) { }); } - private List determineCustomizers(boolean isFirst) { - // we only need to obtain the customizers from the Publisher if it's the first time we are sending data and the Publisher has customizable data - // at this point no matter the type of RestMulti we can safely obtain the headers and status - if (isFirst && (publisher instanceof RestMulti restMulti)) { - Map> headers = restMulti.getHeaders(); - Integer status = restMulti.getStatus(); - if (headers.isEmpty() && (status == null)) { - return staticCustomizers; - } - List result = new ArrayList<>(staticCustomizers.size() + 2); - result.addAll(staticCustomizers); // these are added first so that the result specific values will take precedence if there are conflicts - if (!headers.isEmpty()) { - result.add(new StreamingResponseCustomizer.AddHeadersCustomizer(headers)); - } - if (status != null) { - result.add(new StreamingResponseCustomizer.StatusCustomizer(status)); - } - return result; - } - - return staticCustomizers; - } - @Override public void onComplete() { if (!hadItem) { - StreamingUtil.setHeaders(requestContext, requestContext.serverResponse(), this.determineCustomizers(true)); + StreamingUtil.setHeaders(requestContext, requestContext.serverResponse(), determineCustomizers( + this.publisher, true, + this.staticCustomizers)); } if (json) { String postfix = onCompleteText(); @@ -202,7 +254,7 @@ static abstract class AbstractMultiSubscriber implements Subscriber { protected final long demand; protected volatile Subscription subscription; - private volatile boolean weClosed = false; + protected volatile boolean weClosed = false; AbstractMultiSubscriber(ResteasyReactiveRequestContext requestContext, List staticCustomizers, long demand) { @@ -218,6 +270,31 @@ static abstract class AbstractMultiSubscriber implements Subscriber { }); } + @SuppressWarnings("rawtypes") + protected static List determineCustomizers(Publisher publisher, boolean isFirst, + List staticCustomizers) { + // we only need to obtain the customizers from the Publisher if it's the first time we are sending data and the Publisher has customizable data + // at this point no matter the type of RestMulti we can safely obtain the headers and status + if (isFirst && (publisher instanceof RestMulti restMulti)) { + Map> headers = restMulti.getHeaders(); + Integer status = restMulti.getStatus(); + if (headers.isEmpty() && (status == null)) { + return staticCustomizers; + } + List result = new ArrayList<>(staticCustomizers.size() + 2); + result.addAll(staticCustomizers); // these are added first so that the result specific values will take precedence if there are conflicts + if (!headers.isEmpty()) { + result.add(new StreamingResponseCustomizer.AddHeadersCustomizer(headers)); + } + if (status != null) { + result.add(new StreamingResponseCustomizer.StatusCustomizer(status)); + } + return result; + } + + return staticCustomizers; + } + @Override public void onSubscribe(Subscription s) { this.subscription = s; @@ -343,15 +420,8 @@ private void handleSse(ResteasyReactiveRequestContext requestContext, Publisher< demand = 1L; } - SseUtil.setHeaders(requestContext, requestContext.serverResponse(), streamingResponseCustomizers); requestContext.suspend(); - requestContext.serverResponse().write(EMPTY_BUFFER, throwable -> { - if (throwable == null) { - result.subscribe(new SseMultiSubscriber(requestContext, streamingResponseCustomizers, demand)); - } else { - requestContext.resume(throwable); - } - }); + result.subscribe(new SseMultiSubscriber(requestContext, streamingResponseCustomizers, result, demand)); } public interface StreamingResponseCustomizer {