Skip to content

Commit

Permalink
fix import cycle and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
menghanl committed Oct 3, 2017
1 parent d28c807 commit 88af955
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 20 deletions.
45 changes: 37 additions & 8 deletions balancer/roundrobin/roundrobin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*
*/

package roundrobin
package roundrobin_test

import (
"fmt"
Expand All @@ -27,6 +27,7 @@ import (

"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/peer"
Expand Down Expand Up @@ -99,7 +100,11 @@ func TestOneBackend(t *testing.T) {
}
defer test.cleanup()

cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder()))
rr := balancer.Get("roundrobin")
if rr == nil {
t.Fatalf("got nil when trying to get roundrobin balancer builder")
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
Expand Down Expand Up @@ -131,7 +136,11 @@ func TestBackendsRoundRobin(t *testing.T) {
}
defer test.cleanup()

cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder()))
rr := balancer.Get("roundrobin")
if rr == nil {
t.Fatalf("got nil when trying to get roundrobin balancer builder")
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
Expand Down Expand Up @@ -190,7 +199,11 @@ func TestAddressesRemoved(t *testing.T) {
}
defer test.cleanup()

cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder()))
rr := balancer.Get("roundrobin")
if rr == nil {
t.Fatalf("got nil when trying to get roundrobin balancer builder")
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
Expand Down Expand Up @@ -232,7 +245,11 @@ func TestCloseWithPendingRPC(t *testing.T) {
}
defer test.cleanup()

cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder()))
rr := balancer.Get("roundrobin")
if rr == nil {
t.Fatalf("got nil when trying to get roundrobin balancer builder")
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
Expand Down Expand Up @@ -266,7 +283,11 @@ func TestNewAddressWhileBlocking(t *testing.T) {
}
defer test.cleanup()

cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder()))
rr := balancer.Get("roundrobin")
if rr == nil {
t.Fatalf("got nil when trying to get roundrobin balancer builder")
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
Expand Down Expand Up @@ -315,7 +336,11 @@ func TestOneServerDown(t *testing.T) {
}
defer test.cleanup()

cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder()))
rr := balancer.Get("roundrobin")
if rr == nil {
t.Fatalf("got nil when trying to get roundrobin balancer builder")
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
Expand Down Expand Up @@ -408,7 +433,11 @@ func TestAllServersDown(t *testing.T) {
}
defer test.cleanup()

cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder()))
rr := balancer.Get("roundrobin")
if rr == nil {
t.Fatalf("got nil when trying to get roundrobin balancer builder")
}
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(rr))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
Expand Down
33 changes: 21 additions & 12 deletions balancer_v1_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package grpc

import (
"strings"
"sync"

"golang.org/x/net/context"
Expand All @@ -34,20 +35,27 @@ type balancerWrapperBuilder struct {
}

func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
bwb.b.Start(cc.Target(), BalancerConfig{
targetAddr := cc.Target()
targetSplitted := strings.Split(targetAddr, ":///")
if len(targetSplitted) >= 2 {
targetAddr = targetSplitted[1]
}

bwb.b.Start(targetAddr, BalancerConfig{
DialCreds: opts.DialCreds,
Dialer: opts.Dialer,
})
_, pickfirst := bwb.b.(*pickFirst)
bw := &balancerWrapper{
balancer: bwb.b,
pickfirst: pickfirst,
cc: cc,
startCh: make(chan struct{}),
conns: make(map[resolver.Address]balancer.SubConn),
connSt: make(map[balancer.SubConn]*scState),
csEvltr: &connectivityStateEvaluator{},
state: connectivity.Idle,
balancer: bwb.b,
pickfirst: pickfirst,
cc: cc,
targetAddr: targetAddr,
startCh: make(chan struct{}),
conns: make(map[resolver.Address]balancer.SubConn),
connSt: make(map[balancer.SubConn]*scState),
csEvltr: &connectivityStateEvaluator{},
state: connectivity.Idle,
}
cc.UpdateBalancerState(connectivity.Idle, bw)
go bw.lbWatcher()
Expand All @@ -68,7 +76,8 @@ type balancerWrapper struct {
balancer Balancer // The v1 balancer.
pickfirst bool

cc balancer.ClientConn
cc balancer.ClientConn
targetAddr string // Target without the scheme.

// To aggregate the connectivity state.
csEvltr *connectivityStateEvaluator
Expand All @@ -93,7 +102,7 @@ func (bw *balancerWrapper) lbWatcher() {
if notifyCh == nil {
// There's no resolver in the balancer. Connect directly.
a := resolver.Address{
Addr: bw.cc.Target(),
Addr: bw.targetAddr,
Type: resolver.Backend,
}
sc, err := bw.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{})
Expand All @@ -103,7 +112,7 @@ func (bw *balancerWrapper) lbWatcher() {
bw.mu.Lock()
bw.conns[a] = sc
bw.connSt[sc] = &scState{
addr: Address{Addr: bw.cc.Target()},
addr: Address{Addr: bw.targetAddr},
s: connectivity.Idle,
}
bw.mu.Unlock()
Expand Down

0 comments on commit 88af955

Please sign in to comment.