Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch exception happening in the gRPC interceptors #28063

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package io.quarkus.grpc.server.interceptors;

import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.time.Duration;

import javax.enterprise.context.ApplicationScoped;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.grpc.ForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.examples.helloworld.*;
import io.quarkus.grpc.GlobalInterceptor;
import io.quarkus.grpc.GrpcClient;
import io.quarkus.grpc.server.services.HelloService;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.mutiny.Uni;

public class FailingInInterceptorTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer(
() -> ShrinkWrap.create(JavaArchive.class)
.addPackage(GreeterGrpc.class.getPackage())
.addClasses(MyFailingInterceptor.class, GreeterBean.class, HelloRequest.class, HelloService.class));

@GrpcClient
Greeter greeter;

@Test
void test() {
Uni<HelloReply> result = greeter.sayHello(HelloRequest.newBuilder().setName("ServiceA").build());
assertThatThrownBy(() -> result.await().atMost(Duration.ofSeconds(4)))
.isInstanceOf(StatusRuntimeException.class)
.hasMessageContaining("UNKNOWN");
}

@ApplicationScoped
@GlobalInterceptor
public static class MyFailingInterceptor implements ServerInterceptor {

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
return next
.startCall(new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call) {

@Override
public void sendMessage(RespT message) {
throw new IllegalArgumentException("BOOM");
}

@Override
public void close(Status status, Metadata trailers) {
super.close(status, trailers);
}
}, headers);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.quarkus.grpc.server.interceptors;

import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.time.Duration;

import javax.enterprise.context.ApplicationScoped;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.StatusRuntimeException;
import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.GreeterBean;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.quarkus.grpc.GlobalInterceptor;
import io.quarkus.grpc.GrpcClient;
import io.quarkus.grpc.server.services.HelloService;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.mutiny.Uni;

public class FailingInterceptorTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer(
() -> ShrinkWrap.create(JavaArchive.class)
.addPackage(GreeterGrpc.class.getPackage())
.addClasses(MyFailingInterceptor.class, GreeterBean.class, HelloRequest.class, HelloService.class));

@GrpcClient
Greeter greeter;

@Test
void test() {
Uni<HelloReply> result = greeter.sayHello(HelloRequest.newBuilder().setName("ServiceA").build());
assertThatThrownBy(() -> result.await().atMost(Duration.ofSeconds(4)))
.isInstanceOf(StatusRuntimeException.class)
.hasMessageContaining("UNKNOWN");
}

@ApplicationScoped
@GlobalInterceptor
public static class MyFailingInterceptor implements ServerInterceptor {

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
throw new IllegalArgumentException("BOOM!");
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import static io.quarkus.vertx.core.runtime.context.VertxContextSafetyToggle.setContextSafe;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Supplier;

import javax.enterprise.context.ApplicationScoped;
Expand All @@ -13,9 +15,11 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.quarkus.grpc.GlobalInterceptor;
import io.smallrye.common.vertx.VertxContext;
import io.vertx.core.Context;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;

@ApplicationScoped
Expand Down Expand Up @@ -44,7 +48,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
setContextSafe(local, true);

// Must be sure to call next.startCall on the right context
return new ListenedOnDuplicatedContext<>(() -> next.startCall(call, headers), local);
return new ListenedOnDuplicatedContext<>(call, () -> next.startCall(call, headers), local);
} else {
log.warn("Unable to run on a duplicated context - interceptor not called on the Vert.x event loop");
return next.startCall(call, headers);
Expand All @@ -56,67 +60,99 @@ public int getPriority() {
return Integer.MAX_VALUE;
}

static class ListenedOnDuplicatedContext<ReqT> extends ServerCall.Listener<ReqT> {
static class ListenedOnDuplicatedContext<ReqT, RespT> extends ServerCall.Listener<ReqT> {

private final Context context;
private final Supplier<ServerCall.Listener<ReqT>> supplier;
private final ServerCall<ReqT, RespT> call;
private ServerCall.Listener<ReqT> delegate;

public ListenedOnDuplicatedContext(Supplier<ServerCall.Listener<ReqT>> supplier, Context context) {
private final AtomicBoolean closed = new AtomicBoolean();

public ListenedOnDuplicatedContext(ServerCall<ReqT, RespT> call, Supplier<ServerCall.Listener<ReqT>> supplier,
Context context) {
this.context = context;
this.supplier = supplier;
this.call = call;
}

private synchronized ServerCall.Listener<ReqT> getDelegate() {
if (delegate == null) {
delegate = supplier.get();
try {
delegate = supplier.get();
} catch (Throwable t) {
// If the interceptor supplier throws an exception, catch it, and close the call.
log.warnf("Unable to retrieve gRPC Server call listener", t);
close(t);
return null;
}
}
return delegate;
}

@Override
public void onMessage(ReqT message) {
private void close(Throwable t) {
if (closed.compareAndSet(false, true)) {
call.close(Status.fromThrowable(t), new Metadata());
}
}

private void invoke(Consumer<ServerCall.Listener<ReqT>> invocation) {
if (Vertx.currentContext() == context) {
getDelegate().onMessage(message);
ServerCall.Listener<ReqT> listener = getDelegate();
if (listener == null) {
return;
}
try {
invocation.accept(listener);
} catch (Throwable t) {
close(t);
}
} else {
context.runOnContext(x -> getDelegate().onMessage(message));
context.runOnContext(new Handler<Void>() {
@Override
public void handle(Void x) {
ServerCall.Listener<ReqT> listener = ListenedOnDuplicatedContext.this.getDelegate();
if (listener == null) {
return;
}
try {
invocation.accept(listener);
} catch (Throwable t) {
close(t);
}
}
});
}
}

@Override
public void onMessage(ReqT message) {
invoke(new Consumer<ServerCall.Listener<ReqT>>() {
@Override
public void accept(ServerCall.Listener<ReqT> listener) {
listener.onMessage(message);
}
});
}

@Override
public void onReady() {
if (Vertx.currentContext() == context) {
getDelegate().onReady();
} else {
context.runOnContext(x -> getDelegate().onReady());
}
invoke(ServerCall.Listener::onReady);
}

@Override
public void onHalfClose() {
if (Vertx.currentContext() == context) {
getDelegate().onHalfClose();
} else {
context.runOnContext(x -> getDelegate().onHalfClose());
}
invoke(ServerCall.Listener::onHalfClose);
}

@Override
public void onCancel() {
if (Vertx.currentContext() == context) {
getDelegate().onCancel();
} else {
context.runOnContext(x -> getDelegate().onCancel());
}
invoke(ServerCall.Listener::onCancel);
}

@Override
public void onComplete() {
if (Vertx.currentContext() == context) {
getDelegate().onComplete();
} else {
context.runOnContext(x -> getDelegate().onComplete());
}
invoke(ServerCall.Listener::onComplete);
}
}
}