Skip to content

Commit

Permalink
gRPC: fix request context propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
michalszynkiewicz committed Jun 17, 2021
1 parent cceeab7 commit c62fe3b
Show file tree
Hide file tree
Showing 32 changed files with 1,104 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.quarkus.grpc.runtime.MutinyStub;
import io.quarkus.grpc.runtime.supports.Channels;
import io.quarkus.grpc.runtime.supports.GrpcClientConfigProvider;
import io.quarkus.grpc.runtime.supports.context.GrpcEnableRequestContext;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.common.annotation.NonBlocking;

Expand All @@ -28,7 +27,6 @@ public class GrpcDotNames {
public static final DotName CHANNEL = DotName.createSimple(Channel.class.getName());
public static final DotName GRPC_CLIENT = DotName.createSimple(GrpcClient.class.getName());
public static final DotName GRPC_SERVICE = DotName.createSimple(GrpcService.class.getName());
public static final DotName GRPC_ENABLE_REQUEST_CONTEXT = DotName.createSimple(GrpcEnableRequestContext.class.getName());

public static final DotName BLOCKING = DotName.createSimple(Blocking.class.getName());
public static final DotName NON_BLOCKING = DotName.createSimple(NonBlocking.class.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import io.quarkus.arc.processor.AnnotationsTransformer;
import io.quarkus.arc.processor.BeanInfo;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.Transformation;
import io.quarkus.deployment.IsDevelopment;
import io.quarkus.deployment.IsNormal;
import io.quarkus.deployment.annotations.BuildProducer;
Expand All @@ -60,8 +59,6 @@
import io.quarkus.grpc.runtime.config.GrpcServerBuildTimeConfig;
import io.quarkus.grpc.runtime.health.GrpcHealthEndpoint;
import io.quarkus.grpc.runtime.health.GrpcHealthStorage;
import io.quarkus.grpc.runtime.supports.context.GrpcEnableRequestContext;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor;
import io.quarkus.kubernetes.spi.KubernetesPortBuildItem;
import io.quarkus.netty.deployment.MinNettyAllocatorMaxOrderBuildItem;
import io.quarkus.runtime.LaunchMode;
Expand Down Expand Up @@ -240,14 +237,11 @@ public boolean appliesTo(Kind kind) {
@Override
public void transform(TransformationContext context) {
ClassInfo clazz = context.getTarget().asClass();
if (userDefinedServices.contains(clazz.name())) {
// Add @GrpcEnableRequestContext to activate the request context during each call
Transformation transform = context.transform().add(GrpcDotNames.GRPC_ENABLE_REQUEST_CONTEXT);
if (!customScopes.isScopeDeclaredOn(clazz)) {
// Add @Singleton to make it a bean
transform.add(BuiltinScope.SINGLETON.getName());
}
transform.done();
if (userDefinedServices.contains(clazz.name()) && !customScopes.isScopeDeclaredOn(clazz)) {
// Add @Singleton to make it a bean
context.transform()
.add(BuiltinScope.SINGLETON.getName())
.done();
}
}
});
Expand Down Expand Up @@ -303,8 +297,6 @@ void registerBeans(BuildProducer<AdditionalBeanBuildItem> beans,
List<BindableServiceBuildItem> bindables, BuildProducer<FeatureBuildItem> features) {
// @GrpcService is a CDI qualifier
beans.produce(new AdditionalBeanBuildItem(GrpcService.class));
beans.produce(new AdditionalBeanBuildItem(GrpcRequestContextCdiInterceptor.class));
beans.produce(new AdditionalBeanBuildItem(GrpcEnableRequestContext.class));

if (!bindables.isEmpty() || LaunchMode.current() == LaunchMode.DEVELOPMENT) {
beans.produce(AdditionalBeanBuildItem.unremovableOf(GrpcContainer.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
import io.quarkus.grpc.runtime.devmode.GrpcServerReloader;
import io.quarkus.grpc.runtime.health.GrpcHealthStorage;
import io.quarkus.grpc.runtime.reflection.ReflectionService;
import io.quarkus.grpc.runtime.supports.BlockingServerInterceptor;
import io.quarkus.grpc.runtime.supports.CompressionInterceptor;
import io.quarkus.grpc.runtime.supports.blocking.BlockingServerInterceptor;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextGrpcInterceptor;
import io.quarkus.runtime.LaunchMode;
import io.quarkus.runtime.RuntimeValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
import grpc.health.v1.HealthOuterClass.HealthCheckResponse.ServingStatus;
import grpc.health.v1.MutinyHealthGrpc;
import io.quarkus.grpc.GrpcService;
import io.quarkus.grpc.runtime.supports.context.GrpcEnableRequestContext;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.operators.multi.processors.BroadcastProcessor;

// Note that we need to add the scope and interceptor binding explicitly because this class is not part of the index
// Note that we need to add the scope explicitly because this class is not part of the index
@Singleton
@GrpcEnableRequestContext
@GrpcService
public class GrpcHealthEndpoint extends MutinyHealthGrpc.HealthImplBase {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ public static Channel createChannel(String name) throws SSLException {
GrpcClientConfiguration config = configProvider.getConfiguration(name);

if (config == null && LaunchMode.current() == LaunchMode.TEST) {
LOGGER.infof(
"gRPC client %s created without configuration. We are assuming that it's created to test your gRPC services.",
name);
config = testConfig(configProvider.getServerConfiguration());
}

Expand Down Expand Up @@ -164,7 +167,6 @@ public static Channel createChannel(String name) throws SSLException {
}

private static GrpcClientConfiguration testConfig(GrpcServerConfiguration serverConfiguration) {
LOGGER.info("gRPC client created without configuration. We are assuming that it's created to test your gRPC services.");
GrpcClientConfiguration config = new GrpcClientConfiguration();
config.port = serverConfiguration.testPort;
config.host = serverConfiguration.host;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.quarkus.grpc.runtime.supports.blocking;

import java.util.function.Consumer;

import io.grpc.Context;
import io.grpc.ServerCall;
import io.quarkus.arc.InjectableContext;
import io.quarkus.arc.ManagedContext;
import io.vertx.core.Handler;
import io.vertx.core.Promise;

class BlockingExecutionHandler<ReqT> implements Handler<Promise<Object>> {
private final ServerCall.Listener<ReqT> delegate;
private final Context grpcContext;
private final Consumer<ServerCall.Listener<ReqT>> consumer;
private final InjectableContext.ContextState state;
private final ManagedContext requestContext;

public BlockingExecutionHandler(Consumer<ServerCall.Listener<ReqT>> consumer, Context grpcContext,
ServerCall.Listener<ReqT> delegate, InjectableContext.ContextState state,
ManagedContext requestContext) {
this.consumer = consumer;
this.grpcContext = grpcContext;
this.delegate = delegate;
this.state = state;
this.requestContext = requestContext;
}

@Override
public void handle(Promise<Object> event) {
final Context previous = Context.current();
grpcContext.attach();
try {
requestContext.activate(state);
try {
consumer.accept(delegate);
} catch (Throwable any) {
event.fail(any);
return;
} finally {
requestContext.deactivate();
}
event.complete();
} finally {
grpcContext.detach(previous);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package io.quarkus.grpc.runtime.supports;
package io.quarkus.grpc.runtime.supports.blocking;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
Expand All @@ -13,12 +12,15 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.quarkus.arc.Arc;
import io.quarkus.arc.InjectableContext.ContextState;
import io.quarkus.arc.ManagedContext;
import io.vertx.core.Handler;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;

/**
* gRPC Server interceptor offloading the execution of the gRPC method on a wroker thread if the method is annotated
* gRPC Server interceptor offloading the execution of the gRPC method on a worker thread if the method is annotated
* with {@link io.smallrye.common.annotation.Blocking}.
*
* For non-annotated methods, the interceptor acts as a pass-through.
Expand Down Expand Up @@ -62,13 +64,20 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
boolean isBlocking = cache.computeIfAbsent(fullMethodName, this);

if (isBlocking) {
ReplayListener<ReqT> replay = new ReplayListener<>();

final ManagedContext requestContext = getRequestContext();
ContextState state = requestContext.getState();
ReplayListener<ReqT> replay = new ReplayListener<>(state);
vertx.executeBlocking(new Handler<Promise<Object>>() {
@Override
public void handle(Promise<Object> f) {
ServerCall.Listener<ReqT> listener = next.startCall(call, headers);
replay.setDelegate(listener);
ServerCall.Listener<ReqT> listener;
try {
requestContext.activate(state);
listener = next.startCall(call, headers);
} finally {
requestContext.deactivate();
}
replay.setDelegate(listener, requestContext);
f.complete(null);
}
}, null);
Expand All @@ -87,30 +96,46 @@ public void handle(Promise<Object> f) {
*/
private class ReplayListener<ReqT> extends ServerCall.Listener<ReqT> {
private ServerCall.Listener<ReqT> delegate;
private final List<Consumer<ServerCall.Listener<ReqT>>> incomingEvents = new LinkedList<>();
private final List<Consumer<ServerCall.Listener<ReqT>>> incomingEvents = new ArrayList<>();
private final ContextState requestContextState;

private ReplayListener(ContextState requestContextState) {
this.requestContextState = requestContextState;
}

synchronized void setDelegate(ServerCall.Listener<ReqT> delegate) {
synchronized void setDelegate(ServerCall.Listener<ReqT> delegate,
ManagedContext requestContext) {
this.delegate = delegate;
for (Consumer<ServerCall.Listener<ReqT>> event : incomingEvents) {
event.accept(delegate);
requestContext.activate(requestContextState);
try {
for (Consumer<ServerCall.Listener<ReqT>> event : incomingEvents) {
event.accept(delegate);
}
} finally {
requestContext.deactivate();
}
incomingEvents.clear();
}

private synchronized void executeOnContextOrEnqueue(Consumer<ServerCall.Listener<ReqT>> consumer) {
if (this.delegate != null) {
final Context grpcContext = Context.current();
Handler<Promise<Object>> blockingHandler = new BlockingExecutionHandler<>(consumer, grpcContext, delegate);
if (devMode) {
blockingHandler = new DevModeBlockingExecutionHandler<ReqT>(Thread.currentThread().getContextClassLoader(),
blockingHandler);
}
vertx.executeBlocking(blockingHandler, true, null);
executeBlockingWithRequestContext(consumer);
} else {
incomingEvents.add(consumer);
}
}

private void executeBlockingWithRequestContext(Consumer<ServerCall.Listener<ReqT>> consumer) {
final Context grpcContext = Context.current();
Handler<Promise<Object>> blockingHandler = new BlockingExecutionHandler<>(consumer, grpcContext, delegate,
requestContextState, getRequestContext());
if (devMode) {
blockingHandler = new DevModeBlockingExecutionHandler(Thread.currentThread().getContextClassLoader(),
blockingHandler);
}
vertx.executeBlocking(blockingHandler, true, null);
}

@Override
public void onMessage(ReqT message) {
executeOnContextOrEnqueue(new Consumer<ServerCall.Listener<ReqT>>() {
Expand Down Expand Up @@ -142,50 +167,8 @@ public void onReady() {
}
}

private static class DevModeBlockingExecutionHandler<ReqT> implements Handler<Promise<Object>> {

final ClassLoader tccl;
final Handler<Promise<Object>> delegate;

public DevModeBlockingExecutionHandler(ClassLoader tccl, Handler<Promise<Object>> delegate) {
this.tccl = tccl;
this.delegate = delegate;
}

@Override
public void handle(Promise<Object> event) {
ClassLoader originalTccl = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(tccl);
try {
delegate.handle(event);
} finally {
Thread.currentThread().setContextClassLoader(originalTccl);
}
}
}

private static class BlockingExecutionHandler<ReqT> implements Handler<Promise<Object>> {
private final ServerCall.Listener<ReqT> delegate;
private final Context grpcContext;
private final Consumer<ServerCall.Listener<ReqT>> consumer;

public BlockingExecutionHandler(Consumer<ServerCall.Listener<ReqT>> consumer, Context grpcContext,
ServerCall.Listener<ReqT> delegate) {
this.consumer = consumer;
this.grpcContext = grpcContext;
this.delegate = delegate;
}

@Override
public void handle(Promise<Object> event) {
final Context previous = Context.current();
grpcContext.attach();
try {
consumer.accept(delegate);
event.complete();
} finally {
grpcContext.detach(previous);
}
}
// protected for tests
protected ManagedContext getRequestContext() {
return Arc.container().requestContext();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkus.grpc.runtime.supports.blocking;

import io.vertx.core.Handler;
import io.vertx.core.Promise;

class DevModeBlockingExecutionHandler implements Handler<Promise<Object>> {

final ClassLoader tccl;
final Handler<Promise<Object>> delegate;

public DevModeBlockingExecutionHandler(ClassLoader tccl, Handler<Promise<Object>> delegate) {
this.tccl = tccl;
this.delegate = delegate;
}

@Override
public void handle(Promise<Object> event) {
ClassLoader originalTccl = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(tccl);
try {
delegate.handle(event);
} finally {
Thread.currentThread().setContextClassLoader(originalTccl);
}
}
}

This file was deleted.

Loading

0 comments on commit c62fe3b

Please sign in to comment.