diff --git a/clientconn.go b/clientconn.go index a257f01583c3..6428f8a23eb1 100644 --- a/clientconn.go +++ b/clientconn.go @@ -322,11 +322,16 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * if ok { go cc.lbWatcher() } - colonPos := strings.LastIndex(target, ":") - if colonPos == -1 { - colonPos = len(target) + creds := cc.dopts.copts.TransportCredentials + if creds != nil && creds.Info().ServerName != "" { + cc.authority = creds.Info().ServerName + } else { + colonPos := strings.LastIndex(target, ":") + if colonPos == -1 { + colonPos = len(target) + } + cc.authority = target[:colonPos] } - cc.authority = target[:colonPos] return cc, nil } diff --git a/clientconn_test.go b/clientconn_test.go index c49548dcd175..3d635c73bba0 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -65,7 +65,23 @@ func TestTLSDialTimeout(t *testing.T) { conn.Close() } if err != ErrClientConnTimeout { - t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout) + t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout) + } +} + +func TestTLSServerNameOverwrite(t *testing.T) { + overwriteServerName := "over.write.server.name" + creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", overwriteServerName) + if err != nil { + t.Fatalf("Failed to create credentials %v", err) + } + conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond)) + if err != nil { + t.Fatalf("Dial(_, _) = _, %v, want _, ", err) + } + conn.Close() + if conn.authority != overwriteServerName { + t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) } } @@ -73,7 +89,7 @@ func TestDialContextCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() if _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure()); err != context.Canceled { - t.Fatalf("grpc.DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled) + t.Fatalf("DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled) } } diff --git a/credentials/credentials.go b/credentials/credentials.go index 13be45742dd1..a6285e62e365 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -72,7 +72,7 @@ type PerRPCCredentials interface { } // ProtocolInfo provides information regarding the gRPC wire protocol version, -// security protocol, security protocol version in use, etc. +// security protocol, security protocol version in use, server name, etc. type ProtocolInfo struct { // ProtocolVersion is the gRPC wire protocol version. ProtocolVersion string @@ -80,6 +80,8 @@ type ProtocolInfo struct { SecurityProtocol string // SecurityVersion is the security protocol version. SecurityVersion string + // ServerName is the user-configured server name. + ServerName string } // AuthInfo defines the common interface for the auth information the users are interested in. @@ -130,6 +132,7 @@ func (c tlsCreds) Info() ProtocolInfo { return ProtocolInfo{ SecurityProtocol: "tls", SecurityVersion: "1.2", + ServerName: c.config.ServerName, } } @@ -187,12 +190,16 @@ func NewTLS(c *tls.Config) TransportCredentials { } // NewClientTLSFromCert constructs a TLS from the input certificate for client. -func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials { - return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}) +// serverNameOverwrite is for testing only. If set to a non empty string, +// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests. +func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverwrite string) TransportCredentials { + return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp}) } // NewClientTLSFromFile constructs a TLS from the input certificate file for client. -func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) { +// serverNameOverwrite is for testing only. If set to a non empty string, +// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests. +func NewClientTLSFromFile(certFile, serverNameOverwrite string) (TransportCredentials, error) { b, err := ioutil.ReadFile(certFile) if err != nil { return nil, err @@ -201,7 +208,7 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, er if !cp.AppendCertsFromPEM(b) { return nil, fmt.Errorf("credentials: failed to append certificates") } - return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil + return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp}), nil } // NewServerTLSFromCert constructs a TLS from the input certificate for server.