From e9e6ae6215eed864a46b66148f33ae4dfd7cc724 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Thu, 14 Jul 2016 15:19:57 -0700 Subject: [PATCH 01/12] Make Dial() withblock error on bad certificates --- call.go | 18 +++++++++++++++--- clientconn.go | 3 +++ stream.go | 9 ++++++++- test/end2end_test.go | 20 ++++++++++++++++++++ transport/http2_client.go | 19 ++++++++++--------- transport/http2_server.go | 10 +++++----- transport/transport.go | 11 +++++++++-- 7 files changed, 70 insertions(+), 20 deletions(-) diff --git a/call.go b/call.go index a8b6dcfd2a59..f1a40a37a95b 100644 --- a/call.go +++ b/call.go @@ -84,7 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd } defer func() { if err != nil { - if _, ok := err.(transport.ConnectionError); !ok { + if e, ok := err.(transport.ConnectionError); !ok || !e.Temporary() { t.CloseStream(stream, err) } } @@ -190,10 +190,13 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli // Retry a non-failfast RPC when // i) there is a connection error; or // ii) the server started to drain before this RPC was initiated. - if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { + if e, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } + if ok && !e.Temporary() { + return toRPCErr(err) + } continue } return toRPCErr(err) @@ -204,7 +207,16 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli put() put = nil } - if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { + if e, ok := err.(transport.ConnectionError); ok { + if c.failFast { + return toRPCErr(err) + } + if !e.Temporary() { + return toRPCErr(err) + } + continue + } + if err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } diff --git a/clientconn.go b/clientconn.go index 5c2bf644f786..f2e568ce63cc 100644 --- a/clientconn.go +++ b/clientconn.go @@ -605,6 +605,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { if err != nil { cancel() + if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { + return fmt.Errorf("failed to create client transport: %v", err) + } ac.mu.Lock() if ac.state == Shutdown { // ac.tearDown(...) has been invoked. diff --git a/stream.go b/stream.go index 66bfad812ae5..1d5104f686e7 100644 --- a/stream.go +++ b/stream.go @@ -166,7 +166,14 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth put() put = nil } - if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { + if e, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { + if c.failFast || e.Temporary() { + cs.finish(err) + return nil, toRPCErr(err) + } + continue + } + if err == transport.ErrStreamDrain { if c.failFast { cs.finish(err) return nil, toRPCErr(err) diff --git a/test/end2end_test.go b/test/end2end_test.go index e57ff919545e..ffa4fa35aea2 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2267,6 +2267,26 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) { }) } +func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-tls", network: "tcp", security: "bad-tls"}) + te.startServer() + defer te.tearDown() + + var ( + err error + opts []grpc.DialOption + ) + creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "wrong-server.com") + if err != nil { + te.t.Fatalf("Failed to load credentials: %v", err) + } + opts = append(opts, grpc.WithTransportCredentials(creds), grpc.WithBlock()) + te.cc, err = grpc.Dial(te.srvAddr, opts...) + if err == nil { + te.t.Fatalf("Dial(%q) = %v, want ConnectionError: credentials handshake failed", te.srvAddr, err) + } +} + // interestingGoroutines returns all goroutines we care about for the purpose // of leak checking. It excludes testing or runtime ones. func interestingGoroutines() (gs []string) { diff --git a/transport/http2_client.go b/transport/http2_client.go index 6dc487873c5e..a3709d2df8d7 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -121,7 +121,7 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl scheme := "http" conn, connErr := dial(opts.Dialer, ctx, addr) if connErr != nil { - return nil, ConnectionErrorf("transport: %v", connErr) + return nil, ConnectionErrorf(true, "transport: %v", connErr) } var authInfo credentials.AuthInfo if creds := opts.TransportCredentials; creds != nil { @@ -129,7 +129,8 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn) } if connErr != nil { - return nil, ConnectionErrorf("transport: %v", connErr) + // Credentials handshake error is not a temporary error. + return nil, ConnectionErrorf(false, "transport: %v", connErr) } defer func() { if err != nil { @@ -173,11 +174,11 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl n, err := t.conn.Write(clientPreface) if err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } if n != len(clientPreface) { t.Close() - return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + return nil, ConnectionErrorf(true, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } if initialWindowSize != defaultWindowSize { err = t.framer.writeSettings(true, http2.Setting{ @@ -189,13 +190,13 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl } if err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } } go t.controller() @@ -405,7 +406,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } if err != nil { t.notifyError(err) - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } } t.writableChan <- 0 @@ -619,7 +620,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // invoked. if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { t.notifyError(err) - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() @@ -667,7 +668,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) { func (t *http2Client) handleData(f *http2.DataFrame) { size := len(f.Data()) if err := t.fc.onData(uint32(size)); err != nil { - t.notifyError(ConnectionErrorf("%v", err)) + t.notifyError(ConnectionErrorf(true, "%v", err)) return } // Select the right stream to dispatch. diff --git a/transport/http2_server.go b/transport/http2_server.go index 8ed0cd59cce0..7fd6aeba9e17 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI Val: uint32(initialWindowSize)}) } if err := framer.writeSettings(true, settings...); err != nil { - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := framer.writeWindowUpdate(true, 0, delta); err != nil { - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } } var buf bytes.Buffer @@ -448,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e } if err != nil { t.Close() - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, "transport: %v", err) } } return nil @@ -568,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if err := t.framer.writeHeaders(false, p); err != nil { t.Close() - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, "transport: %v", err) } t.writableChan <- 0 } @@ -642,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { t.Close() - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() diff --git a/transport/transport.go b/transport/transport.go index f739090c1b5a..e2b671845a4d 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -485,9 +485,10 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError { } // ConnectionErrorf creates an ConnectionError with the specified error description. -func ConnectionErrorf(format string, a ...interface{}) ConnectionError { +func ConnectionErrorf(temp bool, format string, a ...interface{}) ConnectionError { return ConnectionError{ Desc: fmt.Sprintf(format, a...), + temp: temp, } } @@ -495,15 +496,21 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError { // entire connection and the retry of all the active streams. type ConnectionError struct { Desc string + temp bool } func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: desc = %q", e.Desc) } +// Temporary indicates if this connection error is temporary or fatal. +func (e ConnectionError) Temporary() bool { + return e.temp +} + var ( // ErrConnClosing indicates that the transport is closing. - ErrConnClosing = ConnectionError{Desc: "transport is closing"} + ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true} // ErrStreamDrain indicates that the stream is rejected by the server because // the server stops accepting new RPCs. ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs") From 1d0bea79438555c9839761d980e1a6df7ca38669 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Fri, 15 Jul 2016 16:20:34 -0700 Subject: [PATCH 02/12] Add addrConn tearDownError --- clientconn.go | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/clientconn.go b/clientconn.go index f2e568ce63cc..a12ca73827b7 100644 --- a/clientconn.go +++ b/clientconn.go @@ -429,11 +429,13 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err // skipWait may overwrite the decision in ac.dopts.block. if ac.dopts.block && !skipWait { if err := ac.resetTransport(false); err != nil { - ac.cc.mu.Lock() - delete(ac.cc.conns, ac.addr) - ac.cc.mu.Unlock() - ac.tearDown(err) - return err + if err != errConnClosing { + ac.cc.mu.Lock() + delete(ac.cc.conns, ac.addr) + ac.cc.mu.Unlock() + ac.tearDown(err) + } + return fmt.Errorf("failed to create transport: %v", err) } // Start to monitor the error status of transport. go ac.transportMonitor() @@ -442,10 +444,12 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err go func() { if err := ac.resetTransport(false); err != nil { grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err) - ac.cc.mu.Lock() - delete(ac.cc.conns, ac.addr) - ac.cc.mu.Unlock() - ac.tearDown(err) + if err != errConnClosing { + ac.cc.mu.Lock() + delete(ac.cc.conns, ac.addr) + ac.cc.mu.Unlock() + ac.tearDown(err) + } return } ac.transportMonitor() @@ -519,6 +523,9 @@ type addrConn struct { // due to timeout. ready chan struct{} transport transport.ClientTransport + + // The reason this addrConn is torn down. + tearDownErr error } // printf records an event in ac's event log, unless ac has been closed. @@ -606,7 +613,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { cancel() if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { - return fmt.Errorf("failed to create client transport: %v", err) + return err } ac.mu.Lock() if ac.state == Shutdown { @@ -708,6 +715,9 @@ func (ac *addrConn) transportMonitor() { ac.printf("transport exiting: %v", err) ac.mu.Unlock() grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err) + if err != errConnClosing { + ac.tearDown(err) + } return } } @@ -722,7 +732,7 @@ func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTr switch { case ac.state == Shutdown: ac.mu.Unlock() - return nil, errConnClosing + return nil, ac.tearDownErr case ac.state == Ready: ct := ac.transport ac.mu.Unlock() @@ -772,6 +782,7 @@ func (ac *addrConn) tearDown(err error) { return } ac.state = Shutdown + ac.tearDownErr = err ac.stateCV.Broadcast() if ac.events != nil { ac.events.Finish() From 779083c6337c1cc7b5dec96332630c4067ab6ec2 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Fri, 15 Jul 2016 17:22:46 -0700 Subject: [PATCH 03/12] Change TestDialWithBlockErrorOnBadCertificates error check --- test/end2end_test.go | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/test/end2end_test.go b/test/end2end_test.go index ffa4fa35aea2..a2aa4143afa7 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -439,12 +439,15 @@ func (te *test) startServer(ts testpb.TestServiceServer) { if err != nil { te.t.Fatalf("Failed to listen: %v", err) } - if te.e.security == "tls" { + switch te.e.security { + case "tls": creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") if err != nil { te.t.Fatalf("Failed to generate credentials %v", err) } sopts = append(sopts, grpc.Creds(creds)) + case "clientAlwaysFailCred": + sopts = append(sopts, grpc.Creds(clientAlwaysFailCred{})) } s := grpc.NewServer(sopts...) te.srv = s @@ -2267,8 +2270,22 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) { }) } +const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" + +type clientAlwaysFailCred struct{} + +func (c clientAlwaysFailCred) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ credentials.AuthInfo, err error) { + return nil, nil, errors.New(clientAlwaysFailCredErrorMsg) +} +func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} + func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-tls", network: "tcp", security: "bad-tls"}) + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"}) te.startServer() defer te.tearDown() @@ -2276,14 +2293,10 @@ func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { err error opts []grpc.DialOption ) - creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "wrong-server.com") - if err != nil { - te.t.Fatalf("Failed to load credentials: %v", err) - } - opts = append(opts, grpc.WithTransportCredentials(creds), grpc.WithBlock()) + opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}), grpc.WithBlock()) te.cc, err = grpc.Dial(te.srvAddr, opts...) - if err == nil { - te.t.Fatalf("Dial(%q) = %v, want ConnectionError: credentials handshake failed", te.srvAddr, err) + if !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg) } } From 558ecfb3a61c86ae735dfd2417283ef05ab8f4d1 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Mon, 18 Jul 2016 11:58:15 -0700 Subject: [PATCH 04/12] Remove RPC non-fail-fast return --- call.go | 16 ++-------------- stream.go | 9 +-------- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/call.go b/call.go index f1a40a37a95b..0f34b5e84c86 100644 --- a/call.go +++ b/call.go @@ -190,13 +190,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli // Retry a non-failfast RPC when // i) there is a connection error; or // ii) the server started to drain before this RPC was initiated. - if e, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } - if ok && !e.Temporary() { - return toRPCErr(err) - } continue } return toRPCErr(err) @@ -207,16 +204,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli put() put = nil } - if e, ok := err.(transport.ConnectionError); ok { - if c.failFast { - return toRPCErr(err) - } - if !e.Temporary() { - return toRPCErr(err) - } - continue - } - if err == transport.ErrStreamDrain { + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } diff --git a/stream.go b/stream.go index 1d5104f686e7..66bfad812ae5 100644 --- a/stream.go +++ b/stream.go @@ -166,14 +166,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth put() put = nil } - if e, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { - if c.failFast || e.Temporary() { - cs.finish(err) - return nil, toRPCErr(err) - } - continue - } - if err == transport.ErrStreamDrain { + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { cs.finish(err) return nil, toRPCErr(err) From b41c9e8e145172019d4ae867143216b9f3ba6a9b Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Mon, 18 Jul 2016 11:58:38 -0700 Subject: [PATCH 05/12] Add TestFailFastRPCErrorOnBadCertificates --- test/end2end_test.go | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/test/end2end_test.go b/test/end2end_test.go index a2aa4143afa7..55fa0ff56c8b 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -490,13 +490,16 @@ func (te *test) clientConn() *grpc.ClientConn { grpc.WithDecompressor(grpc.NewGZIPDecompressor()), ) } - if te.e.security == "tls" { + switch te.e.security { + case "tls": creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") if err != nil { te.t.Fatalf("Failed to load credentials: %v", err) } opts = append(opts, grpc.WithTransportCredentials(creds)) - } else { + case "clientAlwaysFailCred": + opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{})) + default: opts = append(opts, grpc.WithInsecure()) } var err error @@ -2286,7 +2289,7 @@ func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"}) - te.startServer() + te.startServer(&testServer{security: "clientAlwaysFailCred"}) defer te.tearDown() var ( @@ -2300,6 +2303,18 @@ func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { } } +func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"}) + te.startServer(&testServer{security: "clientAlwaysFailCred"}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg) + } +} + // interestingGoroutines returns all goroutines we care about for the purpose // of leak checking. It excludes testing or runtime ones. func interestingGoroutines() (gs []string) { From 7c6eabc607eedee135cc335ec2c65162541a176f Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Mon, 18 Jul 2016 15:00:37 -0700 Subject: [PATCH 06/12] Make Dial() return original error --- clientconn.go | 5 ++++- test/end2end_test.go | 8 +++++--- transport/http2_client.go | 18 +++++++++--------- transport/http2_server.go | 10 +++++----- transport/transport.go | 17 ++++++++++++----- 5 files changed, 35 insertions(+), 23 deletions(-) diff --git a/clientconn.go b/clientconn.go index a12ca73827b7..f17ea50f13dd 100644 --- a/clientconn.go +++ b/clientconn.go @@ -435,7 +435,10 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err ac.cc.mu.Unlock() ac.tearDown(err) } - return fmt.Errorf("failed to create transport: %v", err) + if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { + return e.OriginalError() + } + return err } // Start to monitor the error status of transport. go ac.transportMonitor() diff --git a/test/end2end_test.go b/test/end2end_test.go index 55fa0ff56c8b..b258893a44dc 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2275,10 +2275,12 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) { const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" +var clientAlwaysFailCredError = errors.New(clientAlwaysFailCredErrorMsg) + type clientAlwaysFailCred struct{} func (c clientAlwaysFailCred) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ credentials.AuthInfo, err error) { - return nil, nil, errors.New(clientAlwaysFailCredErrorMsg) + return nil, nil, clientAlwaysFailCredError } func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { return rawConn, nil, nil @@ -2298,8 +2300,8 @@ func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { ) opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}), grpc.WithBlock()) te.cc, err = grpc.Dial(te.srvAddr, opts...) - if !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { - te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg) + if err != clientAlwaysFailCredError { + te.t.Fatalf("Dial(%q) = %v, want %v", te.srvAddr, err, clientAlwaysFailCredError) } } diff --git a/transport/http2_client.go b/transport/http2_client.go index a3709d2df8d7..8fb1e61ccdb1 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -121,7 +121,7 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl scheme := "http" conn, connErr := dial(opts.Dialer, ctx, addr) if connErr != nil { - return nil, ConnectionErrorf(true, "transport: %v", connErr) + return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr) } var authInfo credentials.AuthInfo if creds := opts.TransportCredentials; creds != nil { @@ -130,7 +130,7 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl } if connErr != nil { // Credentials handshake error is not a temporary error. - return nil, ConnectionErrorf(false, "transport: %v", connErr) + return nil, ConnectionErrorf(false, connErr, "transport: %v", connErr) } defer func() { if err != nil { @@ -174,11 +174,11 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl n, err := t.conn.Write(clientPreface) if err != nil { t.Close() - return nil, ConnectionErrorf(true, "transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } if n != len(clientPreface) { t.Close() - return nil, ConnectionErrorf(true, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } if initialWindowSize != defaultWindowSize { err = t.framer.writeSettings(true, http2.Setting{ @@ -190,13 +190,13 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl } if err != nil { t.Close() - return nil, ConnectionErrorf(true, "transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { t.Close() - return nil, ConnectionErrorf(true, "transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } } go t.controller() @@ -406,7 +406,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } if err != nil { t.notifyError(err) - return nil, ConnectionErrorf(true, "transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } } t.writableChan <- 0 @@ -620,7 +620,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // invoked. if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { t.notifyError(err) - return ConnectionErrorf(true, "transport: %v", err) + return ConnectionErrorf(true, err, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() @@ -668,7 +668,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) { func (t *http2Client) handleData(f *http2.DataFrame) { size := len(f.Data()) if err := t.fc.onData(uint32(size)); err != nil { - t.notifyError(ConnectionErrorf(true, "%v", err)) + t.notifyError(ConnectionErrorf(true, err, "%v", err)) return } // Select the right stream to dispatch. diff --git a/transport/http2_server.go b/transport/http2_server.go index 7fd6aeba9e17..16010d55fb21 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI Val: uint32(initialWindowSize)}) } if err := framer.writeSettings(true, settings...); err != nil { - return nil, ConnectionErrorf(true, "transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := framer.writeWindowUpdate(true, 0, delta); err != nil { - return nil, ConnectionErrorf(true, "transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } } var buf bytes.Buffer @@ -448,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e } if err != nil { t.Close() - return ConnectionErrorf(true, "transport: %v", err) + return ConnectionErrorf(true, err, "transport: %v", err) } } return nil @@ -568,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if err := t.framer.writeHeaders(false, p); err != nil { t.Close() - return ConnectionErrorf(true, "transport: %v", err) + return ConnectionErrorf(true, err, "transport: %v", err) } t.writableChan <- 0 } @@ -642,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { t.Close() - return ConnectionErrorf(true, "transport: %v", err) + return ConnectionErrorf(true, err, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() diff --git a/transport/transport.go b/transport/transport.go index e2b671845a4d..2a53b4e5a669 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -485,18 +485,20 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError { } // ConnectionErrorf creates an ConnectionError with the specified error description. -func ConnectionErrorf(temp bool, format string, a ...interface{}) ConnectionError { +func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError { return ConnectionError{ - Desc: fmt.Sprintf(format, a...), - temp: temp, + Desc: fmt.Sprintf(format, a...), + temp: temp, + origErr: e, } } // ConnectionError is an error that results in the termination of the // entire connection and the retry of all the active streams. type ConnectionError struct { - Desc string - temp bool + Desc string + temp bool + origErr error } func (e ConnectionError) Error() string { @@ -508,6 +510,11 @@ func (e ConnectionError) Temporary() bool { return e.temp } +// OriginalError returns the original error of this connection error. +func (e ConnectionError) OriginalError() error { + return e.origErr +} + var ( // ErrConnClosing indicates that the transport is closing. ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true} From d7d831d95e30305a2cad711679032c1e5890289b Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Mon, 18 Jul 2016 15:58:41 -0700 Subject: [PATCH 07/12] Do not create RoundRobin if balancer is not specified --- clientconn.go | 77 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 28 deletions(-) diff --git a/clientconn.go b/clientconn.go index f17ea50f13dd..29de0fa61e26 100644 --- a/clientconn.go +++ b/clientconn.go @@ -233,25 +233,27 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { if cc.dopts.bs == nil { cc.dopts.bs = DefaultBackoffConfig } - if cc.dopts.balancer == nil { - cc.dopts.balancer = RoundRobin(nil) - } - if err := cc.dopts.balancer.Start(target); err != nil { - return nil, err - } var ( ok bool addrs []Address ) - ch := cc.dopts.balancer.Notify() - if ch == nil { - // There is no name resolver installed. + if cc.dopts.balancer == nil { + // Connect to target directly if balancer is nil. addrs = append(addrs, Address{Addr: target}) } else { - addrs, ok = <-ch - if !ok || len(addrs) == 0 { - return nil, errNoAddr + if err := cc.dopts.balancer.Start(target); err != nil { + return nil, err + } + ch := cc.dopts.balancer.Notify() + if ch == nil { + // There is no name resolver installed. + addrs = append(addrs, Address{Addr: target}) + } else { + addrs, ok = <-ch + if !ok || len(addrs) == 0 { + return nil, errNoAddr + } } } waitC := make(chan error, 1) @@ -282,6 +284,8 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { return nil, ErrClientConnTimeout } if ok { + // If balancer is nil or balancer.Notify() is nil, ok will false here. + // Then this goroutine will not be created. go cc.lbWatcher() } colonPos := strings.LastIndex(target, ":") @@ -462,22 +466,35 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err } func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { - addr, put, err := cc.dopts.balancer.Get(ctx, opts) - if err != nil { - return nil, nil, toRPCErr(err) - } - cc.mu.RLock() - if cc.conns == nil { + var ( + ac *addrConn + ok bool + put func() + ) + if cc.dopts.balancer == nil { + // If balancer is nil, there should be only one addrConn available. + for _, ac = range cc.conns { + // Break after the first loop to get the first addrConn. + break + } + } else { + addr, put, err := cc.dopts.balancer.Get(ctx, opts) + if err != nil { + return nil, nil, toRPCErr(err) + } + cc.mu.RLock() + if cc.conns == nil { + cc.mu.RUnlock() + return nil, nil, toRPCErr(ErrClientConnClosing) + } + ac, ok = cc.conns[addr] cc.mu.RUnlock() - return nil, nil, toRPCErr(ErrClientConnClosing) - } - ac, ok := cc.conns[addr] - cc.mu.RUnlock() - if !ok { - if put != nil { - put() + if !ok { + if put != nil { + put() + } + return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc") } - return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc") } t, err := ac.wait(ctx, !opts.BlockingWait) if err != nil { @@ -501,7 +518,9 @@ func (cc *ClientConn) Close() error { conns := cc.conns cc.conns = nil cc.mu.Unlock() - cc.dopts.balancer.Close() + if cc.dopts.balancer != nil { + cc.dopts.balancer.Close() + } for _, ac := range conns { ac.tearDown(ErrClientConnClosing) } @@ -657,7 +676,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { close(ac.ready) ac.ready = nil } - ac.down = ac.cc.dopts.balancer.Up(ac.addr) + if ac.cc.dopts.balancer != nil { + ac.down = ac.cc.dopts.balancer.Up(ac.addr) + } ac.mu.Unlock() return nil } From 1a571b746ace04b72607c3f9ad2399442b58614d Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Mon, 18 Jul 2016 16:00:17 -0700 Subject: [PATCH 08/12] Add TestFailFastRPCWithNoBalancerErrorOnBadCertificates TestNonFailFastRPCWithNoBalancerErrorOnBadCertificates --- test/end2end_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/end2end_test.go b/test/end2end_test.go index b258893a44dc..32dd4e696eb2 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -369,6 +369,7 @@ type test struct { serverCompression bool unaryInt grpc.UnaryServerInterceptor streamInt grpc.StreamServerInterceptor + balancer grpc.Balancer // srv and srvAddr are set once startServer is called. srv *grpc.Server @@ -404,6 +405,8 @@ func newTest(t *testing.T, e env) *test { maxStream: math.MaxUint32, } te.ctx, te.cancel = context.WithCancel(context.Background()) + // Install roundrobin balancer by default. + te.balancer = grpc.RoundRobin(nil) return te } @@ -502,6 +505,9 @@ func (te *test) clientConn() *grpc.ClientConn { default: opts = append(opts, grpc.WithInsecure()) } + if te.balancer != nil { + opts = append(opts, grpc.WithBalancer(te.balancer)) + } var err error te.cc, err = grpc.Dial(te.srvAddr, opts...) if err != nil { @@ -2317,6 +2323,34 @@ func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { } } +func TestFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"}) + // Uninstall balancer. + te.balancer = nil + te.startServer(&testServer{security: "clientAlwaysFailCred"}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg) + } +} + +func TestNonFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"}) + // Uninstall balancer. + te.balancer = nil + te.startServer(&testServer{security: "clientAlwaysFailCred"}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg) + } +} + // interestingGoroutines returns all goroutines we care about for the purpose // of leak checking. It excludes testing or runtime ones. func interestingGoroutines() (gs []string) { From 4bbb9d8142ae36d8f0e985f056617374de3af648 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Tue, 19 Jul 2016 12:02:58 -0700 Subject: [PATCH 09/12] Add test env for no-balancer --- test/end2end_test.go | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/test/end2end_test.go b/test/end2end_test.go index 32dd4e696eb2..1c763544b907 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -305,6 +305,7 @@ type env struct { network string // The type of network such as tcp, unix, etc. security string // The security protocol such as TLS, SSH, etc. httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS + balancer bool // whether to use balancer } func (e env) runnable() bool { @@ -319,12 +320,13 @@ func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { } var ( - tcpClearEnv = env{name: "tcp-clear", network: "tcp"} - tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"} - unixClearEnv = env{name: "unix-clear", network: "unix"} - unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls"} - handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true} - allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv} + tcpClearEnv = env{name: "tcp-clear", network: "tcp", balancer: true} + tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls", balancer: true} + unixClearEnv = env{name: "unix-clear", network: "unix", balancer: true} + unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls", balancer: true} + handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true, balancer: true} + noBalancerEnv = env{name: "no-balancer", network: "tcp", security: "tls", balancer: false} + allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv, noBalancerEnv} ) var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', 'unix-tls', or 'handler-tls' to only run the tests for that environment. Empty means all.") @@ -369,7 +371,6 @@ type test struct { serverCompression bool unaryInt grpc.UnaryServerInterceptor streamInt grpc.StreamServerInterceptor - balancer grpc.Balancer // srv and srvAddr are set once startServer is called. srv *grpc.Server @@ -405,8 +406,6 @@ func newTest(t *testing.T, e env) *test { maxStream: math.MaxUint32, } te.ctx, te.cancel = context.WithCancel(context.Background()) - // Install roundrobin balancer by default. - te.balancer = grpc.RoundRobin(nil) return te } @@ -505,8 +504,8 @@ func (te *test) clientConn() *grpc.ClientConn { default: opts = append(opts, grpc.WithInsecure()) } - if te.balancer != nil { - opts = append(opts, grpc.WithBalancer(te.balancer)) + if te.e.balancer { + opts = append(opts, grpc.WithBalancer(grpc.RoundRobin(nil))) } var err error te.cc, err = grpc.Dial(te.srvAddr, opts...) @@ -2296,7 +2295,7 @@ func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { } func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"}) + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true}) te.startServer(&testServer{security: "clientAlwaysFailCred"}) defer te.tearDown() @@ -2312,7 +2311,7 @@ func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { } func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"}) + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true}) te.startServer(&testServer{security: "clientAlwaysFailCred"}) defer te.tearDown() @@ -2324,9 +2323,7 @@ func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { } func TestFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"}) - // Uninstall balancer. - te.balancer = nil + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: false}) te.startServer(&testServer{security: "clientAlwaysFailCred"}) defer te.tearDown() @@ -2338,9 +2335,7 @@ func TestFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { } func TestNonFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred"}) - // Uninstall balancer. - te.balancer = nil + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: false}) te.startServer(&testServer{security: "clientAlwaysFailCred"}) defer te.tearDown() From f6b46c17871b2b6841661665d294ce289ef4c865 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Tue, 26 Jul 2016 14:28:36 -0700 Subject: [PATCH 10/12] Fix errors after rebasing --- clientconn.go | 19 +++++++++++++------ test/end2end_test.go | 2 +- transport/http2_client.go | 4 ++-- transport/transport.go | 5 +++++ 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/clientconn.go b/clientconn.go index 29de0fa61e26..52e57ca2ba67 100644 --- a/clientconn.go +++ b/clientconn.go @@ -473,12 +473,19 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) ) if cc.dopts.balancer == nil { // If balancer is nil, there should be only one addrConn available. + cc.mu.RLock() for _, ac = range cc.conns { // Break after the first loop to get the first addrConn. + ok = true break } + cc.mu.RUnlock() } else { - addr, put, err := cc.dopts.balancer.Get(ctx, opts) + var ( + addr Address + err error + ) + addr, put, err = cc.dopts.balancer.Get(ctx, opts) if err != nil { return nil, nil, toRPCErr(err) } @@ -489,12 +496,12 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) } ac, ok = cc.conns[addr] cc.mu.RUnlock() - if !ok { - if put != nil { - put() - } - return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc") + } + if !ok { + if put != nil { + put() } + return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc") } t, err := ac.wait(ctx, !opts.BlockingWait) if err != nil { diff --git a/test/end2end_test.go b/test/end2end_test.go index 1c763544b907..0702cb7142cc 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2284,7 +2284,7 @@ var clientAlwaysFailCredError = errors.New(clientAlwaysFailCredErrorMsg) type clientAlwaysFailCred struct{} -func (c clientAlwaysFailCred) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ credentials.AuthInfo, err error) { +func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { return nil, nil, clientAlwaysFailCredError } func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { diff --git a/transport/http2_client.go b/transport/http2_client.go index 8fb1e61ccdb1..9a8384ee1344 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -774,7 +774,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { if t.state == reachable || t.state == draining { if f.LastStreamID > 0 && f.LastStreamID%2 != 1 { t.mu.Unlock() - t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) + t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) return } select { @@ -783,7 +783,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // t.goAway has been closed (i.e.,multiple GoAways). if id < f.LastStreamID { t.mu.Unlock() - t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) return } t.prevGoAwayID = id diff --git a/transport/transport.go b/transport/transport.go index 2a53b4e5a669..311158f7f7d7 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -512,6 +512,11 @@ func (e ConnectionError) Temporary() bool { // OriginalError returns the original error of this connection error. func (e ConnectionError) OriginalError() error { + // Never return nil error here. + // If original error is nil, return itself. + if e.origErr == nil { + return e + } return e.origErr } From fa5748afd3be06741332800992d15972f3fed74f Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Wed, 27 Jul 2016 11:57:23 -0700 Subject: [PATCH 11/12] Change error returned for transport not found --- clientconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clientconn.go b/clientconn.go index 52e57ca2ba67..946cca9c9fc6 100644 --- a/clientconn.go +++ b/clientconn.go @@ -501,7 +501,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) if put != nil { put() } - return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc") + return nil, nil, errConnClosing } t, err := ac.wait(ctx, !opts.BlockingWait) if err != nil { From a4587cd3f03fe017b6dcfd0f42dd550faae8d171 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Fri, 29 Jul 2016 17:17:36 -0700 Subject: [PATCH 12/12] Fix review comments --- call.go | 3 ++- clientconn.go | 23 ++++++++++++----------- transport/transport.go | 22 +++++++++++----------- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/call.go b/call.go index 0f34b5e84c86..5fba11eb0881 100644 --- a/call.go +++ b/call.go @@ -84,7 +84,8 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd } defer func() { if err != nil { - if e, ok := err.(transport.ConnectionError); !ok || !e.Temporary() { + // If err is connection error, t will be closed, no need to close stream here. + if _, ok := err.(transport.ConnectionError); !ok { t.CloseStream(stream, err) } } diff --git a/clientconn.go b/clientconn.go index 946cca9c9fc6..8d9d8d1c6c4c 100644 --- a/clientconn.go +++ b/clientconn.go @@ -283,9 +283,9 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { cc.Close() return nil, ErrClientConnTimeout } + // If balancer is nil or balancer.Notify() is nil, ok will be false here. + // The lbWatcher goroutine will not be created. if ok { - // If balancer is nil or balancer.Notify() is nil, ok will false here. - // Then this goroutine will not be created. go cc.lbWatcher() } colonPos := strings.LastIndex(target, ":") @@ -434,13 +434,14 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err if ac.dopts.block && !skipWait { if err := ac.resetTransport(false); err != nil { if err != errConnClosing { - ac.cc.mu.Lock() - delete(ac.cc.conns, ac.addr) - ac.cc.mu.Unlock() + // Tear down ac and delete it from cc.conns. + cc.mu.Lock() + delete(cc.conns, ac.addr) + cc.mu.Unlock() ac.tearDown(err) } if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { - return e.OriginalError() + return e.Origin() } return err } @@ -452,9 +453,7 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err if err := ac.resetTransport(false); err != nil { grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err) if err != errConnClosing { - ac.cc.mu.Lock() - delete(ac.cc.conns, ac.addr) - ac.cc.mu.Unlock() + // Keep this ac in cc.conns, to get the reason it's torn down. ac.tearDown(err) } return @@ -475,7 +474,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) // If balancer is nil, there should be only one addrConn available. cc.mu.RLock() for _, ac = range cc.conns { - // Break after the first loop to get the first addrConn. + // Break after the first iteration to get the first addrConn. ok = true break } @@ -747,6 +746,7 @@ func (ac *addrConn) transportMonitor() { ac.mu.Unlock() grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err) if err != errConnClosing { + // Keep this ac in cc.conns, to get the reason it's torn down. ac.tearDown(err) } return @@ -762,8 +762,9 @@ func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTr ac.mu.Lock() switch { case ac.state == Shutdown: + err := ac.tearDownErr ac.mu.Unlock() - return nil, ac.tearDownErr + return nil, err case ac.state == Ready: ct := ac.transport ac.mu.Unlock() diff --git a/transport/transport.go b/transport/transport.go index 311158f7f7d7..d59e511372a1 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -487,18 +487,18 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError { // ConnectionErrorf creates an ConnectionError with the specified error description. func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError { return ConnectionError{ - Desc: fmt.Sprintf(format, a...), - temp: temp, - origErr: e, + Desc: fmt.Sprintf(format, a...), + temp: temp, + err: e, } } // ConnectionError is an error that results in the termination of the // entire connection and the retry of all the active streams. type ConnectionError struct { - Desc string - temp bool - origErr error + Desc string + temp bool + err error } func (e ConnectionError) Error() string { @@ -510,14 +510,14 @@ func (e ConnectionError) Temporary() bool { return e.temp } -// OriginalError returns the original error of this connection error. -func (e ConnectionError) OriginalError() error { +// Origin returns the original error of this connection error. +func (e ConnectionError) Origin() error { // Never return nil error here. - // If original error is nil, return itself. - if e.origErr == nil { + // If the original error is nil, return itself. + if e.err == nil { return e } - return e.origErr + return e.err } var (