From e63e0d4095c31d4576f1ba4903afedea6fbe2cde Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Tue, 13 Sep 2016 11:56:46 -0700 Subject: [PATCH] Add Clone() and OverrideServerName() to TransportCredentials --- credentials/credentials.go | 41 ++++++++++++---------- credentials/credentials_test.go | 61 +++++++++++++++++++++++++++++++++ test/end2end_test.go | 18 ++++++++++ 3 files changed, 102 insertions(+), 18 deletions(-) create mode 100644 credentials/credentials_test.go diff --git a/credentials/credentials.go b/credentials/credentials.go index a6285e62e365..5555ef024f67 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -109,6 +109,12 @@ type TransportCredentials interface { ServerHandshake(net.Conn) (net.Conn, AuthInfo, error) // Info provides the ProtocolInfo of this TransportCredentials. Info() ProtocolInfo + // Clone makes a copy of this TransportCredentials. + Clone() TransportCredentials + // OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server. + // gRPC internals also use it to override the virtual hosting name if it is set. + // It must be called before dialing. Currently, this is only used by grpclb. + OverrideServerName(string) error } // TLSInfo contains the auth information for a TLS authenticated connection. @@ -136,16 +142,6 @@ func (c tlsCreds) Info() ProtocolInfo { } } -// GetRequestMetadata returns nil, nil since TLS credentials does not have -// metadata. -func (c *tlsCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - return nil, nil -} - -func (c *tlsCreds) RequireTransportSecurity() bool { - return true -} - func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) { // use local cfg to avoid clobbering ServerName if using multiple endpoints cfg := cloneTLSConfig(c.config) @@ -182,6 +178,15 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) return conn, TLSInfo{conn.ConnectionState()}, nil } +func (c *tlsCreds) Clone() TransportCredentials { + return NewTLS(c.config) +} + +func (c *tlsCreds) OverrideServerName(serverNameOverride string) error { + c.config.ServerName = serverNameOverride + return nil +} + // NewTLS uses c to construct a TransportCredentials based on TLS. func NewTLS(c *tls.Config) TransportCredentials { tc := &tlsCreds{cloneTLSConfig(c)} @@ -190,16 +195,16 @@ func NewTLS(c *tls.Config) TransportCredentials { } // NewClientTLSFromCert constructs a TLS from the input certificate for client. -// serverNameOverwrite is for testing only. If set to a non empty string, -// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests. -func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverwrite string) TransportCredentials { - return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp}) +// serverNameOverride is for testing only. If set to a non empty string, +// it will override the virtual host name of authority (e.g. :authority header field) in requests. +func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials { + return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}) } // NewClientTLSFromFile constructs a TLS from the input certificate file for client. -// serverNameOverwrite is for testing only. If set to a non empty string, -// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests. -func NewClientTLSFromFile(certFile, serverNameOverwrite string) (TransportCredentials, error) { +// serverNameOverride is for testing only. If set to a non empty string, +// it will override the virtual host name of authority (e.g. :authority header field) in requests. +func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) { b, err := ioutil.ReadFile(certFile) if err != nil { return nil, err @@ -208,7 +213,7 @@ func NewClientTLSFromFile(certFile, serverNameOverwrite string) (TransportCreden if !cp.AppendCertsFromPEM(b) { return nil, fmt.Errorf("credentials: failed to append certificates") } - return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp}), nil + return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil } // NewServerTLSFromCert constructs a TLS from the input certificate for server. diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go new file mode 100644 index 000000000000..caf35b2feca8 --- /dev/null +++ b/credentials/credentials_test.go @@ -0,0 +1,61 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package credentials + +import ( + "testing" +) + +func TestTLSOverrideServerName(t *testing.T) { + expectedServerName := "server.name" + c := NewTLS(nil) + c.OverrideServerName(expectedServerName) + if c.Info().ServerName != expectedServerName { + t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) + } +} + +func TestTLSClone(t *testing.T) { + expectedServerName := "server.name" + c := NewTLS(nil) + c.OverrideServerName(expectedServerName) + cc := c.Clone() + if cc.Info().ServerName != expectedServerName { + t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) + } + cc.OverrideServerName("") + if c.Info().ServerName != expectedServerName { + t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 7adc247c9e87..fcd94971c9e2 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2450,6 +2450,12 @@ func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, crede func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { return credentials.ProtocolInfo{} } +func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials { + return nil +} +func (c clientAlwaysFailCred) OverrideServerName(s string) error { + return nil +} func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true}) @@ -2520,6 +2526,12 @@ func (c *clientTimeoutCreds) ServerHandshake(rawConn net.Conn) (net.Conn, creden func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo { return credentials.ProtocolInfo{} } +func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials { + return nil +} +func (c *clientTimeoutCreds) OverrideServerName(s string) error { + return nil +} func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: false}) @@ -2556,6 +2568,12 @@ func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, creden func (c *serverDispatchCred) Info() credentials.ProtocolInfo { return credentials.ProtocolInfo{} } +func (c *serverDispatchCred) Clone() credentials.TransportCredentials { + return nil +} +func (c *serverDispatchCred) OverrideServerName(s string) error { + return nil +} func (c *serverDispatchCred) getRawConn() net.Conn { <-c.ready return c.rawConn