Skip to content

Commit

Permalink
Allow SSE events to be filtered out from REST Client
Browse files Browse the repository at this point in the history
  • Loading branch information
geoand committed Nov 21, 2023
1 parent 6b66359 commit c9d1eea
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Predicate;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
Expand All @@ -23,6 +24,7 @@
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.jboss.resteasy.reactive.RestStreamElementType;
import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import org.jboss.resteasy.reactive.server.jackson.JacksonBasicMessageBodyReader;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
Expand Down Expand Up @@ -136,6 +138,25 @@ public void accept(SseEvent<Dto> event) {
new EventContainer("id1", "name1", new Dto("name1", "1"))));
}

@Test
void shouldBeAbleReadEntireEventWhileAlsoBeingAbleToFilterEvents() {
var resultList = new CopyOnWriteArrayList<>();
createClient()
.eventWithFilter()
.subscribe().with(new Consumer<>() {
@Override
public void accept(SseEvent<Dto> event) {
resultList.add(new EventContainer(event.id(), event.name(), event.data()));
}
});
await().atMost(5, TimeUnit.SECONDS)
.untilAsserted(
() -> assertThat(resultList).containsExactly(
new EventContainer("id", "n0", new Dto("name0", "0")),
new EventContainer("id", "n1", new Dto("name1", "1")),
new EventContainer("id", "n2", new Dto("name2", "2"))));
}

static class EventContainer {
final String id;
final String name;
Expand Down Expand Up @@ -212,6 +233,26 @@ public interface SseClient {
@Path("/event")
@Produces(MediaType.SERVER_SENT_EVENTS)
Multi<SseEvent<Dto>> event();

@GET
@Path("/event-with-filter")
@Produces(MediaType.SERVER_SENT_EVENTS)
@SseEventFilter(CustomFilter.class)
Multi<SseEvent<Dto>> eventWithFilter();
}

public static class CustomFilter implements Predicate<SseEvent<String>> {

@Override
public boolean test(SseEvent<String> event) {
if ("heartbeat".equals(event.id())) {
return false;
}
if ("END".equals(event.data())) {
return false;
}
return true;
}
}

@Path("/sse")
Expand Down Expand Up @@ -261,6 +302,50 @@ public void event(@Context SseEventSink sink, @Context Sse sse) {
}
}
}

@GET
@Path("/event-with-filter")
@Produces(MediaType.SERVER_SENT_EVENTS)
public void eventWithFilter(@Context SseEventSink sink, @Context Sse sse) {
try (sink) {
sink.send(sse.newEventBuilder()
.id("id")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.data(Dto.class, new Dto("name0", "0"))
.name("n0")
.build());

sink.send(sse.newEventBuilder()
.id("heartbeat")
.comment("heartbeat")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.build());

sink.send(sse.newEventBuilder()
.id("id")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.data(Dto.class, new Dto("name1", "1"))
.name("n1")
.build());

sink.send(sse.newEventBuilder()
.id("heartbeat")
.comment("heartbeat")
.build());

sink.send(sse.newEventBuilder()
.id("id")
.mediaType(MediaType.APPLICATION_JSON_TYPE)
.data(Dto.class, new Dto("name2", "2"))
.name("n2")
.build());

sink.send(sse.newEventBuilder()
.id("end")
.data("END")
.build());
}
}
}

@Path("/sse-rest-stream-element-type")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.eclipse.microprofile.rest.client.annotation.RegisterProviders;
import org.eclipse.microprofile.rest.client.ext.ResponseExceptionMapper;
import org.jboss.jandex.DotName;
import org.jboss.resteasy.reactive.client.SseEventFilter;

import io.quarkus.rest.client.reactive.ClientExceptionMapper;
import io.quarkus.rest.client.reactive.ClientFormParam;
Expand Down Expand Up @@ -41,6 +42,8 @@ public class DotNames {

static final DotName METHOD = DotName.createSimple(Method.class.getName());

public static final DotName SSE_EVENT_FILTER = DotName.createSimple(SseEventFilter.class);

private DotNames() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.jboss.resteasy.reactive.common.util.QuarkusMultivaluedHashMap;

import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.BeanArchiveIndexBuildItem;
import io.quarkus.arc.deployment.CustomScopeAnnotationsBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
Expand Down Expand Up @@ -371,6 +372,42 @@ void registerCompressionInterceptors(BuildProducer<ReflectiveClassBuildItem> ref
}
}

