Skip to content

Commit

Permalink
Overwrite authority if creds servername is specified
Browse files Browse the repository at this point in the history
  • Loading branch information
menghanl committed Sep 6, 2016
1 parent 52f6504 commit a00cbfe
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
13 changes: 9 additions & 4 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,16 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
if ok {
go cc.lbWatcher()
}
colonPos := strings.LastIndex(target, ":")
if colonPos == -1 {
colonPos = len(target)
creds := cc.dopts.copts.TransportCredentials
if creds != nil && creds.Info().ServerName != "" {
cc.authority = creds.Info().ServerName
} else {
colonPos := strings.LastIndex(target, ":")
if colonPos == -1 {
colonPos = len(target)
}
cc.authority = target[:colonPos]
}
cc.authority = target[:colonPos]
return cc, nil
}

Expand Down
20 changes: 18 additions & 2 deletions clientconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,31 @@ func TestTLSDialTimeout(t *testing.T) {
conn.Close()
}
if err != ErrClientConnTimeout {
t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout)
t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout)
}
}

func TestTLSServerNameOverwrite(t *testing.T) {
overwriteServerName := "over.write.server.name"
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", overwriteServerName)
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond))
if err != nil {
t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err)
}
conn.Close()
if conn.authority != overwriteServerName {
t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName)
}
}

func TestDialContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure()); err != context.Canceled {
t.Fatalf("grpc.DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled)
t.Fatalf("DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled)
}
}

Expand Down
17 changes: 12 additions & 5 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,16 @@ type PerRPCCredentials interface {
}

// ProtocolInfo provides information regarding the gRPC wire protocol version,
// security protocol, security protocol version in use, etc.
// security protocol, security protocol version in use, server name, etc.
type ProtocolInfo struct {
// ProtocolVersion is the gRPC wire protocol version.
ProtocolVersion string
// SecurityProtocol is the security protocol in use.
SecurityProtocol string
// SecurityVersion is the security protocol version.
SecurityVersion string
// ServerName is the user-configured server name.
ServerName string
}

// AuthInfo defines the common interface for the auth information the users are interested in.
Expand Down Expand Up @@ -130,6 +132,7 @@ func (c tlsCreds) Info() ProtocolInfo {
return ProtocolInfo{
SecurityProtocol: "tls",
SecurityVersion: "1.2",
ServerName: c.config.ServerName,
}
}

Expand Down Expand Up @@ -187,12 +190,16 @@ func NewTLS(c *tls.Config) TransportCredentials {
}

// NewClientTLSFromCert constructs a TLS from the input certificate for client.
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
// serverNameOverwrite is for testing only. If set to a non empty string,
// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverwrite string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp})
}

// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) {
// serverNameOverwrite is for testing only. If set to a non empty string,
// it will overwrite the virtual host name of authority (e.g. :authority header field) in requests.
func NewClientTLSFromFile(certFile, serverNameOverwrite string) (TransportCredentials, error) {
b, err := ioutil.ReadFile(certFile)
if err != nil {
return nil, err
Expand All @@ -201,7 +208,7 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, er
if !cp.AppendCertsFromPEM(b) {
return nil, fmt.Errorf("credentials: failed to append certificates")
}
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil
return NewTLS(&tls.Config{ServerName: serverNameOverwrite, RootCAs: cp}), nil
}

// NewServerTLSFromCert constructs a TLS from the input certificate for server.
Expand Down

0 comments on commit a00cbfe

Please sign in to comment.