Skip to content

Commit

Permalink
ClientHandshake should get the dialing endpoint as the authority (#1607)
Browse files Browse the repository at this point in the history
  • Loading branch information
menghanl authored Oct 23, 2017
1 parent a5986a5 commit 1687ce5
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 21 deletions.
19 changes: 12 additions & 7 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,13 +431,16 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig
}
cc.parsedTarget = parseTarget(cc.target)
creds := cc.dopts.copts.TransportCredentials
if creds != nil && creds.Info().ServerName != "" {
cc.authority = creds.Info().ServerName
} else if cc.dopts.insecure && cc.dopts.copts.Authority != "" {
cc.authority = cc.dopts.copts.Authority
} else {
cc.authority = target
// Use endpoint from "scheme://authority/endpoint" as the default
// authority for ClientConn.
cc.authority = cc.parsedTarget.Endpoint
}

if cc.dopts.scChan != nil && !scSet {
Expand Down Expand Up @@ -541,10 +544,11 @@ type ClientConn struct {
ctx context.Context
cancel context.CancelFunc

target string
authority string
dopts dialOptions
csMgr *connectivityStateManager
target string
parsedTarget resolver.Target
authority string
dopts dialOptions
csMgr *connectivityStateManager

customBalancer bool // If this is true, switching balancer will be disabled.
balancerBuildOpts balancer.BuildOptions
Expand Down Expand Up @@ -953,8 +957,9 @@ func (ac *addrConn) resetTransport() error {
}
ac.mu.Unlock()
sinfo := transport.TargetInfo{
Addr: addr.Addr,
Metadata: addr.Metadata,
Addr: addr.Addr,
Metadata: addr.Metadata,
Authority: ac.cc.authority,
}
newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ func (c tlsCreds) Info() ProtocolInfo {
}
}

func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := cloneTLSConfig(c.config)
if cfg.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
colonPos := strings.LastIndex(authority, ":")
if colonPos == -1 {
colonPos = len(addr)
colonPos = len(authority)
}
cfg.ServerName = addr[:colonPos]
cfg.ServerName = authority[:colonPos]
}
conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1)
Expand Down
9 changes: 4 additions & 5 deletions resolver_conn_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,11 @@ func parseTarget(target string) (ret resolver.Target) {
// builder for this scheme. It then builds the resolver and starts the
// monitoring goroutine for it.
func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) {
target := parseTarget(cc.target)
grpclog.Infof("dialing to target with scheme: %q", target.Scheme)
grpclog.Infof("dialing to target with scheme: %q", cc.parsedTarget.Scheme)

rb := resolver.Get(target.Scheme)
rb := resolver.Get(cc.parsedTarget.Scheme)
if rb == nil {
return nil, fmt.Errorf("could not get resolver for scheme: %q", target.Scheme)
return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme)
}

ccr := &ccResolverWrapper{
Expand All @@ -74,7 +73,7 @@ func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) {
}

var err error
ccr.resolver, err = rb.Build(target, ccr, resolver.BuildOption{})
ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{})
if err != nil {
return nil, err
}
Expand Down
63 changes: 63 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4287,6 +4287,69 @@ func TestServerCredsDispatch(t *testing.T) {
}
}

type authorityCheckCreds struct {
got string
}

func (c *authorityCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, nil, nil
}
func (c *authorityCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.got = authority
return rawConn, nil, nil
}
func (c *authorityCheckCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (c *authorityCheckCreds) Clone() credentials.TransportCredentials {
return c
}
func (c *authorityCheckCreds) OverrideServerName(s string) error {
return nil
}

// This test makes sure that the authority client handshake gets is the endpoint
// in dial target, not the resolved ip address.
func TestCredsHandshakeAuthority(t *testing.T) {
const testAuthority = "test.auth.ori.ty"

lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
cred := &authorityCheckCreds{}
s := grpc.NewServer()
go s.Serve(lis)
defer s.Stop()

r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()

cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred))
if err != nil {
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
}
defer cc.Close()
r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}})

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
for {
s := cc.GetState()
if s == connectivity.Ready {
break
}
if !cc.WaitForStateChange(ctx, s) {
// ctx got timeout or canceled.
t.Fatalf("ClientConn is not ready after 100 ms")
}
}

if cred.got != testAuthority {
t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority)
}
}

func TestFlowControlLogicalRace(t *testing.T) {
// Test for a regression of https://github.com/grpc/grpc-go/issues/632,
// and other flow control bugs.
Expand Down
4 changes: 1 addition & 3 deletions transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import (
type http2Client struct {
ctx context.Context
cancel context.CancelFunc
target string // server name/addr
userAgent string
md interface{}
conn net.Conn // underlying communication channel
Expand Down Expand Up @@ -175,7 +174,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, t
)
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Addr, conn)
conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Authority, conn)
if err != nil {
// Credentials handshake errors are typically considered permanent
// to avoid retrying on e.g. bad certificates.
Expand Down Expand Up @@ -210,7 +209,6 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, t
t := &http2Client{
ctx: ctx,
cancel: cancel,
target: addr.Addr,
userAgent: opts.UserAgent,
md: addr.Metadata,
conn: conn,
Expand Down
5 changes: 3 additions & 2 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,9 @@ type ConnectOptions struct {

// TargetInfo contains the information of the target such as network address and metadata.
type TargetInfo struct {
Addr string
Metadata interface{}
Addr string
Metadata interface{}
Authority string
}

// NewClientTransport establishes the transport with the required ConnectOptions
Expand Down

0 comments on commit 1687ce5

Please sign in to comment.