Skip to content

Commit

Permalink
Merge pull request #867 from iamqizhao/master
Browse files Browse the repository at this point in the history
Support client side interceptor
  • Loading branch information
menghanl authored Sep 2, 2016
2 parents 8d57dd3 + 61f62e0 commit 52f6504
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 19 deletions.
9 changes: 8 additions & 1 deletion call.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
// Invoke sends the RPC request on the wire and returns after response is received.
// Invoke is called by generated code. Also users can call Invoke directly when it
// is really needed in their use cases.
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) error {
if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
}
return invoke(ctx, method, args, reply, cc, opts...)
}

func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
c := defaultCallInfo
for _, o := range opts {
if err := o.before(&c); err != nil {
Expand Down
34 changes: 25 additions & 9 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ var (
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
codec Codec
cp Compressor
dc Decompressor
bs backoffStrategy
balancer Balancer
block bool
insecure bool
timeout time.Duration
copts transport.ConnectOptions
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
codec Codec
cp Compressor
dc Decompressor
bs backoffStrategy
balancer Balancer
block bool
insecure bool
timeout time.Duration
copts transport.ConnectOptions
}

// DialOption configures how we set up the connection.
Expand Down Expand Up @@ -215,6 +217,20 @@ func WithUserAgent(s string) DialOption {
}
}

// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs.
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
return func(o *dialOptions) {
o.unaryInt = f
}
}

// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs.
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
return func(o *dialOptions) {
o.streamInt = f
}
}

// Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
Expand Down
16 changes: 16 additions & 0 deletions interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ import (
"golang.org/x/net/context"
)

// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs.
type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error

// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC
// and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API.
type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error

// Streamer is called by StreamClientInterceptor to create a ClientStream.
type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error)

// StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O
// operations. streamer is the handlder to create a ClientStream and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API.
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)

// UnaryServerInfo consists of various information about a unary RPC on
// server side. All per-rpc information may be mutated by the interceptor.
type UnaryServerInfo struct {
Expand Down
9 changes: 8 additions & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,14 @@ type ClientStream interface {

// NewClientStream creates a new Stream for the client side. This is called
// by generated code.
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
if cc.dopts.streamInt != nil {
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
}
return newClientStream(ctx, desc, cc, method, opts...)
}

func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var (
t transport.ClientTransport
s *transport.Stream
Expand Down
93 changes: 85 additions & 8 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,10 @@ type test struct {
userAgent string
clientCompression bool
serverCompression bool
unaryInt grpc.UnaryServerInterceptor
streamInt grpc.StreamServerInterceptor
unaryClientInt grpc.UnaryClientInterceptor
streamClientInt grpc.StreamClientInterceptor
unaryServerInt grpc.UnaryServerInterceptor
streamServerInt grpc.StreamServerInterceptor

// srv and srvAddr are set once startServer is called.
srv *grpc.Server
Expand Down Expand Up @@ -423,11 +425,11 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
)
}
if te.unaryInt != nil {
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt))
if te.unaryServerInt != nil {
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryServerInt))
}
if te.streamInt != nil {
sopts = append(sopts, grpc.StreamInterceptor(te.streamInt))
if te.streamServerInt != nil {
sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt))
}
la := "localhost:0"
switch te.e.network {
Expand Down Expand Up @@ -492,6 +494,12 @@ func (te *test) clientConn() *grpc.ClientConn {
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
)
}
if te.unaryClientInt != nil {
opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt))
}
if te.streamClientInt != nil {
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
}
switch te.e.security {
case "tls":
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
Expand Down Expand Up @@ -2137,6 +2145,75 @@ func testCompressOK(t *testing.T, e env) {
}
}

func TestUnaryClientInterceptor(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testUnaryClientInterceptor(t, e)
}
}

func failOkayRPC(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
err := invoker(ctx, method, req, reply, cc, opts...)
if err == nil {
return grpc.Errorf(codes.NotFound, "")
}
return err
}

func testUnaryClientInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA
te.unaryClientInt = failOkayRPC
te.startServer(&testServer{security: e.security})
defer te.tearDown()

tc := testpb.NewTestServiceClient(te.clientConn())
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.NotFound {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %s", tc, err, codes.NotFound)
}
}

func TestStreamClientInterceptor(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testStreamClientInterceptor(t, e)
}
}

func failOkayStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
s, err := streamer(ctx, desc, cc, method, opts...)
if err == nil {
return nil, grpc.Errorf(codes.NotFound, "")
}
return s, nil
}

func testStreamClientInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.streamClientInt = failOkayStream
te.startServer(&testServer{security: e.security})
defer te.tearDown()

tc := testpb.NewTestServiceClient(te.clientConn())
respParam := []*testpb.ResponseParameters{
{
Size: proto.Int32(int32(1)),
},
}
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
if err != nil {
t.Fatal(err)
}
req := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: respParam,
Payload: payload,
}
if _, err := tc.StreamingOutputCall(context.Background(), req); grpc.Code(err) != codes.NotFound {
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, error code %s", tc, err, codes.NotFound)
}
}

func TestUnaryServerInterceptor(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
Expand All @@ -2150,7 +2227,7 @@ func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInf

func testUnaryServerInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.unaryInt = errInjector
te.unaryServerInt = errInjector
te.startServer(&testServer{security: e.security})
defer te.tearDown()

Expand Down Expand Up @@ -2181,7 +2258,7 @@ func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServ

func testStreamServerInterceptor(t *testing.T, e env) {
te := newTest(t, e)
te.streamInt = fullDuplexOnly
te.streamServerInt = fullDuplexOnly
te.startServer(&testServer{security: e.security})
defer te.tearDown()

Expand Down

0 comments on commit 52f6504

Please sign in to comment.