Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
51025: rpc: use grpc.Chain{Unary,Stream}Interceptor r=nvanbenschoten,andreimatei a=tbg

These were not available when the code was first written but they are
very helpful for keeping things tidy.

Release note: None

Co-authored-by: Tobias Schottdorf <[email protected]>
  • Loading branch information
craig[bot] and tbg committed Jul 9, 2020
2 parents 1b5d070 + 9f77450 commit 492cde2
Showing 1 changed file with 42 additions and 66 deletions.
108 changes: 42 additions & 66 deletions pkg/rpc/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,99 +184,75 @@ func NewServerWithInterceptor(
opts = append(opts, grpc.Creds(credentials.NewTLS(tlsConfig)))
}

var unaryInterceptor grpc.UnaryServerInterceptor
var streamInterceptor grpc.StreamServerInterceptor
// These interceptors will be called in the order in which they appear, i.e.
// The last element will wrap the actual handler.
var unaryInterceptor []grpc.UnaryServerInterceptor
var streamInterceptor []grpc.StreamServerInterceptor

if tracer := ctx.AmbientCtx.Tracer; tracer != nil {
// We use a SpanInclusionFunc to save a bit of unnecessary work when
// tracing is disabled.
unaryInterceptor = otgrpc.OpenTracingServerInterceptor(
tracer,
otgrpc.IncludingSpans(otgrpc.SpanInclusionFunc(
func(
parentSpanCtx opentracing.SpanContext,
method string,
req, resp interface{}) bool {
// This anonymous func serves to bind the tracer for
// spanInclusionFuncForServer.
return spanInclusionFuncForServer(
tracer.(*tracing.Tracer), parentSpanCtx, method, req, resp)
})),
)
// TODO(tschottdorf): should set up tracing for stream-based RPCs as
// well. The otgrpc package has no such facility, but there's also this:
//
// https://github.com/grpc-ecosystem/go-grpc-middleware/tree/master/tracing/opentracing
}

// TODO(tschottdorf): when setting up the interceptors below, could make the
// functions a wee bit more performant by hoisting some of the nil checks
// out. Doubt measurements can tell the difference though.

if interceptor != nil {
prevUnaryInterceptor := unaryInterceptor
unaryInterceptor = func(
if !ctx.Config.Insecure {
unaryInterceptor = append(unaryInterceptor, func(
ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
) (interface{}, error) {
if err := interceptor(info.FullMethod); err != nil {
if err := requireSuperUser(ctx); err != nil {
return nil, err
}
if prevUnaryInterceptor != nil {
return prevUnaryInterceptor(ctx, req, info, handler)
}
return handler(ctx, req)
}
}

if interceptor != nil {
prevStreamInterceptor := streamInterceptor
streamInterceptor = func(
})
streamInterceptor = append(streamInterceptor, func(
srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler,
) error {
if err := interceptor(info.FullMethod); err != nil {
if err := requireSuperUser(stream.Context()); err != nil {
return err
}
if prevStreamInterceptor != nil {
return prevStreamInterceptor(srv, stream, info, handler)
}
return handler(srv, stream)
}
})
}

if !ctx.Config.Insecure {
prevUnaryInterceptor := unaryInterceptor
unaryInterceptor = func(
if interceptor != nil {
unaryInterceptor = append(unaryInterceptor, func(
ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
) (interface{}, error) {
if err := requireSuperUser(ctx); err != nil {
if err := interceptor(info.FullMethod); err != nil {
return nil, err
}
if prevUnaryInterceptor != nil {
return prevUnaryInterceptor(ctx, req, info, handler)
}
return handler(ctx, req)
}
prevStreamInterceptor := streamInterceptor
streamInterceptor = func(
})

streamInterceptor = append(streamInterceptor, func(
srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler,
) error {
if err := requireSuperUser(stream.Context()); err != nil {
if err := interceptor(info.FullMethod); err != nil {
return err
}
if prevStreamInterceptor != nil {
return prevStreamInterceptor(srv, stream, info, handler)
}
return handler(srv, stream)
}
})
}

if unaryInterceptor != nil {
opts = append(opts, grpc.UnaryInterceptor(unaryInterceptor))
}
if streamInterceptor != nil {
opts = append(opts, grpc.StreamInterceptor(streamInterceptor))
if tracer := ctx.AmbientCtx.Tracer; tracer != nil {
// We use a SpanInclusionFunc to save a bit of unnecessary work when
// tracing is disabled.
unaryInterceptor = append(unaryInterceptor, otgrpc.OpenTracingServerInterceptor(
tracer,
otgrpc.IncludingSpans(otgrpc.SpanInclusionFunc(
func(
parentSpanCtx opentracing.SpanContext,
method string,
req, resp interface{}) bool {
// This anonymous func serves to bind the tracer for
// spanInclusionFuncForServer.
return spanInclusionFuncForServer(
tracer.(*tracing.Tracer), parentSpanCtx, method, req, resp)
})),
))
// TODO(tschottdorf): should set up tracing for stream-based RPCs as
// well. The otgrpc package has no such facility, but there's also this:
//
// https://github.com/grpc-ecosystem/go-grpc-middleware/tree/master/tracing/opentracing
}

opts = append(opts, grpc.ChainUnaryInterceptor(unaryInterceptor...))
opts = append(opts, grpc.ChainStreamInterceptor(streamInterceptor...))

s := grpc.NewServer(opts...)
RegisterHeartbeatServer(s, &HeartbeatService{
clock: ctx.Clock,
Expand Down

0 comments on commit 492cde2

Please sign in to comment.