Skip to content

Commit

Permalink
Add Clone() and OverrideServerName() to TransportCredentials
Browse files Browse the repository at this point in the history
  • Loading branch information
menghanl committed Sep 26, 2016
1 parent 3644242 commit e63e0d4
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 18 deletions.
41 changes: 23 additions & 18 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ type TransportCredentials interface {
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportCredentials.
Info() ProtocolInfo
// Clone makes a copy of this TransportCredentials.
Clone() TransportCredentials
// OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server.
// gRPC internals also use it to override the virtual hosting name if it is set.
// It must be called before dialing. Currently, this is only used by grpclb.
OverrideServerName(string) error
}

// TLSInfo contains the auth information for a TLS authenticated connection.
Expand Down Expand Up @@ -136,16 +142,6 @@ func (c tlsCreds) Info() ProtocolInfo {
}
}

// GetRequestMetadata returns nil, nil since TLS credentials does not have
// metadata.
func (c *tlsCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return nil, nil
}

func (c *tlsCreds) RequireTransportSecurity() bool {
return true
}

func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := cloneTLSConfig(c.config)
Expand Down Expand Up @@ -182,6 +178,15 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
return conn, TLSInfo{conn.ConnectionState()}, nil
}

func (c *tlsCreds) Clone() TransportCredentials {
return NewTLS(c.config)
}

func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
c.config.ServerName = serverNameOverride
return nil
}

// NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{cloneTLSConfig(c)}
Expand All @@ -190,16 +195,16 @@ func NewTLS(c *tls.Config) TransportCredentials {
}

// NewClientTLSFromCert constructs a TLS from the input certificate for client.
// serverNameOverwrite is for testing only. If set to a non empty string,
// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverwrite string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp})
// serverNameOverride is for testing only. If set to a non empty string,
// it will override the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp})
}

// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
// serverNameOverwrite is for testing only. If set to a non empty string,
// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromFile(certFile, serverNameOverwrite string) (TransportCredentials, error) {
// serverNameOverride is for testing only. If set to a non empty string,
// it will override the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) {
b, err := ioutil.ReadFile(certFile)
if err != nil {
return nil, err
Expand All @@ -208,7 +213,7 @@ func NewClientTLSFromFile(certFile, serverNameOverwrite string) (TransportCreden
if !cp.AppendCertsFromPEM(b) {
return nil, fmt.Errorf("credentials: failed to append certificates")
}
return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp}), nil
return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
}

// NewServerTLSFromCert constructs a TLS from the input certificate for server.
Expand Down
61 changes: 61 additions & 0 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/

package credentials

import (
"testing"
)

func TestTLSOverrideServerName(t *testing.T) {
expectedServerName := "server.name"
c := NewTLS(nil)
c.OverrideServerName(expectedServerName)
if c.Info().ServerName != expectedServerName {
t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}
}

func TestTLSClone(t *testing.T) {
expectedServerName := "server.name"
c := NewTLS(nil)
c.OverrideServerName(expectedServerName)
cc := c.Clone()
if cc.Info().ServerName != expectedServerName {
t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
}
cc.OverrideServerName("")
if c.Info().ServerName != expectedServerName {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}
}
18 changes: 18 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2450,6 +2450,12 @@ func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, crede
func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials {
return nil
}
func (c clientAlwaysFailCred) OverrideServerName(s string) error {
return nil
}

func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true})
Expand Down Expand Up @@ -2520,6 +2526,12 @@ func (c *clientTimeoutCreds) ServerHandshake(rawConn net.Conn) (net.Conn, creden
func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials {
return nil
}
func (c *clientTimeoutCreds) OverrideServerName(s string) error {
return nil
}

func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) {
te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: false})
Expand Down Expand Up @@ -2556,6 +2568,12 @@ func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, creden
func (c *serverDispatchCred) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (c *serverDispatchCred) Clone() credentials.TransportCredentials {
return nil
}
func (c *serverDispatchCred) OverrideServerName(s string) error {
return nil
}
func (c *serverDispatchCred) getRawConn() net.Conn {
<-c.ready
return c.rawConn
Expand Down

0 comments on commit e63e0d4

Please sign in to comment.