diff --git a/credentials/credentials.go b/credentials/credentials.go index 3f17b70628ed..13be45742dd1 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -40,6 +40,7 @@ package credentials // import "google.golang.org/grpc/credentials" import ( "crypto/tls" "crypto/x509" + "errors" "fmt" "io/ioutil" "net" @@ -86,6 +87,12 @@ type AuthInfo interface { AuthType() string } +var ( + // ErrConnDispatched indicates that rawConn has been dispatched out of gRPC + // and the caller should not close rawConn. + ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC") +) + // TransportCredentials defines the common interface for all the live gRPC wire // protocols and supported transport security protocols (e.g., TLS, SSL). type TransportCredentials interface { diff --git a/server.go b/server.go index 1ed8aac9eb0d..b2a825ad0bfc 100644 --- a/server.go +++ b/server.go @@ -367,7 +367,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) { s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) s.mu.Unlock() grpclog.Printf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) - rawConn.Close() + // If serverHandShake returns ErrConnDispatched, keep rawConn open. + if err != credentials.ErrConnDispatched { + rawConn.Close() + } return } diff --git a/test/end2end_test.go b/test/end2end_test.go index c76c58b1cf18..09d389714304 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2349,6 +2349,54 @@ func TestNonFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { } } +type serverDispatchCred struct { + ready chan struct{} + rawConn net.Conn +} + +func newServerDispatchCred() *serverDispatchCred { + return &serverDispatchCred{ + ready: make(chan struct{}), + } +} +func (c *serverDispatchCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + c.rawConn = rawConn + close(c.ready) + return nil, nil, credentials.ErrConnDispatched +} +func (c *serverDispatchCred) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (c *serverDispatchCred) getRawConn() net.Conn { + <-c.ready + return c.rawConn +} + +func TestServerCredsDispatch(t *testing.T) { + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + cred := newServerDispatchCred() + s := grpc.NewServer(grpc.Creds(cred)) + go s.Serve(lis) + defer s.Stop() + + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(cred)) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + } + defer cc.Close() + + // Check rawConn is not closed. + if n, err := cred.getRawConn().Write([]byte{0}); n <= 0 || err != nil { + t.Errorf("Read() = %v, %v; want n>0, ", n, err) + } +} + // interestingGoroutines returns all goroutines we care about for the purpose // of leak checking. It excludes testing or runtime ones. func interestingGoroutines() (gs []string) {