Skip to content

Commit

Permalink
Merge pull request #768 from menghanl/fatal_on_bad_certificates
Browse files Browse the repository at this point in the history
Return error on bad certificates
  • Loading branch information
menghanl authored Aug 3, 2016
2 parents 2456736 + a4587cd commit 35896af
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 60 deletions.
1 change: 1 addition & 0 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
}
defer func() {
if err != nil {
// If err is connection error, t will be closed, no need to close stream here.
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
Expand Down
112 changes: 79 additions & 33 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,25 +233,27 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig
}
if cc.dopts.balancer == nil {
cc.dopts.balancer = RoundRobin(nil)
}

if err := cc.dopts.balancer.Start(target); err != nil {
return nil, err
}
var (
ok bool
addrs []Address
)
ch := cc.dopts.balancer.Notify()
if ch == nil {
// There is no name resolver installed.
if cc.dopts.balancer == nil {
// Connect to target directly if balancer is nil.
addrs = append(addrs, Address{Addr: target})
} else {
addrs, ok = <-ch
if !ok || len(addrs) == 0 {
return nil, errNoAddr
if err := cc.dopts.balancer.Start(target); err != nil {
return nil, err
}
ch := cc.dopts.balancer.Notify()
if ch == nil {
// There is no name resolver installed.
addrs = append(addrs, Address{Addr: target})
} else {
addrs, ok = <-ch
if !ok || len(addrs) == 0 {
return nil, errNoAddr
}
}
}
waitC := make(chan error, 1)
Expand Down Expand Up @@ -281,6 +283,8 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc.Close()
return nil, ErrClientConnTimeout
}
// If balancer is nil or balancer.Notify() is nil, ok will be false here.
// The lbWatcher goroutine will not be created.
if ok {
go cc.lbWatcher()
}
Expand Down Expand Up @@ -429,10 +433,16 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err
// skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil {
ac.cc.mu.Lock()
delete(ac.cc.conns, ac.addr)
ac.cc.mu.Unlock()
ac.tearDown(err)
if err != errConnClosing {
// Tear down ac and delete it from cc.conns.
cc.mu.Lock()
delete(cc.conns, ac.addr)
cc.mu.Unlock()
ac.tearDown(err)
}
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return e.Origin()
}
return err
}
// Start to monitor the error status of transport.
Expand All @@ -442,10 +452,10 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err
go func() {
if err := ac.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
ac.cc.mu.Lock()
delete(ac.cc.conns, ac.addr)
ac.cc.mu.Unlock()
ac.tearDown(err)
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return
}
ac.transportMonitor()
Expand All @@ -455,22 +465,42 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err
}

func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
addr, put, err := cc.dopts.balancer.Get(ctx, opts)
if err != nil {
return nil, nil, toRPCErr(err)
}
cc.mu.RLock()
if cc.conns == nil {
var (
ac *addrConn
ok bool
put func()
)
if cc.dopts.balancer == nil {
// If balancer is nil, there should be only one addrConn available.
cc.mu.RLock()
for _, ac = range cc.conns {
// Break after the first iteration to get the first addrConn.
ok = true
break
}
cc.mu.RUnlock()
} else {
var (
addr Address
err error
)
addr, put, err = cc.dopts.balancer.Get(ctx, opts)
if err != nil {
return nil, nil, toRPCErr(err)
}
cc.mu.RLock()
if cc.conns == nil {
cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
}
ac, ok = cc.conns[addr]
cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
}
ac, ok := cc.conns[addr]
cc.mu.RUnlock()
if !ok {
if put != nil {
put()
}
return nil, nil, Errorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
return nil, nil, errConnClosing
}
t, err := ac.wait(ctx, !opts.BlockingWait)
if err != nil {
Expand All @@ -494,7 +524,9 @@ func (cc *ClientConn) Close() error {
conns := cc.conns
cc.conns = nil
cc.mu.Unlock()
cc.dopts.balancer.Close()
if cc.dopts.balancer != nil {
cc.dopts.balancer.Close()
}
for _, ac := range conns {
ac.tearDown(ErrClientConnClosing)
}
Expand All @@ -519,6 +551,9 @@ type addrConn struct {
// due to timeout.
ready chan struct{}
transport transport.ClientTransport

// The reason this addrConn is torn down.
tearDownErr error
}

// printf records an event in ac's event log, unless ac has been closed.
Expand Down Expand Up @@ -605,6 +640,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
if err != nil {
cancel()

if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return err
}
ac.mu.Lock()
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
Expand Down Expand Up @@ -644,7 +682,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
close(ac.ready)
ac.ready = nil
}
ac.down = ac.cc.dopts.balancer.Up(ac.addr)
if ac.cc.dopts.balancer != nil {
ac.down = ac.cc.dopts.balancer.Up(ac.addr)
}
ac.mu.Unlock()
return nil
}
Expand Down Expand Up @@ -705,6 +745,10 @@ func (ac *addrConn) transportMonitor() {
ac.printf("transport exiting: %v", err)
ac.mu.Unlock()
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return
}
}
Expand All @@ -718,8 +762,9 @@ func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTr
ac.mu.Lock()
switch {
case ac.state == Shutdown:
err := ac.tearDownErr
ac.mu.Unlock()
return nil, errConnClosing
return nil, err
case ac.state == Ready:
ct := ac.transport
ac.mu.Unlock()
Expand Down Expand Up @@ -769,6 +814,7 @@ func (ac *addrConn) tearDown(err error) {
return
}
ac.state = Shutdown
ac.tearDownErr = err
ac.stateCV.Broadcast()
if ac.events != nil {
ac.events.Finish()
Expand Down
97 changes: 88 additions & 9 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ type env struct {
network string // The type of network such as tcp, unix, etc.
security string // The security protocol such as TLS, SSH, etc.
httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS
balancer bool // whether to use balancer
}

func (e env) runnable() bool {
Expand All @@ -319,12 +320,13 @@ func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) {
}

var (
tcpClearEnv = env{name: "tcp-clear", network: "tcp"}
tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"}
unixClearEnv = env{name: "unix-clear", network: "unix"}
unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls"}
handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true}
allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv}
tcpClearEnv = env{name: "tcp-clear", network: "tcp", balancer: true}
tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls", balancer: true}
unixClearEnv = env{name: "unix-clear", network: "unix", balancer: true}
unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls", balancer: true}
handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true, balancer: true}
noBalancerEnv = env{name: "no-balancer", network: "tcp", security: "tls", balancer: false}
allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv, noBalancerEnv}
)

