diff --git a/balancer_switching_test.go b/balancer_switching_test.go index 943f470b648c..abb403332dcf 100644 --- a/balancer_switching_test.go +++ b/balancer_switching_test.go @@ -19,17 +19,19 @@ package grpc import ( + "fmt" "math" "testing" "time" "golang.org/x/net/context" + _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/test/leakcheck" ) -func checkPickFirst(cc *ClientConn, servers []*server, t *testing.T) { +func checkPickFirst(cc *ClientConn, servers []*server) error { var ( req = "port" reply string @@ -38,14 +40,14 @@ func checkPickFirst(cc *ClientConn, servers []*server, t *testing.T) { // The second RPC should succeed with the first server. for i := 0; i < 1000; i++ { if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { - return + return nil } time.Sleep(time.Millisecond) } - t.Fatalf("EmptyCall() = _, %v, want _, %v", err, servers[0].port) + return fmt.Errorf("EmptyCall() = _, %v, want _, %v", err, servers[0].port) } -func checkRoundRobin(cc *ClientConn, servers []*server, t *testing.T) { +func checkRoundRobin(cc *ClientConn, servers []*server) error { var ( req = "port" reply string @@ -53,12 +55,21 @@ func checkRoundRobin(cc *ClientConn, servers []*server, t *testing.T) { ) // Make sure connections to all servers are up. - for _, s := range servers { - for { - if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == s.port { - break + for i := 0; i < 2; i++ { + // Do this check twice, otherwise the first RPC's transport may still be + // picked by the closing pickfirst balancer, and the test becomes flaky. + for _, s := range servers { + var up bool + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == s.port { + up = true + break + } + time.Sleep(time.Millisecond) + } + if !up { + return fmt.Errorf("server %v is not up within 1 second", s.port) } - time.Sleep(time.Millisecond) } } @@ -66,9 +77,10 @@ func checkRoundRobin(cc *ClientConn, servers []*server, t *testing.T) { for i := 0; i < 3*serverCount; i++ { err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc) if ErrorDesc(err) != servers[i%serverCount].port { - t.Fatalf("Index %d: want peer %v, got peer %v", i, servers[i%serverCount].port, err) + return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[i%serverCount].port, err) } } + return nil } func TestSwitchBalancer(t *testing.T) { @@ -87,11 +99,17 @@ func TestSwitchBalancer(t *testing.T) { defer cc.Close() r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) // The default balancer is pickfirst. - checkPickFirst(cc, servers, t) + if err := checkPickFirst(cc, servers); err != nil { + t.Fatalf("check pickfirst returned non-nil error: %v", err) + } // Switch to roundrobin. cc.switchBalancer("roundrobin") - checkRoundRobin(cc, servers, t) + if err := checkRoundRobin(cc, servers); err != nil { + t.Fatalf("check roundrobin returned non-nil error: %v", err) + } // Switch to pickfirst. cc.switchBalancer("pickfirst") - checkPickFirst(cc, servers, t) + if err := checkPickFirst(cc, servers); err != nil { + t.Fatalf("check pickfirst returned non-nil error: %v", err) + } }