Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClientHandshake should get the dialing endpoint as the authority #1607

Merged
merged 4 commits into from
Oct 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,13 +431,16 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig
}
cc.parsedTarget = parseTarget(cc.target)
creds := cc.dopts.copts.TransportCredentials
if creds != nil && creds.Info().ServerName != "" {
cc.authority = creds.Info().ServerName
} else if cc.dopts.insecure && cc.dopts.copts.Authority != "" {
cc.authority = cc.dopts.copts.Authority
} else {
cc.authority = target
// Use endpoint from "scheme://authority/endpoint" as the default
// authority for ClientConn.
cc.authority = cc.parsedTarget.Endpoint
}

if cc.dopts.scChan != nil && !scSet {
Expand Down Expand Up @@ -541,10 +544,11 @@ type ClientConn struct {
ctx context.Context
cancel context.CancelFunc

target string
authority string
dopts dialOptions
csMgr *connectivityStateManager
target string
parsedTarget resolver.Target
authority string
dopts dialOptions
csMgr *connectivityStateManager

customBalancer bool // If this is true, switching balancer will be disabled.
balancerBuildOpts balancer.BuildOptions
Expand Down Expand Up @@ -953,8 +957,9 @@ func (ac *addrConn) resetTransport() error {
}
ac.mu.Unlock()
sinfo := transport.TargetInfo{
Addr: addr.Addr,
Metadata: addr.Metadata,
Addr: addr.Addr,
Metadata: addr.Metadata,
Authority: ac.cc.authority,
}
newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ func (c tlsCreds) Info() ProtocolInfo {
}
}

func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := cloneTLSConfig(c.config)
if cfg.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
colonPos := strings.LastIndex(authority, ":")
if colonPos == -1 {
colonPos = len(addr)
colonPos = len(authority)
}
cfg.ServerName = addr[:colonPos]
cfg.ServerName = authority[:colonPos]
}
conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1)
Expand Down
9 changes: 4 additions & 5 deletions resolver_conn_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,11 @@ func parseTarget(target string) (ret resolver.Target) {
// builder for this scheme. It then builds the resolver and starts the
// monitoring goroutine for it.
func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) {
target := parseTarget(cc.target)
grpclog.Infof("dialing to target with scheme: %q", target.Scheme)
grpclog.Infof("dialing to target with scheme: %q", cc.parsedTarget.Scheme)

rb := resolver.Get(target.Scheme)
rb := resolver.Get(cc.parsedTarget.Scheme)
if rb == nil {
return nil, fmt.Errorf("could not get resolver for scheme: %q", target.Scheme)
return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme)
}

ccr := &ccResolverWrapper{
Expand All @@ -74,7 +73,7 @@ func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) {
}

var err error
ccr.resolver, err = rb.Build(target, ccr, resolver.BuildOption{})
ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{})
if err != nil {
return nil, err
}
Expand Down
63 changes: 63 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4287,6 +4287,69 @@ func TestServerCredsDispatch(t *testing.T) {
}
}

type authorityCheckCreds struct {
got string
}

func (c *authorityCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, nil, nil
}
func (c *authorityCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.got = authority
return rawConn, nil, nil
}
func (c *authorityCheckCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (c *authorityCheckCreds) Clone() credentials.TransportCredentials {
return c
}
func (c *authorityCheckCreds) OverrideServerName(s string) error {
return nil
}

// This test makes sure that the authority client handshake gets is the endpoint
// in dial target, not the resolved ip address.
func TestCredsHandshakeAuthority(t *testing.T) {
const testAuthority = "test.auth.ori.ty"

lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
cred := &authorityCheckCreds{}
s := grpc.NewServer()
go s.Serve(lis)
defer s.Stop()

r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()

cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred))
if err != nil {
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
}
defer cc.Close()
r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}})

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
for {
s := cc.GetState()
if s == connectivity.Ready {
break
}
if !cc.WaitForStateChange(ctx, s) {
// ctx got timeout or canceled.
t.Fatalf("ClientConn is not ready after 100 ms")
}
}

if cred.got != testAuthority {
t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority)
}
}

func TestFlowControlLogicalRace(t *testing.T) {
// Test for a regression of https://github.com/grpc/grpc-go/issues/632,
// and other flow control bugs.
Expand Down
4 changes: 1 addition & 3 deletions transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import (
type http2Client struct {
ctx context.Context
cancel context.CancelFunc
target string // server name/addr
userAgent string
md interface{}
conn net.Conn // underlying communication channel
Expand Down Expand Up @@ -175,7 +174,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, t
)
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Addr, conn)
conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Authority, conn)
if err != nil {
// Credentials handshake errors are typically considered permanent
// to avoid retrying on e.g. bad certificates.
Expand Down Expand Up @@ -210,7 +209,6 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, t
t := &http2Client{
ctx: ctx,
cancel: cancel,
target: addr.Addr,
userAgent: opts.UserAgent,
md: addr.Metadata,
conn: conn,
Expand Down
5 changes: 3 additions & 2 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,9 @@ type ConnectOptions struct {

// TargetInfo contains the information of the target such as network address and metadata.
type TargetInfo struct {
Addr string
Metadata interface{}
Addr string
Metadata interface{}
Authority string
}

// NewClientTransport establishes the transport with the required ConnectOptions
Expand Down