@BuildStep
void handleSseEventFilter(BuildProducer<ReflectiveClassBuildItem> reflectiveClasses,
BeanArchiveIndexBuildItem beanArchiveIndexBuildItem) {
var index = beanArchiveIndexBuildItem.getIndex();
Collection<AnnotationInstance> instances = index.getAnnotations(DotNames.SSE_EVENT_FILTER);
if (instances.isEmpty()) {
return;
}

List<String> filterClassNames = new ArrayList<>(instances.size());
for (AnnotationInstance instance : instances) {
if (instance.target().kind() != AnnotationTarget.Kind.METHOD) {
continue;
}
if (instance.value() == null) {
continue; // can't happen
}
Type filterType = instance.value().asClass();
DotName filterClassName = filterType.name();
ClassInfo filterClassInfo = index.getClassByName(filterClassName.toString());
if (filterClassInfo == null) {
log.warn("Unable to find class '" + filterType.name() + "' in index");
} else if (!filterClassInfo.hasNoArgsConstructor()) {
throw new RestClientDefinitionException(
"Classes used in @SseEventFilter must have a no-args constructor. Offending class is '"
+ filterClassName + "'");
} else {
filterClassNames.add(filterClassName.toString());
}
}
reflectiveClasses.produce(ReflectiveClassBuildItem
.builder(filterClassNames.toArray(new String[0]))
.constructors(true)
.build());
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
void addRestClientBeans(Capabilities capabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,38 @@
*/
public interface SseEvent<T> {

/**
* Get event identifier.
* <p>
* Contains value of SSE {@code "id"} field. This field is optional. Method may return {@code null}, if the event
* identifier is not specified.
*
* @return event id.
*/
String id();

/**
* Get event name.
* <p>
* Contains value of SSE {@code "event"} field. This field is optional. Method may return {@code null}, if the event
* name is not specified.
*
* @return event name, or {@code null} if not set.
*/
String name();

/**
* Get a comment string that accompanies the event.
* <p>
* Contains value of the comment associated with SSE event. This field is optional. Method may return {@code null}, if
* the event comment is not specified.
*
* @return comment associated with the event.
*/
String comment();

/**
* Get event data.
*/
T data();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.jboss.resteasy.reactive.client;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.function.Predicate;

/**
* Used when not all SSE events streamed from the server should be included in the event stream returned by the client.
* <p>
* IMPORTANT: implementations MUST contain a no-args constructor
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface SseEventFilter {

/**
* Predicate which decides whether an event should be included in the event stream returned by the client.
*/
Class<? extends Predicate<SseEvent<String>>> value();
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;

import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.GenericType;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;

import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import org.jboss.resteasy.reactive.common.jaxrs.ResponseImpl;
import org.jboss.resteasy.reactive.common.util.RestMediaType;

Expand Down Expand Up @@ -45,8 +48,8 @@ public <R> Multi<R> get(GenericType<R> responseType) {

/**
* We need this class to work around a bug in Mutiny where we can register our cancel listener
* after the subscription is cancelled and we never get notified
* See https://github.com/smallrye/smallrye-mutiny/issues/417
* after the subscription is cancelled, and we never get notified
* See <a href="https://github.com/smallrye/smallrye-mutiny/issues/417">...</a>
*/
static class MultiRequest<R> {

Expand Down Expand Up @@ -127,9 +130,11 @@ public <R> Multi<R> method(String name, Entity<?> entity, GenericType<R> respons
if (!emitter.isCancelled()) {
if (response.getStatus() == 200
&& MediaType.SERVER_SENT_EVENTS_TYPE.isCompatible(response.getMediaType())) {
registerForSse(multiRequest, responseType, response, vertxResponse,
registerForSse(
multiRequest, responseType, vertxResponse,
(String) restClientRequestContext.getProperties()
.get(RestClientRequestContext.DEFAULT_CONTENT_TYPE_PROP));
.get(RestClientRequestContext.DEFAULT_CONTENT_TYPE_PROP),
restClientRequestContext.getInvokedMethod());
} else if (response.getStatus() == 200
&& RestMediaType.APPLICATION_STREAM_JSON_TYPE.isCompatible(response.getMediaType())) {
registerForJsonStream(multiRequest, restClientRequestContext, responseType, response,
Expand All @@ -156,14 +161,16 @@ private boolean isNewlineDelimited(ResponseImpl response) {
@SuppressWarnings({ "unchecked", "rawtypes" })
private <R> void registerForSse(MultiRequest<? super R> multiRequest,
GenericType<R> responseType,
Response response,
HttpClientResponse vertxResponse, String defaultContentType) {
HttpClientResponse vertxResponse, String defaultContentType,
Method invokedMethod) {

boolean returnSseEvent = SseEvent.class.equals(responseType.getRawType());
GenericType responseTypeFirstParam = responseType.getType() instanceof ParameterizedType
? new GenericType(((ParameterizedType) responseType.getType()).getActualTypeArguments()[0])
: null;

Predicate<SseEvent<String>> eventPredicate = createEventPredicate(invokedMethod);

// honestly, isn't reconnect contradictory with completion?
// FIXME: Reconnect settings?
// For now we don't want multi to reconnect
Expand All @@ -172,8 +179,39 @@ private <R> void registerForSse(MultiRequest<? super R> multiRequest,

multiRequest.onCancel(sseSource::close);
sseSource.register(event -> {

// TODO: we might want to cut down on the allocations here...

if (eventPredicate != null) {
boolean keep = eventPredicate.test(new SseEvent<>() {
@Override
public String id() {
return event.getId();
}

@Override
public String name() {
return event.getName();
}

@Override
public String comment() {
return event.getComment();
}

@Override
public String data() {
return event.readData();
}
});
if (!keep) {
return;
}
}

// DO NOT pass the response mime type because it's SSE: let the event pick between the X-SSE-Content-Type header or
// the content-type SSE field

if (returnSseEvent) {
multiRequest.emit((R) new SseEvent() {
@Override
Expand Down Expand Up @@ -212,6 +250,23 @@ public Object data() {
sseSource.registerAfterRequest(vertxResponse);
}

private Predicate<SseEvent<String>> createEventPredicate(Method invokedMethod) {
if (invokedMethod == null) {
return null; // should never happen
}

SseEventFilter filterAnnotation = invokedMethod.getAnnotation(SseEventFilter.class);
if (filterAnnotation == null) {
return null;
}

try {
return filterAnnotation.value().getConstructor().newInstance();
} catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
throw new RuntimeException(e);
}
}

private <R> void registerForChunks(MultiRequest<? super R> multiRequest,
RestClientRequestContext restClientRequestContext,
GenericType<R> responseType,
Expand Down

0 comments on commit c9d1eea

Please sign in to comment.