diff --git a/balancer/balancer.go b/balancer/balancer.go index 84e10b630e75..2ce9a346ae86 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -33,8 +33,6 @@ import ( var ( // m is a map from name to balancer builder. m = make(map[string]Builder) - // defaultBuilder is the default balancer to use. - defaultBuilder Builder // TODO(bar) install pickfirst as default. ) // Register registers the balancer builder to the balancer map. @@ -44,13 +42,12 @@ func Register(b Builder) { } // Get returns the resolver builder registered with the given name. -// If no builder is register with the name, the default pickfirst will -// be used. +// If no builder is register with the name, nil will be returned. func Get(name string) Builder { if b, ok := m[name]; ok { return b } - return defaultBuilder + return nil } // SubConn represents a gRPC sub connection. diff --git a/balancer/roundrobin/roundrobin_test.go b/balancer/roundrobin/roundrobin_test.go index 3b4e1305dadd..0335e4865af2 100644 --- a/balancer/roundrobin/roundrobin_test.go +++ b/balancer/roundrobin/roundrobin_test.go @@ -16,7 +16,7 @@ * */ -package roundrobin +package roundrobin_test import ( "fmt" @@ -27,6 +27,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" + "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/peer" @@ -36,6 +37,8 @@ import ( "google.golang.org/grpc/test/leakcheck" ) +var rr = balancer.Get("roundrobin") + type testServer struct { testpb.TestServiceServer } @@ -99,7 +102,7 @@ func TestOneBackend(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -131,7 +134,7 @@ func TestBackendsRoundRobin(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -190,7 +193,7 @@ func TestAddressesRemoved(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -232,7 +235,7 @@ func TestCloseWithPendingRPC(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -266,7 +269,7 @@ func TestNewAddressWhileBlocking(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -315,7 +318,7 @@ func TestOneServerDown(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) if err != nil { t.Fatalf("failed to dial: %v", err) } @@ -408,7 +411,7 @@ func TestAllServersDown(t *testing.T) { } defer test.cleanup() - cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr)) if err != nil { t.Fatalf("failed to dial: %v", err) } diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go index e4a95fd5cd25..8ec74920f651 100644 --- a/balancer_conn_wrappers.go +++ b/balancer_conn_wrappers.go @@ -73,7 +73,7 @@ func (b *scStateUpdateBuffer) load() { } } -// get returns the channel that receives a recvMsg in the buffer. +// get returns the channel that the scStateUpdate will be sent to. // // Upon receiving, the caller should call load to send another // scStateChangeTuple onto the channel if there is any. @@ -96,6 +96,8 @@ type ccBalancerWrapper struct { stateChangeQueue *scStateUpdateBuffer resolverUpdateCh chan *resolverUpdate done chan struct{} + + subConns map[*acBalancerWrapper]struct{} } func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.BuildOptions) *ccBalancerWrapper { @@ -104,6 +106,7 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui stateChangeQueue: newSCStateUpdateBuffer(), resolverUpdateCh: make(chan *resolverUpdate, 1), done: make(chan struct{}), + subConns: make(map[*acBalancerWrapper]struct{}), } go ccb.watcher() ccb.balancer = b.Build(ccb, bopts) @@ -117,8 +120,20 @@ func (ccb *ccBalancerWrapper) watcher() { select { case t := <-ccb.stateChangeQueue.get(): ccb.stateChangeQueue.load() + select { + case <-ccb.done: + ccb.balancer.Close() + return + default: + } ccb.balancer.HandleSubConnStateChange(t.sc, t.state) case t := <-ccb.resolverUpdateCh: + select { + case <-ccb.done: + ccb.balancer.Close() + return + default: + } ccb.balancer.HandleResolvedAddrs(t.addrs, t.err) case <-ccb.done: } @@ -126,6 +141,9 @@ func (ccb *ccBalancerWrapper) watcher() { select { case <-ccb.done: ccb.balancer.Close() + for acbw := range ccb.subConns { + ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) + } return default: } @@ -171,7 +189,10 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer return nil, err } acbw := &acBalancerWrapper{ac: ac} + acbw.ac.mu.Lock() ac.acbw = acbw + acbw.ac.mu.Unlock() + ccb.subConns[acbw] = struct{}{} return acbw, nil } @@ -181,6 +202,7 @@ func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { if !ok { return } + delete(ccb.subConns, acbw) ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) } diff --git a/balancer_switching_test.go b/balancer_switching_test.go new file mode 100644 index 000000000000..d185c1653f15 --- /dev/null +++ b/balancer_switching_test.go @@ -0,0 +1,133 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "fmt" + "math" + "testing" + "time" + + "golang.org/x/net/context" + _ "google.golang.org/grpc/grpclog/glogger" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/test/leakcheck" +) + +func checkPickFirst(cc *ClientConn, servers []*server) error { + var ( + req = "port" + reply string + err error + ) + connected := false + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); ErrorDesc(err) == servers[0].port { + if connected { + // connected is set to false if peer is not server[0]. So if + // connected is true here, this is the second time we saw + // server[0] in a row. Break because pickfirst is in effect. + break + } + connected = true + } else { + connected = false + } + time.Sleep(time.Millisecond) + } + if !connected { + return fmt.Errorf("pickfirst is not in effect after 1 second, EmptyCall() = _, %v, want _, %v", err, servers[0].port) + } + // The following RPCs should all succeed with the first server. + for i := 0; i < 3; i++ { + err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc) + if ErrorDesc(err) != servers[0].port { + return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[0].port, err) + } + } + return nil +} + +func checkRoundRobin(cc *ClientConn, servers []*server) error { + var ( + req = "port" + reply string + err error + ) + + // Make sure connections to all servers are up. + for i := 0; i < 2; i++ { + // Do this check twice, otherwise the first RPC's transport may still be + // picked by the closing pickfirst balancer, and the test becomes flaky. + for _, s := range servers { + var up bool + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); ErrorDesc(err) == s.port { + up = true + break + } + time.Sleep(time.Millisecond) + } + if !up { + return fmt.Errorf("server %v is not up within 1 second", s.port) + } + } + } + + serverCount := len(servers) + for i := 0; i < 3*serverCount; i++ { + err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc) + if ErrorDesc(err) != servers[i%serverCount].port { + return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[i%serverCount].port, err) + } + } + return nil +} + +func TestSwitchBalancer(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 2 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) + // The default balancer is pickfirst. + if err := checkPickFirst(cc, servers); err != nil { + t.Fatalf("check pickfirst returned non-nil error: %v", err) + } + // Switch to roundrobin. + cc.switchBalancer("roundrobin") + if err := checkRoundRobin(cc, servers); err != nil { + t.Fatalf("check roundrobin returned non-nil error: %v", err) + } + // Switch to pickfirst. + cc.switchBalancer("pickfirst") + if err := checkPickFirst(cc, servers); err != nil { + t.Fatalf("check pickfirst returned non-nil error: %v", err) + } +} diff --git a/balancer_test.go b/balancer_test.go index 29dbe0a67671..a1558f027a59 100644 --- a/balancer_test.go +++ b/balancer_test.go @@ -31,6 +31,10 @@ import ( _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/naming" "google.golang.org/grpc/test/leakcheck" + + // V1 balancer tests use passthrough resolver instead of dns. + // TODO(bar) remove this when removing v1 balaner entirely. + _ "google.golang.org/grpc/resolver/passthrough" ) type testWatcher struct { @@ -117,7 +121,7 @@ func TestNameDiscovery(t *testing.T) { numServers := 2 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -151,7 +155,7 @@ func TestEmptyAddrs(t *testing.T) { defer leakcheck.Check(t) servers, r, cleanup := startServers(t, 1, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -185,7 +189,7 @@ func TestRoundRobin(t *testing.T) { numServers := 3 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -230,7 +234,7 @@ func TestCloseWithPendingRPC(t *testing.T) { defer leakcheck.Check(t) servers, r, cleanup := startServers(t, 1, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -282,7 +286,7 @@ func TestGetOnWaitChannel(t *testing.T) { defer leakcheck.Check(t) servers, r, cleanup := startServers(t, 1, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -328,7 +332,7 @@ func TestOneServerDown(t *testing.T) { numServers := 2 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -381,7 +385,7 @@ func TestOneAddressRemoval(t *testing.T) { numServers := 2 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -439,7 +443,7 @@ func TestOneAddressRemoval(t *testing.T) { func checkServerUp(t *testing.T, currentServer *server) { req := "port" port := currentServer.port - cc, err := Dial("localhost:"+port, WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///localhost:"+port, WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -457,7 +461,7 @@ func TestPickFirstEmptyAddrs(t *testing.T) { defer leakcheck.Check(t) servers, r, cleanup := startServers(t, 1, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -489,7 +493,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) { defer leakcheck.Check(t) servers, r, cleanup := startServers(t, 1, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -543,7 +547,7 @@ func TestPickFirstOrderAllServerUp(t *testing.T) { numServers := 3 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -656,7 +660,7 @@ func TestPickFirstOrderOneServerDown(t *testing.T) { numServers := 3 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -747,7 +751,7 @@ func TestPickFirstOneAddressRemoval(t *testing.T) { numServers := 2 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("localhost:"+servers[0].port, WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("passthrough:///localhost:"+servers[0].port, WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } diff --git a/balancer_v1_wrapper.go b/balancer_v1_wrapper.go index 9d0616080a1b..b44c989cd3c1 100644 --- a/balancer_v1_wrapper.go +++ b/balancer_v1_wrapper.go @@ -19,6 +19,7 @@ package grpc import ( + "strings" "sync" "golang.org/x/net/context" @@ -34,20 +35,27 @@ type balancerWrapperBuilder struct { } func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - bwb.b.Start(cc.Target(), BalancerConfig{ + targetAddr := cc.Target() + targetSplitted := strings.Split(targetAddr, ":///") + if len(targetSplitted) >= 2 { + targetAddr = targetSplitted[1] + } + + bwb.b.Start(targetAddr, BalancerConfig{ DialCreds: opts.DialCreds, Dialer: opts.Dialer, }) _, pickfirst := bwb.b.(*pickFirst) bw := &balancerWrapper{ - balancer: bwb.b, - pickfirst: pickfirst, - cc: cc, - startCh: make(chan struct{}), - conns: make(map[resolver.Address]balancer.SubConn), - connSt: make(map[balancer.SubConn]*scState), - csEvltr: &connectivityStateEvaluator{}, - state: connectivity.Idle, + balancer: bwb.b, + pickfirst: pickfirst, + cc: cc, + targetAddr: targetAddr, + startCh: make(chan struct{}), + conns: make(map[resolver.Address]balancer.SubConn), + connSt: make(map[balancer.SubConn]*scState), + csEvltr: &connectivityStateEvaluator{}, + state: connectivity.Idle, } cc.UpdateBalancerState(connectivity.Idle, bw) go bw.lbWatcher() @@ -68,7 +76,8 @@ type balancerWrapper struct { balancer Balancer // The v1 balancer. pickfirst bool - cc balancer.ClientConn + cc balancer.ClientConn + targetAddr string // Target without the scheme. // To aggregate the connectivity state. csEvltr *connectivityStateEvaluator @@ -93,7 +102,7 @@ func (bw *balancerWrapper) lbWatcher() { if notifyCh == nil { // There's no resolver in the balancer. Connect directly. a := resolver.Address{ - Addr: bw.cc.Target(), + Addr: bw.targetAddr, Type: resolver.Backend, } sc, err := bw.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) @@ -103,7 +112,7 @@ func (bw *balancerWrapper) lbWatcher() { bw.mu.Lock() bw.conns[a] = sc bw.connSt[sc] = &scState{ - addr: Address{Addr: bw.cc.Target()}, + addr: Address{Addr: bw.targetAddr}, s: connectivity.Idle, } bw.mu.Unlock() diff --git a/clientconn.go b/clientconn.go index a34bd987cb8c..83563e94b3fa 100644 --- a/clientconn.go +++ b/clientconn.go @@ -31,11 +31,13 @@ import ( "golang.org/x/net/context" "golang.org/x/net/trace" "google.golang.org/grpc/balancer" + _ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin. "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/resolver" + _ "google.golang.org/grpc/resolver/dns" // To register dns resolver. "google.golang.org/grpc/stats" "google.golang.org/grpc/transport" ) @@ -435,42 +437,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * cc.authority = target } - if cc.dopts.balancerBuilder != nil { - var credsClone credentials.TransportCredentials - if creds != nil { - credsClone = creds.Clone() - } - buildOpts := balancer.BuildOptions{ - DialCreds: credsClone, - Dialer: cc.dopts.copts.Dialer, - } - // Build should not take long time. So it's ok to not have a goroutine for it. - // TODO(bar) init balancer after first resolver result to support service config balancer. - cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, buildOpts) - } else { - waitC := make(chan error, 1) - go func() { - defer close(waitC) - // No balancer, or no resolver within the balancer. Connect directly. - ac, err := cc.newAddrConn([]resolver.Address{{Addr: target}}) - if err != nil { - waitC <- err - return - } - if err := ac.connect(cc.dopts.block); err != nil { - waitC <- err - return - } - }() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case err := <-waitC: - if err != nil { - return nil, err - } - } - } if cc.dopts.scChan != nil && !scSet { // Blocking wait for the initial service config. select { @@ -486,20 +452,27 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * go cc.scWatcher() } + var credsClone credentials.TransportCredentials + if creds := cc.dopts.copts.TransportCredentials; creds != nil { + credsClone = creds.Clone() + } + cc.balancerBuildOpts = balancer.BuildOptions{ + DialCreds: credsClone, + Dialer: cc.dopts.copts.Dialer, + } + + if cc.dopts.balancerBuilder != nil { + cc.customBalancer = true + // Build should not take long time. So it's ok to not have a goroutine for it. + cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, cc.balancerBuildOpts) + } + // Build the resolver. cc.resolverWrapper, err = newCCResolverWrapper(cc) if err != nil { return nil, fmt.Errorf("failed to build resolver: %v", err) } - if cc.balancerWrapper != nil && cc.resolverWrapper == nil { - // TODO(bar) there should always be a resolver (DNS as the default). - // Unblock balancer initialization with a fake resolver update if there's no resolver. - // The balancer wrapper will not read the addresses, so an empty list works. - // TODO(bar) remove this after the real resolver is started. - cc.balancerWrapper.handleResolvedAddrs([]resolver.Address{}, nil) - } - // A blocking dial blocks until the clientConn is ready. if cc.dopts.block { for { @@ -570,16 +543,19 @@ type ClientConn struct { dopts dialOptions csMgr *connectivityStateManager - balancerWrapper *ccBalancerWrapper - resolverWrapper *ccResolverWrapper - - blockingpicker *pickerWrapper + customBalancer bool // If this is true, switching balancer will be disabled. + balancerBuildOpts balancer.BuildOptions + resolverWrapper *ccResolverWrapper + blockingpicker *pickerWrapper mu sync.RWMutex sc ServiceConfig conns map[*addrConn]struct{} // Keepalive parameter can be updated if a GoAway is received. - mkp keepalive.ClientParameters + mkp keepalive.ClientParameters + curBalancerName string + curAddresses []resolver.Address + balancerWrapper *ccBalancerWrapper } // WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or @@ -622,6 +598,71 @@ func (cc *ClientConn) scWatcher() { } } +func (cc *ClientConn) handleResolvedAddrs(addrs []resolver.Address, err error) { + cc.mu.Lock() + defer cc.mu.Unlock() + if cc.conns == nil { + return + } + + // TODO(bar switching) when grpclb is submitted, check address type and start grpclb. + if !cc.customBalancer && cc.balancerWrapper == nil { + // No customBalancer was specified by DialOption, and this is the first + // time handling resolved addresses, create a pickfirst balancer. + builder := newPickfirstBuilder() + cc.curBalancerName = builder.Name() + cc.balancerWrapper = newCCBalancerWrapper(cc, builder, cc.balancerBuildOpts) + } + + // TODO(bar switching) compare addresses, if there's no update, don't notify balancer. + cc.curAddresses = addrs + cc.balancerWrapper.handleResolvedAddrs(addrs, nil) +} + +// switchBalancer starts the switching from current balancer to the balancer with name. +func (cc *ClientConn) switchBalancer(name string) { + cc.mu.Lock() + defer cc.mu.Unlock() + if cc.conns == nil { + return + } + grpclog.Infof("ClientConn switching balancer to %q", name) + + if cc.customBalancer { + grpclog.Infoln("ignoring service config balancer configuration: WithBalancer DialOption used instead") + return + } + + if cc.curBalancerName == name { + return + } + + // TODO(bar switching) change this to two steps: drain and close. + // Keep track of sc in wrapper. + cc.balancerWrapper.close() + + builder := balancer.Get(name) + if builder == nil { + grpclog.Infof("failed to get balancer builder for: %v (this should never happen...)", name) + builder = newPickfirstBuilder() + } + cc.curBalancerName = builder.Name() + cc.balancerWrapper = newCCBalancerWrapper(cc, builder, cc.balancerBuildOpts) + cc.balancerWrapper.handleResolvedAddrs(cc.curAddresses, nil) +} + +func (cc *ClientConn) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + cc.mu.Lock() + if cc.conns == nil { + cc.mu.Unlock() + return + } + // TODO(bar switching) send updates to all balancer wrappers when balancer + // gracefully switching is supported. + cc.balancerWrapper.handleSubConnStateChange(sc, s) + cc.mu.Unlock() +} + // newAddrConn creates an addrConn for addrs and adds it to cc.conns. func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) { ac := &addrConn{ @@ -670,11 +711,7 @@ func (ac *addrConn) connect(block bool) error { return nil } ac.state = connectivity.Connecting - if ac.cc.balancerWrapper != nil { - ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) - } else { - ac.cc.csMgr.updateState(ac.state) - } + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) ac.mu.Unlock() if block { @@ -756,31 +793,6 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { } func (cc *ClientConn) getTransport(ctx context.Context, failfast bool) (transport.ClientTransport, func(balancer.DoneInfo), error) { - if cc.balancerWrapper == nil { - // If balancer is nil, there should be only one addrConn available. - cc.mu.RLock() - if cc.conns == nil { - cc.mu.RUnlock() - // TODO this function returns toRPCErr and non-toRPCErr. Clean up - // the errors in ClientConn. - return nil, nil, toRPCErr(ErrClientConnClosing) - } - var ac *addrConn - for ac = range cc.conns { - // Break after the first iteration to get the first addrConn. - break - } - cc.mu.RUnlock() - if ac == nil { - return nil, nil, errConnClosing - } - t, err := ac.wait(ctx, false /*hasBalancer*/, failfast) - if err != nil { - return nil, nil, err - } - return t, nil, nil - } - t, done, err := cc.blockingpicker.pick(ctx, failfast, balancer.PickOptions{}) if err != nil { return nil, nil, toRPCErr(err) @@ -800,13 +812,18 @@ func (cc *ClientConn) Close() error { conns := cc.conns cc.conns = nil cc.csMgr.updateState(connectivity.Shutdown) + + rWrapper := cc.resolverWrapper + cc.resolverWrapper = nil + bWrapper := cc.balancerWrapper + cc.balancerWrapper = nil cc.mu.Unlock() cc.blockingpicker.close() - if cc.resolverWrapper != nil { - cc.resolverWrapper.close() + if rWrapper != nil { + rWrapper.close() } - if cc.balancerWrapper != nil { - cc.balancerWrapper.close() + if bWrapper != nil { + bWrapper.close() } for ac := range conns { ac.tearDown(ErrClientConnClosing) @@ -877,11 +894,7 @@ func (ac *addrConn) resetTransport() error { return errConnClosing } ac.state = connectivity.TransientFailure - if ac.cc.balancerWrapper != nil { - ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) - } else { - ac.cc.csMgr.updateState(ac.state) - } + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) if ac.ready != nil { close(ac.ready) ac.ready = nil @@ -906,12 +919,7 @@ func (ac *addrConn) resetTransport() error { } ac.printf("connecting") ac.state = connectivity.Connecting - // TODO(bar) remove condition once we always have a balancer. - if ac.cc.balancerWrapper != nil { - ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) - } else { - ac.cc.csMgr.updateState(ac.state) - } + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) // copy ac.addrs in case of race addrsIter := make([]resolver.Address, len(ac.addrs)) copy(addrsIter, ac.addrs) @@ -953,11 +961,7 @@ func (ac *addrConn) resetTransport() error { return errConnClosing } ac.state = connectivity.Ready - if ac.cc.balancerWrapper != nil { - ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) - } else { - ac.cc.csMgr.updateState(ac.state) - } + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) t := ac.transport ac.transport = newTransport if t != nil { @@ -973,11 +977,7 @@ func (ac *addrConn) resetTransport() error { } ac.mu.Lock() ac.state = connectivity.TransientFailure - if ac.cc.balancerWrapper != nil { - ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) - } else { - ac.cc.csMgr.updateState(ac.state) - } + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) if ac.ready != nil { close(ac.ready) ac.ready = nil @@ -1111,11 +1111,7 @@ func (ac *addrConn) tearDown(err error) { } ac.state = connectivity.Shutdown ac.tearDownErr = err - if ac.cc.balancerWrapper != nil { - ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) - } else { - ac.cc.csMgr.updateState(ac.state) - } + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) if ac.events != nil { ac.events.Finish() ac.events = nil diff --git a/clientconn_test.go b/clientconn_test.go index 47801e9625f8..c0b0ba436643 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/naming" + _ "google.golang.org/grpc/resolver/passthrough" "google.golang.org/grpc/test/leakcheck" "google.golang.org/grpc/testdata" ) @@ -47,7 +48,7 @@ func TestConnectivityStates(t *testing.T) { defer leakcheck.Check(t) servers, resolver, cleanup := startServers(t, 2, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(resolver)), WithInsecure()) + cc, err := Dial("passthrough:///foo.bar.com", WithBalancer(RoundRobin(resolver)), WithInsecure()) if err != nil { t.Fatalf("Dial(\"foo.bar.com\", WithBalancer(_)) = _, %v, want _ ", err) } @@ -82,7 +83,7 @@ func TestConnectivityStates(t *testing.T) { func TestDialTimeout(t *testing.T) { defer leakcheck.Check(t) - conn, err := Dial("Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure()) + conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure()) if err == nil { conn.Close() } @@ -97,7 +98,7 @@ func TestTLSDialTimeout(t *testing.T) { if err != nil { t.Fatalf("Failed to create credentials %v", err) } - conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock()) + conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock()) if err == nil { conn.Close() } @@ -113,7 +114,7 @@ func TestDefaultAuthority(t *testing.T) { if err != nil { t.Fatalf("Dial(_, _) = _, %v, want _, ", err) } - conn.Close() + defer conn.Close() if conn.authority != target { t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, target) } @@ -126,11 +127,11 @@ func TestTLSServerNameOverwrite(t *testing.T) { if err != nil { t.Fatalf("Failed to create credentials %v", err) } - conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds)) + conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds)) if err != nil { t.Fatalf("Dial(_, _) = _, %v, want _, ", err) } - conn.Close() + defer conn.Close() if conn.authority != overwriteServerName { t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) } @@ -139,11 +140,11 @@ func TestTLSServerNameOverwrite(t *testing.T) { func TestWithAuthority(t *testing.T) { defer leakcheck.Check(t) overwriteServerName := "over.write.server.name" - conn, err := Dial("Non-Existent.Server:80", WithInsecure(), WithAuthority(overwriteServerName)) + conn, err := Dial("passthrough:///Non-Existent.Server:80", WithInsecure(), WithAuthority(overwriteServerName)) if err != nil { t.Fatalf("Dial(_, _) = _, %v, want _, ", err) } - conn.Close() + defer conn.Close() if conn.authority != overwriteServerName { t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) } @@ -156,11 +157,11 @@ func TestWithAuthorityAndTLS(t *testing.T) { if err != nil { t.Fatalf("Failed to create credentials %v", err) } - conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithAuthority("no.effect.authority")) + conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithAuthority("no.effect.authority")) if err != nil { t.Fatalf("Dial(_, _) = _, %v, want _, ", err) } - conn.Close() + defer conn.Close() if conn.authority != overwriteServerName { t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) } @@ -231,11 +232,11 @@ func TestCredentialsMisuse(t *testing.T) { t.Fatalf("Failed to create authenticator %v", err) } // Two conflicting credential configurations - if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithBlock(), WithInsecure()); err != errCredentialsConflict { + if _, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithBlock(), WithInsecure()); err != errCredentialsConflict { t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict) } // security info on insecure connection - if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(securePerRPCCredentials{}), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing { + if _, err := Dial("passthrough:///Non-Existent.Server:80", WithPerRPCCredentials(securePerRPCCredentials{}), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing { t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing) } } @@ -263,10 +264,11 @@ func TestWithBackoffMaxDelay(t *testing.T) { func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOption) { opts = append(opts, WithInsecure()) - conn, err := Dial("foo:80", opts...) + conn, err := Dial("passthrough:///foo:80", opts...) if err != nil { t.Fatalf("unexpected error dialing connection: %v", err) } + defer conn.Close() if conn.dopts.bs == nil { t.Fatalf("backoff config not set") @@ -280,39 +282,6 @@ func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOpt if actual != *expected { t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected) } - conn.Close() -} - -type testErr struct { - temp bool -} - -func (e *testErr) Error() string { - return "test error" -} - -func (e *testErr) Temporary() bool { - return e.temp -} - -var nonTemporaryError = &testErr{false} - -func nonTemporaryErrorDialer(addr string, timeout time.Duration) (net.Conn, error) { - return nil, nonTemporaryError -} - -func TestDialWithBlockErrorOnNonTemporaryErrorDialer(t *testing.T) { - defer leakcheck.Check(t) - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - if _, err := DialContext(ctx, "", WithInsecure(), WithDialer(nonTemporaryErrorDialer), WithBlock(), FailOnNonTempDialError(true)); err != nonTemporaryError { - t.Fatalf("Dial(%q) = %v, want %v", "", err, nonTemporaryError) - } - - // Without FailOnNonTempDialError, gRPC will retry to connect, and dial should exit with time out error. - if _, err := DialContext(ctx, "", WithInsecure(), WithDialer(nonTemporaryErrorDialer), WithBlock()); err != context.DeadlineExceeded { - t.Fatalf("Dial(%q) = %v, want %v", "", err, context.DeadlineExceeded) - } } // emptyBalancer returns an empty set of servers. diff --git a/pickfirst.go b/pickfirst.go index 7f993ef5a381..e4597cb86c75 100644 --- a/pickfirst.go +++ b/pickfirst.go @@ -57,14 +57,20 @@ func (b *pickfirstBalancer) HandleResolvedAddrs(addrs []resolver.Address, err er return } b.cc.UpdateBalancerState(connectivity.Idle, &picker{sc: b.sc}) + b.sc.Connect() } else { b.sc.UpdateAddresses(addrs) + b.sc.Connect() } } func (b *pickfirstBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { grpclog.Infof("pickfirstBalancer: HandleSubConnStateChange: %p, %v", sc, s) - if b.sc != sc || s == connectivity.Shutdown { + if b.sc != sc { + grpclog.Infof("pickfirstBalancer: ignored state change because sc is not recognized") + return + } + if s == connectivity.Shutdown { b.sc = nil return } @@ -93,3 +99,7 @@ func (p *picker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer. } return p.sc, nil, nil } + +func init() { + balancer.Register(newPickfirstBuilder()) +} diff --git a/resolver/resolver.go b/resolver/resolver.go index 49307e8fe9e9..6e822b56bef6 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -24,7 +24,7 @@ var ( // m is a map from scheme to resolver builder. m = make(map[string]Builder) // defaultScheme is the default scheme to use. - defaultScheme string + defaultScheme = "dns" ) // TODO(bar) install dns resolver in init(){}. diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go index 7d53964d094d..2f61714860bc 100644 --- a/resolver_conn_wrapper.go +++ b/resolver_conn_wrapper.go @@ -19,6 +19,7 @@ package grpc import ( + "fmt" "strings" "google.golang.org/grpc/grpclog" @@ -56,19 +57,13 @@ func parseTarget(target string) (ret resolver.Target) { // newCCResolverWrapper parses cc.target for scheme and gets the resolver // builder for this scheme. It then builds the resolver and starts the // monitoring goroutine for it. -// -// This function could return nil, nil, in tests for old behaviors. -// TODO(bar) never return nil, nil when DNS becomes the default resolver. func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { target := parseTarget(cc.target) grpclog.Infof("dialing to target with scheme: %q", target.Scheme) rb := resolver.Get(target.Scheme) if rb == nil { - // TODO(bar) return error when DNS becomes the default (implemented and - // registered by DNS package). - grpclog.Infof("could not get resolver for scheme: %q", target.Scheme) - return nil, nil + return nil, fmt.Errorf("could not get resolver for scheme: %q", target.Scheme) } ccr := &ccResolverWrapper{ @@ -100,13 +95,19 @@ func (ccr *ccResolverWrapper) watcher() { select { case addrs := <-ccr.addrCh: - grpclog.Infof("ccResolverWrapper: sending new addresses to balancer wrapper: %v", addrs) - // TODO(bar switching) this should never be nil. Pickfirst should be default. - if ccr.cc.balancerWrapper != nil { - // TODO(bar switching) create balancer if it's nil? - ccr.cc.balancerWrapper.handleResolvedAddrs(addrs, nil) + select { + case <-ccr.done: + return + default: } + grpclog.Infof("ccResolverWrapper: sending new addresses to cc: %v", addrs) + ccr.cc.handleResolvedAddrs(addrs, nil) case sc := <-ccr.scCh: + select { + case <-ccr.done: + return + default: + } grpclog.Infof("ccResolverWrapper: got new service config: %v", sc) case <-ccr.done: return diff --git a/test/end2end_test.go b/test/end2end_test.go index ab151209a4ae..1e3c4be23217 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -559,8 +559,6 @@ func (te *test) startServer(ts testpb.TestServiceServer) { te.t.Fatalf("Failed to generate credentials %v", err) } sopts = append(sopts, grpc.Creds(creds)) - case "clientAlwaysFailCred": - sopts = append(sopts, grpc.Creds(clientAlwaysFailCred{})) case "clientTimeoutCreds": sopts = append(sopts, grpc.Creds(&clientTimeoutCreds{})) } @@ -634,15 +632,13 @@ func (te *test) clientConn() *grpc.ClientConn { te.t.Fatalf("Failed to load credentials: %v", err) } opts = append(opts, grpc.WithTransportCredentials(creds)) - case "clientAlwaysFailCred": - opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{})) case "clientTimeoutCreds": opts = append(opts, grpc.WithTransportCredentials(&clientTimeoutCreds{})) default: opts = append(opts, grpc.WithInsecure()) } // TODO(bar) switch balancer case "pickfirst". - var scheme string + scheme := "passthrough:///" switch te.e.balancer { case "v1": opts = append(opts, grpc.WithBalancer(grpc.RoundRobin(nil))) @@ -652,7 +648,6 @@ func (te *test) clientConn() *grpc.ClientConn { te.t.Fatalf("got nil when trying to get roundrobin balancer builder") } opts = append(opts, grpc.WithBalancerBuilder(rr)) - scheme = "passthrough:///" } if te.clientInitialWindowSize > 0 { opts = append(opts, grpc.WithInitialWindowSize(te.clientInitialWindowSize)) @@ -670,6 +665,9 @@ func (te *test) clientConn() *grpc.ClientConn { // Only do a blocking dial if server is up. opts = append(opts, grpc.WithBlock()) } + if te.srvAddr == "" { + te.srvAddr = "client.side.only.test" + } var err error te.cc, err = grpc.Dial(scheme+te.srvAddr, opts...) if err != nil { @@ -4068,44 +4066,6 @@ func testClientRequestBodyErrorCancelStreamingInput(t *testing.T, e env) { }) } -const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" - -var errClientAlwaysFailCred = errors.New(clientAlwaysFailCredErrorMsg) - -type clientAlwaysFailCred struct{} - -func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return nil, nil, errClientAlwaysFailCred -} -func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return rawConn, nil, nil -} -func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { - return credentials.ProtocolInfo{} -} -func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials { - return nil -} -func (c clientAlwaysFailCred) OverrideServerName(s string) error { - return nil -} - -func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: "v1"}) - te.startServer(&testServer{security: te.e.security}) - defer te.tearDown() - - var ( - err error - opts []grpc.DialOption - ) - opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}), grpc.WithBlock()) - te.cc, err = grpc.Dial(te.srvAddr, opts...) - if err != errClientAlwaysFailCred { - te.t.Fatalf("Dial(%q) = %v, want %v", te.srvAddr, err, errClientAlwaysFailCred) - } -} - type clientTimeoutCreds struct { timeoutReturned bool }