diff --git a/attributes/attributes.go b/attributes/attributes.go index 3efca4591493..49712aca33ae 100644 --- a/attributes/attributes.go +++ b/attributes/attributes.go @@ -112,19 +112,31 @@ func (a *Attributes) String() string { sb.WriteString("{") first := true for k, v := range a.m { - var key, val string - if str, ok := k.(interface{ String() string }); ok { - key = str.String() - } - if str, ok := v.(interface{ String() string }); ok { - val = str.String() - } if !first { sb.WriteString(", ") } - sb.WriteString(fmt.Sprintf("%q: %q, ", key, val)) + sb.WriteString(fmt.Sprintf("%q: %q ", str(k), str(v))) first = false } sb.WriteString("}") return sb.String() } + +func str(x interface{}) string { + if v, ok := x.(fmt.Stringer); ok { + return v.String() + } else if v, ok := x.(string); ok { + return v + } + return fmt.Sprintf("<%p>", x) +} + +// MarshalJSON helps implement the json.Marshaler interface, thereby rendering +// the Attributes correctly when printing (via pretty.JSON) structs containing +// Attributes as fields. +// +// Is it impossible to unmarshal attributes from a JSON representation and this +// method is meant only for debugging purposes. +func (a *Attributes) MarshalJSON() ([]byte, error) { + return []byte(a.String()), nil +} diff --git a/clientconn.go b/clientconn.go index 04a12c57e42e..bfd7555a8bf2 100644 --- a/clientconn.go +++ b/clientconn.go @@ -37,6 +37,7 @@ import ( "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/pretty" iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/keepalive" @@ -867,6 +868,20 @@ func (cc *ClientConn) handleSubConnStateChange(sc balancer.SubConn, s connectivi cc.balancerWrapper.updateSubConnState(sc, s, err) } +// Makes a copy of the input addresses slice and clears out the balancer +// attributes field. Addresses are passed during subconn creation and address +// update operations. In both cases, we will clear the balancer attributes by +// calling this function, and therefore we will be able to use the Equal method +// provided by the resolver.Address type for comparison. +func copyAddressesWithoutBalancerAttributes(in []resolver.Address) []resolver.Address { + out := make([]resolver.Address, len(in)) + for i := range in { + out[i] = in[i] + out[i].BalancerAttributes = nil + } + return out +} + // newAddrConn creates an addrConn for addrs and adds it to cc.conns. // // Caller needs to make sure len(addrs) > 0. @@ -874,7 +889,7 @@ func (cc *ClientConn) newAddrConn(addrs []resolver.Address, opts balancer.NewSub ac := &addrConn{ state: connectivity.Idle, cc: cc, - addrs: addrs, + addrs: copyAddressesWithoutBalancerAttributes(addrs), scopts: opts, dopts: cc.dopts, czData: new(channelzData), @@ -995,8 +1010,9 @@ func equalAddresses(a, b []resolver.Address) bool { // connections or connection attempts. func (ac *addrConn) updateAddrs(addrs []resolver.Address) { ac.mu.Lock() - channelz.Infof(logger, ac.channelzID, "addrConn: updateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs) + channelz.Infof(logger, ac.channelzID, "addrConn: updateAddrs curAddr: %v, addrs: %v", pretty.ToJSON(ac.curAddr), pretty.ToJSON(addrs)) + addrs = copyAddressesWithoutBalancerAttributes(addrs) if equalAddresses(ac.addrs, addrs) { ac.mu.Unlock() return diff --git a/internal/stubserver/stubserver.go b/internal/stubserver/stubserver.go index 3c89ff6823bd..cfc4b0f2e82d 100644 --- a/internal/stubserver/stubserver.go +++ b/internal/stubserver/stubserver.go @@ -60,6 +60,10 @@ type StubServer struct { Address string Target string + // Custom listener to use for serving. If unspecified, a new listener is + // created on a local port. + Listener net.Listener + cleanups []func() // Lambdas executed in Stop(); populated by Start(). // Set automatically if Target == "" @@ -118,9 +122,13 @@ func (ss *StubServer) StartServer(sopts ...grpc.ServerOption) error { ss.R = manual.NewBuilderWithScheme("whatever") } - lis, err := net.Listen(ss.Network, ss.Address) - if err != nil { - return fmt.Errorf("net.Listen(%q, %q) = %v", ss.Network, ss.Address, err) + lis := ss.Listener + if lis == nil { + var err error + lis, err = net.Listen(ss.Network, ss.Address) + if err != nil { + return fmt.Errorf("net.Listen(%q, %q) = %v", ss.Network, ss.Address, err) + } } ss.Address = lis.Addr().String() ss.cleanups = append(ss.cleanups, func() { lis.Close() }) diff --git a/resolver/resolver.go b/resolver/resolver.go index 1b7f382e3a03..d8db6f5d34eb 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -142,6 +142,10 @@ type Address struct { // Equal returns whether a and o are identical. Metadata is compared directly, // not with any recursive introspection. +// +// This method compares all fields of the address. When used to tell apart +// addresses during subchannel creation or connection establishment, it might be +// more appropriate for the caller to implement custom equality logic. func (a Address) Equal(o Address) bool { return a.Addr == o.Addr && a.ServerName == o.ServerName && a.Attributes.Equal(o.Attributes) && diff --git a/test/pickfirst_test.go b/test/pickfirst_test.go index 55659b928a57..d41786a556a3 100644 --- a/test/pickfirst_test.go +++ b/test/pickfirst_test.go @@ -20,7 +20,7 @@ package test import ( "context" - "sync" + "fmt" "testing" "time" @@ -299,6 +299,11 @@ func (s) TestPickFirst_NewAddressWhileBlocking(t *testing.T) { } } +// TestPickFirst_StickyTransientFailure tests the case where pick_first is +// configured on a channel, and the backend is configured to close incoming +// connections as soon as they are accepted. The test verifies that the channel +// enters TransientFailure and stays there. The test also verifies that the +// pick_first LB policy is constantly trying to reconnect to the backend. func (s) TestPickFirst_StickyTransientFailure(t *testing.T) { // Spin up a local server which closes the connection as soon as it receives // one. It also sends a signal on a channel whenver it received a connection. @@ -346,40 +351,27 @@ func (s) TestPickFirst_StickyTransientFailure(t *testing.T) { } t.Cleanup(func() { cc.Close() }) - var wg sync.WaitGroup - wg.Add(2) - // Spin up a goroutine that waits for the channel to move to - // TransientFailure. After that it checks that the channel stays in - // TransientFailure, until Shutdown. - go func() { - defer wg.Done() - for state := cc.GetState(); state != connectivity.TransientFailure; state = cc.GetState() { - if !cc.WaitForStateChange(ctx, state) { - t.Errorf("Timeout when waiting for state to change to TransientFailure. Current state is %s", state) - return - } - } + awaitState(ctx, t, cc, connectivity.TransientFailure) - // TODO(easwars): this waits for 10s. Need shorter deadline here. Basically once the second goroutine exits, we should exit from here too. + // Spawn a goroutine to ensure that the channel stays in TransientFailure. + // The call to cc.WaitForStateChange will return false when the main + // goroutine exits and the context is cancelled. + go func() { if cc.WaitForStateChange(ctx, connectivity.TransientFailure) { if state := cc.GetState(); state != connectivity.Shutdown { t.Errorf("Unexpected state change from TransientFailure to %s", cc.GetState()) } } }() - // Spin up a goroutine which ensures that the pick_first LB policy is - // constantly trying to reconnect. - go func() { - defer wg.Done() - for i := 0; i < 10; i++ { - select { - case <-connCh: - case <-time.After(2 * defaultTestShortTimeout): - t.Error("Timeout when waiting for pick_first to reconnect") - } + + // Ensures that the pick_first LB policy is constantly trying to reconnect. + for i := 0; i < 10; i++ { + select { + case <-connCh: + case <-time.After(2 * defaultTestShortTimeout): + t.Error("Timeout when waiting for pick_first to reconnect") } - }() - wg.Wait() + } } // Tests the PF LB policy with shuffling enabled. @@ -475,3 +467,221 @@ func (s) TestPickFirst_ShuffleAddressListDisabled(t *testing.T) { t.Fatal(err) } } + +// setupPickFirstWithListenerWrapper is very similar to setupPickFirst, but uses +// a wrapped listener that the test can use to track accepted connections. +func setupPickFirstWithListenerWrapper(t *testing.T, backendCount int, opts ...grpc.DialOption) (*grpc.ClientConn, *manual.Resolver, []*stubserver.StubServer, []*testutils.ListenerWrapper) { + t.Helper() + + // Initialize channelz. Used to determine pending RPC count. + czCleanup := channelz.NewChannelzStorageForTesting() + t.Cleanup(func() { czCleanupWrapper(czCleanup, t) }) + + backends := make([]*stubserver.StubServer, backendCount) + addrs := make([]resolver.Address, backendCount) + listeners := make([]*testutils.ListenerWrapper, backendCount) + for i := 0; i < backendCount; i++ { + lis := testutils.NewListenerWrapper(t, nil) + backend := &stubserver.StubServer{ + Listener: lis, + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + if err := backend.StartServer(); err != nil { + t.Fatalf("Failed to start backend: %v", err) + } + t.Logf("Started TestService backend at: %q", backend.Address) + t.Cleanup(func() { backend.Stop() }) + + backends[i] = backend + addrs[i] = resolver.Address{Addr: backend.Address} + listeners[i] = lis + } + + r := manual.NewBuilderWithScheme("whatever") + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithResolvers(r), + grpc.WithDefaultServiceConfig(pickFirstServiceConfig), + } + dopts = append(dopts, opts...) + cc, err := grpc.Dial(r.Scheme()+":///test.server", dopts...) + if err != nil { + t.Fatalf("grpc.Dial() failed: %v", err) + } + t.Cleanup(func() { cc.Close() }) + + // At this point, the resolver has not returned any addresses to the channel. + // This RPC must block until the context expires. + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + client := testgrpc.NewTestServiceClient(cc) + if _, err := client.EmptyCall(sCtx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = %s, want %s", status.Code(err), codes.DeadlineExceeded) + } + return cc, r, backends, listeners +} + +// TestPickFirst_AddressUpdateWithAttributes tests the case where an address +// update received by the pick_first LB policy differs in attributes. Addresses +// which differ in attributes are considered different from the perspective of +// subconn creation and connection establishment and the test verifies that new +// connections are created when attributes change. +func (s) TestPickFirst_AddressUpdateWithAttributes(t *testing.T) { + cc, r, backends, listeners := setupPickFirstWithListenerWrapper(t, 2) + + // Add a set of attributes to the addresses before pushing them to the + // pick_first LB policy through the manual resolver. + addrs := stubBackendsToResolverAddrs(backends) + for i := range addrs { + addrs[i].Attributes = addrs[i].Attributes.WithValue("test-attribute-1", fmt.Sprintf("%d", i)) + } + r.UpdateState(resolver.State{Addresses: addrs}) + + // Ensure that RPCs succeed to the first backend in the list. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + // Grab the wrapped connection from the listener wrapper. This will be used + // to verify the connection is closed. + val, err := listeners[0].NewConnCh.Receive(ctx) + if err != nil { + t.Fatalf("Failed to receive new connection from wrapped listener: %v", err) + } + conn := val.(*testutils.ConnWrapper) + + // Add another set of attributes to the addresses, and push them to the + // pick_first LB policy through the manual resolver. Leave the order of the + // addresses unchanged. + for i := range addrs { + addrs[i].Attributes = addrs[i].Attributes.WithValue("test-attribute-2", fmt.Sprintf("%d", i)) + } + r.UpdateState(resolver.State{Addresses: addrs}) + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + // A change in the address attributes results in the new address being + // considered different to the current address. This will result in the old + // connection being closed and a new connection to the same backend (since + // address order is not modified). + if _, err := conn.CloseCh.Receive(ctx); err != nil { + t.Fatalf("Timeout when expecting existing connection to be closed: %v", err) + } + val, err = listeners[0].NewConnCh.Receive(ctx) + if err != nil { + t.Fatalf("Failed to receive new connection from wrapped listener: %v", err) + } + conn = val.(*testutils.ConnWrapper) + + // Add another set of attributes to the addresses, and push them to the + // pick_first LB policy through the manual resolver. Reverse of the order + // of addresses. + for i := range addrs { + addrs[i].Attributes = addrs[i].Attributes.WithValue("test-attribute-3", fmt.Sprintf("%d", i)) + } + addrs[0], addrs[1] = addrs[1], addrs[0] + r.UpdateState(resolver.State{Addresses: addrs}) + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + // Ensure that the old connection is closed and a new connection is + // established to the first address in the new list. + if _, err := conn.CloseCh.Receive(ctx); err != nil { + t.Fatalf("Timeout when expecting existing connection to be closed: %v", err) + } + _, err = listeners[1].NewConnCh.Receive(ctx) + if err != nil { + t.Fatalf("Failed to receive new connection from wrapped listener: %v", err) + } +} + +// TestPickFirst_AddressUpdateWithBalancerAttributes tests the case where an +// address update received by the pick_first LB policy differs in balancer +// attributes, which are meant only for consumption by LB policies. In this +// case, the test verifies that new connections are not created when the address +// update only changes the balancer attributes. +func (s) TestPickFirst_AddressUpdateWithBalancerAttributes(t *testing.T) { + cc, r, backends, listeners := setupPickFirstWithListenerWrapper(t, 2) + + // Add a set of balancer attributes to the addresses before pushing them to + // the pick_first LB policy through the manual resolver. + addrs := stubBackendsToResolverAddrs(backends) + for i := range addrs { + addrs[i].BalancerAttributes = addrs[i].BalancerAttributes.WithValue("test-attribute-1", fmt.Sprintf("%d", i)) + } + r.UpdateState(resolver.State{Addresses: addrs}) + + // Ensure that RPCs succeed to the expected backend. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + // Grab the wrapped connection from the listener wrapper. This will be used + // to verify the connection is not closed. + val, err := listeners[0].NewConnCh.Receive(ctx) + if err != nil { + t.Fatalf("Failed to receive new connection from wrapped listener: %v", err) + } + conn := val.(*testutils.ConnWrapper) + + // Add a set of balancer attributes to the addresses before pushing them to + // the pick_first LB policy through the manual resolver. Leave the order of + // the addresses unchanged. + for i := range addrs { + addrs[i].BalancerAttributes = addrs[i].BalancerAttributes.WithValue("test-attribute-2", fmt.Sprintf("%d", i)) + } + r.UpdateState(resolver.State{Addresses: addrs}) + + // Ensure that no new connection is established, and ensure that the old + // connection is not closed. + for i := range listeners { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if _, err := listeners[i].NewConnCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("Unexpected error when expecting no new connection: %v", err) + } + } + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if _, err := conn.CloseCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("Unexpected error when expecting existing connection to stay active: %v", err) + } + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil { + t.Fatal(err) + } + + // Add a set of balancer attributes to the addresses before pushing them to + // the pick_first LB policy through the manual resolver. Reverse of the + // order of addresses. + for i := range addrs { + addrs[i].BalancerAttributes = addrs[i].BalancerAttributes.WithValue("test-attribute-3", fmt.Sprintf("%d", i)) + } + addrs[0], addrs[1] = addrs[1], addrs[0] + r.UpdateState(resolver.State{Addresses: addrs}) + + // Ensure that no new connection is established, and ensure that the old + // connection is not closed. + for i := range listeners { + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if _, err := listeners[i].NewConnCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("Unexpected error when expecting no new connection: %v", err) + } + } + sCtx, sCancel = context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + if _, err := conn.CloseCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("Unexpected error when expecting existing connection to stay active: %v", err) + } + if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil { + t.Fatal(err) + } +}