Skip to content

Commit

Permalink
Handle client close/cancel on grpc mutiny streaming service
Browse files Browse the repository at this point in the history
  • Loading branch information
pcasaes committed Apr 21, 2022
1 parent 3ec92db commit ee94e06
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import javax.interceptor.Interceptor;
import javax.interceptor.InvocationContext;

import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import io.quarkus.grpc.stubs.ServerCalls;
import io.quarkus.grpc.stubs.StreamCollector;
Expand Down Expand Up @@ -45,7 +46,7 @@ Object collect(InvocationContext context) throws Exception {
Object[] newParams = new Object[params.length];
for (int i = 0; i < params.length; i++) {
if (i == streamIndex) {
newParams[i] = new StreamObserverWrapper<>(stream);
newParams[i] = wrap(stream);
} else {
newParams[i] = params[i];
}
Expand All @@ -54,6 +55,13 @@ Object collect(InvocationContext context) throws Exception {
return context.proceed();
}

private StreamObserver<Object> wrap(StreamObserver<Object> stream) {
if (stream instanceof ServerCallStreamObserver) {
return new ServerCallStreamObserverWrapper<>((ServerCallStreamObserver<Object>) stream);
}
return new StreamObserverWrapper<>(stream);
}

private final class StreamObserverWrapper<T> implements StreamObserver<T> {

private final StreamObserver<T> delegate;
Expand Down Expand Up @@ -81,4 +89,80 @@ public void onCompleted() {

}

private final class ServerCallStreamObserverWrapper<T> extends ServerCallStreamObserver<T> {

private final ServerCallStreamObserver<T> delegate;

public ServerCallStreamObserverWrapper(ServerCallStreamObserver<T> delegate) {
this.delegate = delegate;
}

@Override
public void onNext(T value) {
delegate.onNext(value);
}

@Override
public void onError(Throwable t) {
delegate.onError(t);
streamCollector.remove(delegate);
}

@Override
public void onCompleted() {
delegate.onCompleted();
streamCollector.remove(delegate);
}

@Override
public boolean isCancelled() {
return delegate.isCancelled();
}

@Override
public void setOnCancelHandler(Runnable runnable) {
delegate.setOnCancelHandler(runnable);
}

@Override
public void setCompression(String s) {
delegate.setCompression(s);
}

@Override
public void disableAutoRequest() {
delegate.disableAutoRequest();
}

@Override
public boolean isReady() {
return delegate.isReady();
}

@Override
public void setOnReadyHandler(Runnable runnable) {
delegate.setOnReadyHandler(runnable);
}

@Override
public void request(int i) {
delegate.request(i);
}

@Override
public void setMessageCompression(boolean b) {
delegate.setMessageCompression(b);
}

@Override
public void setOnCloseHandler(Runnable onCloseHandler) {
delegate.setOnCloseHandler(onCloseHandler);
}

@Override
public void disableAutoInboundFlowControl() {
delegate.disableAutoInboundFlowControl();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.operators.multi.processors.UnicastProcessor;
import io.smallrye.mutiny.subscription.Cancellable;

public class ServerCalls {
private static final Logger log = Logger.getLogger(ServerCalls.class);
Expand Down Expand Up @@ -64,7 +65,7 @@ public static <I, O> void oneToMany(I request, StreamObserver<O> response, Strin
response.onError(Status.fromCode(Status.Code.INTERNAL).asException());
return;
}
returnValue.subscribe().with(
handleSubscription(returnValue.subscribe().with(
new Consumer<O>() {
@Override
public void accept(O v) {
Expand All @@ -82,7 +83,7 @@ public void accept(Throwable throwable) {
public void run() {
onCompleted(response);
}
});
}), response);
} catch (Throwable throwable) {
onError(response, toStatusFailure(throwable));
}
Expand Down Expand Up @@ -124,6 +125,22 @@ public void accept(Throwable failure) {
}
}

private static <O> void handleSubscription(Cancellable cancellable, StreamObserver<O> response) {
if (response instanceof ServerCallStreamObserver) {
ServerCallStreamObserver<O> serverCallResponse = (ServerCallStreamObserver<O>) response;

Runnable cancel = new Runnable() {
@Override
public void run() {
cancellable.cancel();
}
};

serverCallResponse.setOnCloseHandler(cancel);
serverCallResponse.setOnCancelHandler(cancel);
}
}

public static <I, O> StreamObserver<I> manyToMany(StreamObserver<O> response,
Function<Multi<I>, Multi<O>> implementation) {
try {
Expand All @@ -133,11 +150,11 @@ public static <I, O> StreamObserver<I> manyToMany(StreamObserver<O> response,
Multi<O> multi = implementation.apply(input);
if (multi == null) {
log.error("gRPC service method returned null instead of Multi. " +
"Please change the implementation to return a Multi object or throw StatusRuntimeException");
"Please change the implementation to rceturn a Multi object or throw StatusRuntimeException");
response.onError(Status.fromCode(Status.Code.INTERNAL).asException());
return null;
}
multi.subscribe().with(
handleSubscription(multi.subscribe().with(
new Consumer<O>() {
@Override
public void accept(O v) {
Expand All @@ -155,7 +172,8 @@ public void accept(Throwable failure) {
public void run() {
onCompleted(response);
}
});
}), response);

return pump;
} catch (Throwable throwable) {
onError(response, toStatusFailure(throwable));
Expand Down

0 comments on commit ee94e06

Please sign in to comment.