Skip to content

Commit

Permalink
Break StreamDirector interface, fix metadata propagation for gRPC-Go>…
Browse files Browse the repository at this point in the history
…1.5. (#20)
  • Loading branch information
Michal Witkowski authored Nov 20, 2017
1 parent 97396d9 commit 67591eb
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 16 deletions.
6 changes: 5 additions & 1 deletion proxy/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ import (
// The presence of the `Context` allows for rich filtering, e.g. based on Metadata (headers).
// If no handling is meant to be done, a `codes.NotImplemented` gRPC error should be returned.
//
// The context returned from this function should be the context for the *outgoing* (to backend) call. In case you want
// to forward any Metadata between the inbound request and outbound requests, you should do it manually. However, you
// *must* propagate the cancel function (`context.WithCancel`) of the inbound context to the one returned.
//
// It is worth noting that the StreamDirector will be fired *after* all server-side stream interceptors
// are invoked. So decisions around authorization, monitoring etc. are better to be handled there.
//
// See the rather rich example.
type StreamDirector func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error)
type StreamDirector func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error)
17 changes: 11 additions & 6 deletions proxy/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,26 @@ func ExampleTransparentHandler() {
// Provide sa simple example of a director that shields internal services and dials a staging or production backend.
// This is a *very naive* implementation that creates a new connection on every request. Consider using pooling.
func ExampleStreamDirector() {
director = func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error) {
director = func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) {
// Make sure we never forward internal services.
if strings.HasPrefix(fullMethodName, "/com.example.internal.") {
return nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
}
md, ok := metadata.FromContext(ctx)
md, ok := metadata.FromIncomingContext(ctx)
// Copy the inbound metadata explicitly.
outCtx, _ := context.WithCancel(ctx)
outCtx = metadata.NewOutgoingContext(outCtx, md.Copy())
if ok {
// Decide on which backend to dial
if val, exists := md[":authority"]; exists && val[0] == "staging.api.example.com" {
// Make sure we use DialContext so the dialing can be cancelled/time out together with the context.
return grpc.DialContext(ctx, "api-service.staging.svc.local", grpc.WithCodec(proxy.Codec()))
conn, err := grpc.DialContext(ctx, "api-service.staging.svc.local", grpc.WithCodec(proxy.Codec()))
return outCtx, conn, err
} else if val, exists := md[":authority"]; exists && val[0] == "api.example.com" {
return grpc.DialContext(ctx, "api-service.prod.svc.local", grpc.WithCodec(proxy.Codec()))
conn, err := grpc.DialContext(ctx, "api-service.prod.svc.local", grpc.WithCodec(proxy.Codec()))
return outCtx, conn, err
}
}
return nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
}
}
5 changes: 3 additions & 2 deletions proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
return grpc.Errorf(codes.Internal, "lowLevelServerStream not exists in context")
}
fullMethodName := lowLevelServerStream.Method()
clientCtx, clientCancel := context.WithCancel(serverStream.Context())
backendConn, err := s.director(serverStream.Context(), fullMethodName)
// We require that the director's returned context inherits from the serverStream.Context().
outgoingCtx, backendConn, err := s.director(serverStream.Context(), fullMethodName)
clientCtx, clientCancel := context.WithCancel(outgoingCtx)
if err != nil {
return err
}
Expand Down
17 changes: 10 additions & 7 deletions proxy/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type assertingService struct {

func (s *assertingService) PingEmpty(ctx context.Context, _ *pb.Empty) (*pb.PingResponse, error) {
// Check that this call has client's metadata.
md, ok := metadata.FromContext(ctx)
md, ok := metadata.FromIncomingContext(ctx)
assert.True(s.t, ok, "PingEmpty call must have metadata in context")
_, ok = md[clientMdKey]
assert.True(s.t, ok, "PingEmpty call must have clients's custom headers in metadata")
Expand Down Expand Up @@ -116,7 +116,7 @@ func (s *ProxyHappySuite) ctx() context.Context {
}

func (s *ProxyHappySuite) TestPingEmptyCarriesClientMetadata() {
ctx := metadata.NewContext(s.ctx(), metadata.Pairs(clientMdKey, "true"))
ctx := metadata.NewOutgoingContext(s.ctx(), metadata.Pairs(clientMdKey, "true"))
out, err := s.testClient.PingEmpty(ctx, &pb.Empty{})
require.NoError(s.T(), err, "PingEmpty should succeed without errors")
require.Equal(s.T(), &pb.PingResponse{Value: pingDefaultValue, Counter: 42}, out)
Expand Down Expand Up @@ -148,7 +148,7 @@ func (s *ProxyHappySuite) TestPingErrorPropagatesAppError() {

func (s *ProxyHappySuite) TestDirectorErrorIsPropagated() {
// See SetupSuite where the StreamDirector has a special case.
ctx := metadata.NewContext(s.ctx(), metadata.Pairs(rejectingMdKey, "true"))
ctx := metadata.NewOutgoingContext(s.ctx(), metadata.Pairs(rejectingMdKey, "true"))
_, err := s.testClient.Ping(ctx, &pb.PingRequest{Value: "foo"})
require.Error(s.T(), err, "Director should reject this RPC")
assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err))
Expand Down Expand Up @@ -204,14 +204,17 @@ func (s *ProxyHappySuite) SetupSuite() {
// Setup of the proxy's Director.
s.serverClientConn, err = grpc.Dial(s.serverListener.Addr().String(), grpc.WithInsecure(), grpc.WithCodec(proxy.Codec()))
require.NoError(s.T(), err, "must not error on deferred client Dial")
director := func(ctx context.Context, fullName string) (*grpc.ClientConn, error) {
md, ok := metadata.FromContext(ctx)
director := func(ctx context.Context, fullName string) (context.Context, *grpc.ClientConn, error) {
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if _, exists := md[rejectingMdKey]; exists {
return nil, grpc.Errorf(codes.PermissionDenied, "testing rejection")
return ctx, nil, grpc.Errorf(codes.PermissionDenied, "testing rejection")
}
}
return s.serverClientConn, nil
// Explicitly copy the metadata, otherwise the tests will fail.
outCtx, _ := context.WithCancel(ctx)
outCtx = metadata.NewOutgoingContext(outCtx, md.Copy())
return outCtx, s.serverClientConn, nil
}
s.proxy = grpc.NewServer(
grpc.CustomCodec(proxy.Codec()),
Expand Down

0 comments on commit 67591eb

Please sign in to comment.