Skip to content

Commit

Permalink
Take RestMulti headers and status into account when using SSE resourc…
Browse files Browse the repository at this point in the history
…e method
  • Loading branch information
geoand committed Dec 20, 2024
1 parent 718c582 commit f0e1e4e
Showing 1 changed file with 110 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.concurrent.Flow.Publisher;
import java.util.concurrent.Flow.Subscriber;
import java.util.concurrent.Flow.Subscription;
import java.util.function.Consumer;

import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.sse.OutboundSseEvent;
Expand Down Expand Up @@ -46,32 +47,104 @@ public void setStreamingResponseCustomizers(List<StreamingResponseCustomizer> st
this.streamingResponseCustomizers = streamingResponseCustomizers;
}

@SuppressWarnings("rawtypes")
private static class SseMultiSubscriber extends AbstractMultiSubscriber {

private final Publisher publisher;
// 0: no items have been pushed by the stream
// 1: the first item has been pushed by the stream, and we have yet to send the empty buffer (with the headers)
// 2: the empty buffer (with the headers) was sent, and we have received a response
// 3: all items pulled from upstream and successfully sent downstream
// 4: we got an error sending an item
private volatile int state = 0;

SseMultiSubscriber(ResteasyReactiveRequestContext requestContext, List<StreamingResponseCustomizer> staticCustomizers,
long demand) {
Publisher publisher, long demand) {
super(requestContext, staticCustomizers, demand);
this.publisher = publisher;
}

@Override
public void onSubscribe(Subscription s) {
this.subscription = s;
// we only request one item initially because we need to use that item to create the headers
// that will be sent in the first empty response
s.request(1);
}

@Override
public void onComplete() {
// make sure we don't trigger cancel with our onCloseHandler
weClosed = true;
if (state == 1) { // stream only had one item that we have yet to send (we are waiting for the empty buffer to be sent)
// do nothing as we still need to send the first item
// the connection will be closed by doSend when the item is sent
} else if (state < 3) {
doClose();
} else {
handleException(requestContext, new IllegalStateException("Unexpected state: " + state));
}
}

@Override
public void onNext(Object item) {
OutboundSseEvent event;
if (item instanceof OutboundSseEvent) {
event = (OutboundSseEvent) item;
if (state == 0) { // first item
state = 1;
SseUtil.setHeaders(requestContext, requestContext.serverResponse(),
determineCustomizers(publisher, true, staticCustomizers));

requestContext.serverResponse().write(EMPTY_BUFFER, new Consumer<>() {
@Override
public void accept(Throwable throwable) {
if (throwable == null) {
state = 2;
// now we can actually send the first item
doSend(item);
} else {
state = 4;
requestContext.resume(throwable);
}
}
});
} else if (state == 2) { // the only should have got here is when the empty buffer was sent
doSend(item);
} else {
event = new OutboundSseEventImpl.BuilderImpl().data(item).build();
handleException(requestContext, new IllegalStateException("Unexpected state: " + state));
}
SseUtil.send(requestContext, event, staticCustomizers).whenComplete((v, t) -> {
}

private void doSend(Object item) {
SseUtil.send(requestContext, fromItem(item), staticCustomizers).whenComplete((v, t) -> {
if (t != null) {
state = 4;
// need to cancel because the exception didn't come from the Multi
subscription.cancel();
handleException(requestContext, t);
} else if (weClosed && !requestContext.serverResponse().closed()) {
// this is the case where the stream only had one item so we need to close the connection as onComplete could not do it at the time it was called
doClose();
} else {
// send in the next item
subscription.request(demand);
}
});
}

private void doClose() {
state = 3;
requestContext.serverResponse().end();
requestContext.close();
}

private OutboundSseEvent fromItem(Object item) {
OutboundSseEvent event;
if (item instanceof OutboundSseEvent) {
event = (OutboundSseEvent) item;
} else {
event = new OutboundSseEventImpl.BuilderImpl().data(item).build();
}
return event;
}
}

@SuppressWarnings("rawtypes")
Expand Down Expand Up @@ -103,7 +176,7 @@ private static class StreamingMultiSubscriber extends AbstractMultiSubscriber {

@Override
public void onNext(Object item) {
List<StreamingResponseCustomizer> customizers = determineCustomizers(!hadItem);
List<StreamingResponseCustomizer> customizers = determineCustomizers(publisher, !hadItem, staticCustomizers);
hadItem = true;
StreamingUtil.send(requestContext, customizers, item, messagePrefix(), messageSuffix())
.handle((v, t) -> {
Expand All @@ -125,33 +198,12 @@ public void onNext(Object item) {
});
}

private List<StreamingResponseCustomizer> determineCustomizers(boolean isFirst) {
// we only need to obtain the customizers from the Publisher if it's the first time we are sending data and the Publisher has customizable data
// at this point no matter the type of RestMulti we can safely obtain the headers and status
if (isFirst && (publisher instanceof RestMulti<?> restMulti)) {
Map<String, List<String>> headers = restMulti.getHeaders();
Integer status = restMulti.getStatus();
if (headers.isEmpty() && (status == null)) {
return staticCustomizers;
}
List<StreamingResponseCustomizer> result = new ArrayList<>(staticCustomizers.size() + 2);
result.addAll(staticCustomizers); // these are added first so that the result specific values will take precedence if there are conflicts
if (!headers.isEmpty()) {
result.add(new StreamingResponseCustomizer.AddHeadersCustomizer(headers));
}
if (status != null) {
result.add(new StreamingResponseCustomizer.StatusCustomizer(status));
}
return result;
}

return staticCustomizers;
}

@Override
public void onComplete() {
if (!hadItem) {
StreamingUtil.setHeaders(requestContext, requestContext.serverResponse(), this.determineCustomizers(true));
StreamingUtil.setHeaders(requestContext, requestContext.serverResponse(), determineCustomizers(
this.publisher, true,
this.staticCustomizers));
}
if (json) {
String postfix = onCompleteText();
Expand Down Expand Up @@ -202,7 +254,7 @@ static abstract class AbstractMultiSubscriber implements Subscriber<Object> {
protected final long demand;

protected volatile Subscription subscription;
private volatile boolean weClosed = false;
protected volatile boolean weClosed = false;

AbstractMultiSubscriber(ResteasyReactiveRequestContext requestContext,
List<StreamingResponseCustomizer> staticCustomizers, long demand) {
Expand All @@ -218,6 +270,31 @@ static abstract class AbstractMultiSubscriber implements Subscriber<Object> {
});
}

@SuppressWarnings("rawtypes")
protected static List<StreamingResponseCustomizer> determineCustomizers(Publisher publisher, boolean isFirst,
List<StreamingResponseCustomizer> staticCustomizers) {
// we only need to obtain the customizers from the Publisher if it's the first time we are sending data and the Publisher has customizable data
// at this point no matter the type of RestMulti we can safely obtain the headers and status
if (isFirst && (publisher instanceof RestMulti<?> restMulti)) {
Map<String, List<String>> headers = restMulti.getHeaders();
Integer status = restMulti.getStatus();
if (headers.isEmpty() && (status == null)) {
return staticCustomizers;
}
List<StreamingResponseCustomizer> result = new ArrayList<>(staticCustomizers.size() + 2);
result.addAll(staticCustomizers); // these are added first so that the result specific values will take precedence if there are conflicts
if (!headers.isEmpty()) {
result.add(new StreamingResponseCustomizer.AddHeadersCustomizer(headers));
}
if (status != null) {
result.add(new StreamingResponseCustomizer.StatusCustomizer(status));
}
return result;
}

return staticCustomizers;
}

@Override
public void onSubscribe(Subscription s) {
this.subscription = s;
Expand Down Expand Up @@ -343,15 +420,8 @@ private void handleSse(ResteasyReactiveRequestContext requestContext, Publisher<
demand = 1L;
}

SseUtil.setHeaders(requestContext, requestContext.serverResponse(), streamingResponseCustomizers);
requestContext.suspend();
requestContext.serverResponse().write(EMPTY_BUFFER, throwable -> {
if (throwable == null) {
result.subscribe(new SseMultiSubscriber(requestContext, streamingResponseCustomizers, demand));
} else {
requestContext.resume(throwable);
}
});
result.subscribe(new SseMultiSubscriber(requestContext, streamingResponseCustomizers, result, demand));
}

public interface StreamingResponseCustomizer {
Expand Down

0 comments on commit f0e1e4e

Please sign in to comment.