var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', 'unix-tls', or 'handler-tls' to only run the tests for that environment. Empty means all.")
Expand Down Expand Up @@ -439,12 +441,15 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
if err != nil {
te.t.Fatalf("Failed to listen: %v", err)
}
if te.e.security == "tls" {
switch te.e.security {
case "tls":
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
te.t.Fatalf("Failed to generate credentials %v", err)
}
sopts = append(sopts, grpc.Creds(creds))
case "clientAlwaysFailCred":
sopts = append(sopts, grpc.Creds(clientAlwaysFailCred{}))
}
s := grpc.NewServer(sopts...)
te.srv = s
Expand Down Expand Up @@ -487,15 +492,21 @@ func (te *test) clientConn() *grpc.ClientConn {
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
)
}
if te.e.security == "tls" {
switch te.e.security {
case "tls":
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil {
te.t.Fatalf("Failed to load credentials: %v", err)
}
opts = append(opts, grpc.WithTransportCredentials(creds))
} else {
case "clientAlwaysFailCred":
opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}))
default:
opts = append(opts, grpc.WithInsecure())
}
if te.e.balancer {
opts = append(opts, grpc.WithBalancer(grpc.RoundRobin(nil)))
}
var err error
te.cc, err = grpc.Dial(te.srvAddr, opts...)
if err != nil {
Expand Down Expand Up @@ -2270,6 +2281,74 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) {
})
}

const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails"

var clientAlwaysFailCredError = errors.New(clientAlwaysFailCredErrorMsg)

type clientAlwaysFailCred struct{}

func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return nil, nil, clientAlwaysFailCredError
}
func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, nil, nil
}
func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}

func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true})
te.startServer(&testServer{security: "clientAlwaysFailCred"})
defer te.tearDown()

var (
err error
opts []grpc.DialOption
)
opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}), grpc.WithBlock())
te.cc, err = grpc.Dial(te.srvAddr, opts...)
if err != clientAlwaysFailCredError {
te.t.Fatalf("Dial(%q) = %v, want %v", te.srvAddr, err, clientAlwaysFailCredError)
}
}

func TestFailFastRPCErrorOnBadCertificates(t *testing.T) {
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true})
te.startServer(&testServer{security: "clientAlwaysFailCred"})
defer te.tearDown()

cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg)
}
}

func TestFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) {
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: false})
te.startServer(&testServer{security: "clientAlwaysFailCred"})
defer te.tearDown()

cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg)
}
}

func TestNonFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) {
te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: false})
te.startServer(&testServer{security: "clientAlwaysFailCred"})
defer te.tearDown()

cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
te.t.Fatalf("Dial(%q) = %v, want err.Error() contains %q", te.srvAddr, err, clientAlwaysFailCredErrorMsg)
}
}

// interestingGoroutines returns all goroutines we care about for the purpose
// of leak checking. It excludes testing or runtime ones.
func interestingGoroutines() (gs []string) {
Expand Down
Loading

0 comments on commit 35896af

Please sign in to comment.