Skip to content

Commit

Permalink
Fix gRPC context propagation.
Browse files Browse the repository at this point in the history
  • Loading branch information
alesj committed Apr 18, 2023
1 parent 0ea2dd2 commit 5f4901a
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package io.quarkus.grpc.client.bd;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.grpc.Deadline;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloRequest;
import io.quarkus.grpc.GrpcClient;
import io.quarkus.test.QuarkusUnitTest;

public class ClientBlockingDeadlineTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer(
() -> ShrinkWrap.create(JavaArchive.class)
.addPackage(GreeterGrpc.class.getPackage()).addClasses(HelloService.class))
.withConfigurationResource("hello-config-deadline.properties");

@GrpcClient("hello-service")
GreeterGrpc.GreeterBlockingStub stub;

@Test
public void testCallOptions() {
Deadline deadline = stub.getCallOptions().getDeadline();
assertNotNull(deadline);
try {
//noinspection ResultOfMethodCallIgnored
stub.sayHello(HelloRequest.newBuilder().setName("Scaladar").build());
} catch (Exception e) {
Assertions.assertTrue(e instanceof StatusRuntimeException);
StatusRuntimeException sre = (StatusRuntimeException) e;
Status status = sre.getStatus();
Assertions.assertNotNull(status);
Assertions.assertEquals(Status.DEADLINE_EXCEEDED.getCode(), status.getCode());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.quarkus.grpc.client.bd;

import java.time.Duration;

import io.grpc.Context;
import io.grpc.Deadline;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.grpc.stub.StreamObserver;
import io.quarkus.grpc.GrpcService;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.mutiny.Uni;

@GrpcService
public class HelloService extends GreeterGrpc.GreeterImplBase {

@Override
@Blocking
public void sayHello(HelloRequest request, StreamObserver<HelloReply> observer) {
Deadline deadline = Context.current().getDeadline();
if (deadline == null) {
throw new IllegalStateException("Null deadline");
}
Uni.createFrom()
.item(HelloReply.newBuilder().setMessage("OK").build())
.onItem()
.delayIt()
.by(Duration.ofMillis(400)).invoke(observer::onNext)
.invoke(observer::onCompleted)
.await()
.indefinitely();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import java.time.Duration;

import io.grpc.Context;
import io.grpc.Deadline;
import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
Expand All @@ -13,6 +15,10 @@ public class HelloService implements Greeter {

@Override
public Uni<HelloReply> sayHello(HelloRequest request) {
Deadline deadline = Context.current().getDeadline();
if (deadline == null) {
throw new IllegalStateException("Null deadline");
}
return Uni.createFrom().item(HelloReply.newBuilder().setMessage("OK").build()).onItem().delayIt()
.by(Duration.ofMillis(400));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package io.quarkus.grpc.runtime.supports.blocking;

import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.Consumer;
import java.util.function.Function;

import org.jboss.logging.Logger;

import io.grpc.Context;
import io.grpc.Metadata;
import io.grpc.ServerCall;
Expand All @@ -31,6 +33,7 @@
* For non-annotated methods, the interceptor acts as a pass-through.
*/
public class BlockingServerInterceptor implements ServerInterceptor, Function<String, Boolean> {
private static final Logger log = Logger.getLogger(BlockingServerInterceptor.class);

private final Vertx vertx;
private final Set<String> blockingMethods;
Expand Down Expand Up @@ -102,14 +105,16 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
*/
private class ReplayListener<ReqT> extends ServerCall.Listener<ReqT> {
private final InjectableContext.ContextState requestContextState;
private final Context grpcContext;

// exclusive to event loop context
private ServerCall.Listener<ReqT> delegate;
private final Queue<Consumer<ServerCall.Listener<ReqT>>> incomingEvents = new LinkedList<>();
private boolean isConsumingFromIncomingEvents = false;
private volatile ServerCall.Listener<ReqT> delegate;
private final Queue<Consumer<ServerCall.Listener<ReqT>>> incomingEvents = new ConcurrentLinkedQueue<>();
private volatile boolean isConsumingFromIncomingEvents;

private ReplayListener(InjectableContext.ContextState requestContextState) {
this.requestContextState = requestContextState;
this.grpcContext = Context.current();
}

/**
Expand Down Expand Up @@ -144,7 +149,11 @@ private void executeOnContextOrEnqueue(Consumer<ServerCall.Listener<ReqT>> consu
* @param consumer the original
*/
private void executeBlockingWithRequestContext(Consumer<ServerCall.Listener<ReqT>> consumer) {
final Context grpcContext = Context.current();
if (!isExecutable()) {
log.warn("Not executable, already shutdown? Ignoring execution ...");
return;
}

Handler<Promise<Object>> blockingHandler = new BlockingExecutionHandler<>(consumer, grpcContext, delegate,
requestContextState, getRequestContext(), this);
if (devMode) {
Expand Down Expand Up @@ -189,6 +198,11 @@ public void onReady() {
}

// protected for tests

protected boolean isExecutable() {
return Arc.container() != null;
}

protected ManagedContext getRequestContext() {
return Arc.container().requestContext();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,28 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
setContextSafe(local, true);

// Must be sure to call next.startCall on the right context
return new ListenedOnDuplicatedContext<>(ehp, call, () -> next.startCall(call, headers), local);
return new ListenedOnDuplicatedContext<>(ehp, call, nextCall(call, headers, next), local);
} else {
log.warn("Unable to run on a duplicated context - interceptor not called on the Vert.x event loop");
return next.startCall(call, headers);
}
}

private <ReqT, RespT> Supplier<ServerCall.Listener<ReqT>> nextCall(ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
// Must be sure to call next.startCall on the right context
io.grpc.Context current = io.grpc.Context.current();
return () -> {
io.grpc.Context previous = current.attach();
try {
return next.startCall(call, headers);
} finally {
current.detach(previous);
}
};
}

@Override
public int getPriority() {
return Interceptors.DUPLICATE_CONTEXT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ void setup() {
ManagedContext requestContext = mock(ManagedContext.class);
when(requestContext.getState()).thenReturn(contextState);
blockingServerInterceptor = new BlockingServerInterceptor(vertx, Collections.singletonList("blocking"), false) {
@Override
protected boolean isExecutable() {
return true;
}

@Override
protected ManagedContext getRequestContext() {
return requestContext;
Expand All @@ -53,21 +58,25 @@ void testContextPropagation() throws Exception {

// setting grpc context
final Context context = Context.current().withValue(USERNAME, "my-user");
Context previous = context.attach();
try {
final ServerCall.Listener listener = blockingServerInterceptor.interceptCall(serverCall, null, serverCallHandler);
serverCallHandler.awaitSetup();

final ServerCall.Listener listener = blockingServerInterceptor.interceptCall(serverCall, null, serverCallHandler);
serverCallHandler.awaitSetup();
// simulate GRPC call
context.wrap(() -> listener.onMessage("hello")).run();

// simulate GRPC call
context.wrap(() -> listener.onMessage("hello")).run();
// await for the message to be received
serverCallHandler.await();

// await for the message to be received
serverCallHandler.await();
// check that the thread is a worker thread
assertThat(serverCallHandler.threadName).contains("vert.x").contains("worker");

// check that the thread is a worker thread
assertThat(serverCallHandler.threadName).contains("vert.x").contains("worker");

// check that the context was propagated correctly
assertThat(serverCallHandler.contextUserName).isEqualTo("my-user");
// check that the context was propagated correctly
assertThat(serverCallHandler.contextUserName).isEqualTo("my-user");
} finally {
context.detach(previous);
}
}

static class BlockingServerCallHandler implements ServerCallHandler {
Expand Down

0 comments on commit 5f4901a

Please sign in to comment.