diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index ac300f364997..a922f888b59a 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -184,99 +184,75 @@ func NewServerWithInterceptor( opts = append(opts, grpc.Creds(credentials.NewTLS(tlsConfig))) } - var unaryInterceptor grpc.UnaryServerInterceptor - var streamInterceptor grpc.StreamServerInterceptor + // These interceptors will be called in the order in which they appear, i.e. + // The last element will wrap the actual handler. + var unaryInterceptor []grpc.UnaryServerInterceptor + var streamInterceptor []grpc.StreamServerInterceptor - if tracer := ctx.AmbientCtx.Tracer; tracer != nil { - // We use a SpanInclusionFunc to save a bit of unnecessary work when - // tracing is disabled. - unaryInterceptor = otgrpc.OpenTracingServerInterceptor( - tracer, - otgrpc.IncludingSpans(otgrpc.SpanInclusionFunc( - func( - parentSpanCtx opentracing.SpanContext, - method string, - req, resp interface{}) bool { - // This anonymous func serves to bind the tracer for - // spanInclusionFuncForServer. - return spanInclusionFuncForServer( - tracer.(*tracing.Tracer), parentSpanCtx, method, req, resp) - })), - ) - // TODO(tschottdorf): should set up tracing for stream-based RPCs as - // well. The otgrpc package has no such facility, but there's also this: - // - // https://github.com/grpc-ecosystem/go-grpc-middleware/tree/master/tracing/opentracing - } - - // TODO(tschottdorf): when setting up the interceptors below, could make the - // functions a wee bit more performant by hoisting some of the nil checks - // out. Doubt measurements can tell the difference though. - - if interceptor != nil { - prevUnaryInterceptor := unaryInterceptor - unaryInterceptor = func( + if !ctx.Config.Insecure { + unaryInterceptor = append(unaryInterceptor, func( ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (interface{}, error) { - if err := interceptor(info.FullMethod); err != nil { + if err := requireSuperUser(ctx); err != nil { return nil, err } - if prevUnaryInterceptor != nil { - return prevUnaryInterceptor(ctx, req, info, handler) - } return handler(ctx, req) - } - } - - if interceptor != nil { - prevStreamInterceptor := streamInterceptor - streamInterceptor = func( + }) + streamInterceptor = append(streamInterceptor, func( srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { - if err := interceptor(info.FullMethod); err != nil { + if err := requireSuperUser(stream.Context()); err != nil { return err } - if prevStreamInterceptor != nil { - return prevStreamInterceptor(srv, stream, info, handler) - } return handler(srv, stream) - } + }) } - if !ctx.Config.Insecure { - prevUnaryInterceptor := unaryInterceptor - unaryInterceptor = func( + if interceptor != nil { + unaryInterceptor = append(unaryInterceptor, func( ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (interface{}, error) { - if err := requireSuperUser(ctx); err != nil { + if err := interceptor(info.FullMethod); err != nil { return nil, err } - if prevUnaryInterceptor != nil { - return prevUnaryInterceptor(ctx, req, info, handler) - } return handler(ctx, req) - } - prevStreamInterceptor := streamInterceptor - streamInterceptor = func( + }) + + streamInterceptor = append(streamInterceptor, func( srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { - if err := requireSuperUser(stream.Context()); err != nil { + if err := interceptor(info.FullMethod); err != nil { return err } - if prevStreamInterceptor != nil { - return prevStreamInterceptor(srv, stream, info, handler) - } return handler(srv, stream) - } + }) } - if unaryInterceptor != nil { - opts = append(opts, grpc.UnaryInterceptor(unaryInterceptor)) - } - if streamInterceptor != nil { - opts = append(opts, grpc.StreamInterceptor(streamInterceptor)) + if tracer := ctx.AmbientCtx.Tracer; tracer != nil { + // We use a SpanInclusionFunc to save a bit of unnecessary work when + // tracing is disabled. + unaryInterceptor = append(unaryInterceptor, otgrpc.OpenTracingServerInterceptor( + tracer, + otgrpc.IncludingSpans(otgrpc.SpanInclusionFunc( + func( + parentSpanCtx opentracing.SpanContext, + method string, + req, resp interface{}) bool { + // This anonymous func serves to bind the tracer for + // spanInclusionFuncForServer. + return spanInclusionFuncForServer( + tracer.(*tracing.Tracer), parentSpanCtx, method, req, resp) + })), + )) + // TODO(tschottdorf): should set up tracing for stream-based RPCs as + // well. The otgrpc package has no such facility, but there's also this: + // + // https://github.com/grpc-ecosystem/go-grpc-middleware/tree/master/tracing/opentracing } + opts = append(opts, grpc.ChainUnaryInterceptor(unaryInterceptor...)) + opts = append(opts, grpc.ChainStreamInterceptor(streamInterceptor...)) + s := grpc.NewServer(opts...) RegisterHeartbeatServer(s, &HeartbeatService{ clock: ctx.Clock,