Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
The problem comes from the default gRPC context storage using a thread-local.
This commit overrides the storage implementation (using the recommended method) to use the duplicated context and fallback to a thread-local.
  • Loading branch information
cescoffier committed Sep 26, 2022
1 parent cd4f781 commit b82b235
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void shouldSecureUniEndpoint() {
client.unaryCall(Security.Container.newBuilder().setText("woo-hoo").build())
.subscribe().with(e -> resultCount.incrementAndGet());

await().atMost(5, TimeUnit.SECONDS)
await().atMost(10, TimeUnit.SECONDS)
.until(() -> resultCount.get() == 1);
}

Expand All @@ -82,7 +82,7 @@ void shouldSecureMultiEndpoint() {
.supplier(() -> (Security.Container.newBuilder().setText("woo-hoo").build())).atMost(4))
.subscribe().with(e -> results.add(e.getIsOnEventLoop()));

await().atMost(5, TimeUnit.SECONDS)
await().atMost(10, TimeUnit.SECONDS)
.until(() -> results.size() == 5);

assertThat(results.stream().filter(e -> !e)).isEmpty();
Expand All @@ -101,7 +101,7 @@ void shouldFailWithInvalidCredentials() {
.onFailure().invoke(error::set)
.subscribe().with(e -> resultCount.incrementAndGet());

await().atMost(5, TimeUnit.SECONDS)
await().atMost(10, TimeUnit.SECONDS)
.until(() -> error.get() != null);
}

Expand All @@ -118,7 +118,7 @@ void shouldFailWithInvalidInsufficientRole() {
.onFailure().invoke(error::set)
.subscribe().with(e -> resultCount.incrementAndGet());

await().atMost(5, TimeUnit.SECONDS)
await().atMost(10, TimeUnit.SECONDS)
.until(() -> error.get() != null);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package io.quarkus.grpc.server.interceptors;

import static org.assertj.core.api.Assertions.assertThat;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.quarkus.grpc.GrpcClient;
import io.quarkus.test.QuarkusUnitTest;

/**
* Test reproducing <a href="https://github.com/quarkusio/quarkus/issues/26830">#26830</a>.
*/
public class GrpcContextPropagationTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer(
() -> ShrinkWrap.create(JavaArchive.class)
.addPackage(GreeterGrpc.class.getPackage())
.addClasses(MyFirstInterceptor.class, MyInterceptedGreeting.class));

@GrpcClient
Greeter greeter;

@Test
void test() {
HelloReply foo = greeter.sayHello(HelloRequest.newBuilder().setName("foo").build()).await().indefinitely();
assertThat(foo.getMessage()).isEqualTo("hello k1 - 1");
foo = greeter.sayHello(HelloRequest.newBuilder().setName("foo").build()).await().indefinitely();
assertThat(foo.getMessage()).isEqualTo("hello k1 - 2");
}

}
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package io.quarkus.grpc.server.interceptors;

import java.util.concurrent.atomic.AtomicInteger;

import javax.enterprise.context.ApplicationScoped;
import javax.enterprise.inject.spi.Prioritized;

import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
Expand All @@ -15,19 +19,26 @@
@GlobalInterceptor
public class MyFirstInterceptor implements ServerInterceptor, Prioritized {

public static Context.Key<String> KEY_1 = Context.key("X-TEST_1");
public static Context.Key<Integer> KEY_2 = Context.keyWithDefault("X-TEST_2", -1);
private volatile long callTime;

private AtomicInteger counter = new AtomicInteger();

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall,
Metadata metadata, ServerCallHandler<ReqT, RespT> serverCallHandler) {
return serverCallHandler
.startCall(new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(serverCall) {
@Override
public void close(Status status, Metadata trailers) {
callTime = System.nanoTime();
super.close(status, trailers);
}
}, metadata);

Context ctx = Context.current().withValue(KEY_1, "k1").withValue(KEY_2, counter.incrementAndGet());
return Contexts.interceptCall(ctx, new ForwardingServerCall.SimpleForwardingServerCall<>(serverCall) {

@Override
public void close(Status status, Metadata trailers) {
callTime = System.nanoTime();
super.close(status, trailers);
}
}, metadata, serverCallHandler);

}

public long getLastCall() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.quarkus.grpc.server.interceptors;

import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.quarkus.grpc.GrpcService;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.mutiny.Uni;

@GrpcService
public class MyInterceptedGreeting implements Greeter {
@Override
@Blocking
public Uni<HelloReply> sayHello(HelloRequest request) {
return Uni.createFrom().item(() -> HelloReply.newBuilder()
.setMessage("hello " + MyFirstInterceptor.KEY_1.get() + " - " + MyFirstInterceptor.KEY_2.get()).build());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package io.grpc.override;

import io.grpc.Context;
import io.smallrye.common.vertx.VertxContext;
import io.vertx.core.Vertx;

/**
* Override gRPC context storage to rely on duplicated context when available.
*/
public class ContextStorageOverride extends Context.Storage {

private static final ThreadLocal<Context> fallback = new ThreadLocal<>();

private static final String GRPC_CONTEXT = "GRPC_CONTEXT";

@Override
public Context doAttach(Context toAttach) {
Context current = current();
io.vertx.core.Context dc = Vertx.currentContext();
if (dc != null && VertxContext.isDuplicatedContext(dc)) {
dc.putLocal(GRPC_CONTEXT, toAttach);
} else {
fallback.set(toAttach);
}
return current;
}

@Override
public void detach(Context context, Context toRestore) {
io.vertx.core.Context dc = Vertx.currentContext();
if (toRestore != Context.ROOT) {
if (dc != null && VertxContext.isDuplicatedContext(dc)) {
dc.putLocal(GRPC_CONTEXT, toRestore);
} else {
fallback.set(toRestore);
}
} else {
if (dc != null && VertxContext.isDuplicatedContext(dc)) {
// Do nothing duplicated context are not shared.
} else {
fallback.set(null);
}
}
}

@Override
public Context current() {
if (VertxContext.isOnDuplicatedContext()) {
Context current = Vertx.currentContext().getLocal(GRPC_CONTEXT);
if (current == null) {
return Context.ROOT;
}
return current;
} else {
Context current = fallback.get();
if (current == null) {
return Context.ROOT;
}
return current;
}
}

@Override
public void attach(Context toAttach) {
// do nothing, should not be called.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ private synchronized ServerCall.Listener<ReqT> getDelegate() {
delegate = supplier.get();
} catch (Throwable t) {
// If the interceptor supplier throws an exception, catch it, and close the call.
log.warnf("Unable to retrieve gRPC Server call listener", t);
log.warn("Unable to retrieve gRPC Server call listener", t);
close(t);
return null;
}
Expand Down

0 comments on commit b82b235

Please sign in to comment.