diff --git a/go/vt/grpcclient/client.go b/go/vt/grpcclient/client.go index b000a542a41..6ad54ae8dea 100644 --- a/go/vt/grpcclient/client.go +++ b/go/vt/grpcclient/client.go @@ -19,6 +19,7 @@ limitations under the License. package grpcclient import ( + "context" "flag" "time" @@ -58,6 +59,16 @@ func RegisterGRPCDialOptions(grpcDialOptionsFunc func(opts []grpc.DialOption) ([ // failFast is a non-optional parameter because callers are required to specify // what that should be. func Dial(target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + return DialContext(context.Background(), target, failFast, opts...) +} + +// DialContext creates a grpc connection to the given target. Setup steps are +// covered by the context deadline, and, if WithBlock is specified in the dial +// options, connection establishment steps are covered by the context as well. +// +// failFast is a non-optional parameter because callers are required to specify +// what that should be. +func DialContext(ctx context.Context, target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.ClientConn, error) { grpccommon.EnableTracingOpt() newopts := []grpc.DialOption{ grpc.WithDefaultCallOptions( @@ -98,7 +109,7 @@ func Dial(target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.Clie newopts = append(newopts, interceptors()...) - return grpc.Dial(target, newopts...) + return grpc.DialContext(ctx, target, newopts...) } func interceptors() []grpc.DialOption { diff --git a/go/vt/vttablet/grpctmclient/cached_client.go b/go/vt/vttablet/grpctmclient/cached_client.go new file mode 100644 index 00000000000..2e55e62a79f --- /dev/null +++ b/go/vt/vttablet/grpctmclient/cached_client.go @@ -0,0 +1,331 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package grpctmclient + +import ( + "context" + "flag" + "io" + "sort" + "sync" + "time" + + "google.golang.org/grpc" + + "vitess.io/vitess/go/netutil" + "vitess.io/vitess/go/stats" + "vitess.io/vitess/go/sync2" + "vitess.io/vitess/go/vt/grpcclient" + "vitess.io/vitess/go/vt/vttablet/tmclient" + + tabletmanagerservicepb "vitess.io/vitess/go/vt/proto/tabletmanagerservice" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" +) + +var ( + defaultPoolCapacity = flag.Int("tablet_manager_grpc_connpool_size", 100, "number of tablets to keep tmclient connections open to") +) + +func init() { + tmclient.RegisterTabletManagerClientFactory("grpc-cached", func() tmclient.TabletManagerClient { + return NewCachedConnClient(*defaultPoolCapacity) + }) +} + +// closeFunc allows a standalone function to implement io.Closer, similar to +// how http.HandlerFunc allows standalone functions to implement http.Handler. +type closeFunc func() error + +func (fn closeFunc) Close() error { + return fn() +} + +var _ io.Closer = (*closeFunc)(nil) + +type cachedConn struct { + tabletmanagerservicepb.TabletManagerClient + cc *grpc.ClientConn + + addr string + lastAccessTime time.Time + refs int +} + +type cachedConnDialer struct { + m sync.Mutex + conns map[string]*cachedConn + evict []*cachedConn + evictSorted bool + connWaitSema *sync2.Semaphore + capacity int +} + +var dialerStats = struct { + ConnReuse *stats.Gauge + ConnNew *stats.Gauge + DialTimeouts *stats.Gauge + DialTimings *stats.Timings +}{ + ConnReuse: stats.NewGauge("tabletmanagerclient_cachedconn_reuse", "number of times a call to dial() was able to reuse an existing connection"), + ConnNew: stats.NewGauge("tabletmanagerclient_cachedconn_new", "number of times a call to dial() resulted in a dialing a new grpc clientconn"), + DialTimeouts: stats.NewGauge("tabletmanagerclient_cachedconn_dial_timeouts", "number of context timeouts during dial()"), + DialTimings: stats.NewTimings("tabletmanagerclient_cachedconn_dial_timings", "timings for various dial paths", "path", "cache_fast", "sema_fast", "sema_poll"), +} + +// NewCachedConnClient returns a grpc Client that caches connections to the +// different tablets. +func NewCachedConnClient(capacity int) *Client { + dialer := &cachedConnDialer{ + conns: make(map[string]*cachedConn, capacity), + evict: make([]*cachedConn, 0, capacity), + connWaitSema: sync2.NewSemaphore(capacity, 0), + capacity: capacity, + } + return &Client{dialer} +} + +var _ dialer = (*cachedConnDialer)(nil) + +func (dialer *cachedConnDialer) sortEvictionsLocked() { + if !dialer.evictSorted { + sort.Slice(dialer.evict, func(i, j int) bool { + left, right := dialer.evict[i], dialer.evict[j] + if left.refs == right.refs { + return right.lastAccessTime.After(left.lastAccessTime) + } + return right.refs > left.refs + }) + dialer.evictSorted = true + } +} + +func (dialer *cachedConnDialer) dial(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) { + start := time.Now() + addr := getTabletAddr(tablet) + + if client, closer, found, err := dialer.tryFromCache(addr, &dialer.m); found { + dialerStats.DialTimings.Add("cache_fast", time.Since(start)) + return client, closer, err + } + + if dialer.connWaitSema.TryAcquire() { + defer func() { + dialerStats.DialTimings.Add("sema_fast", time.Since(start)) + }() + + // Check if another goroutine managed to dial a conn for the same addr + // while we were waiting for the write lock. This is identical to the + // read-lock section above, except we release the connWaitSema if we + // are able to use the cache, allowing another goroutine to dial a new + // conn instead. + if client, closer, found, err := dialer.tryFromCache(addr, &dialer.m); found { + dialer.connWaitSema.Release() + return client, closer, err + } + return dialer.newdial(ctx, addr) + } + + defer func() { + dialerStats.DialTimings.Add("sema_poll", time.Since(start)) + }() + + for { + select { + case <-ctx.Done(): + dialerStats.DialTimeouts.Add(1) + return nil, nil, ctx.Err() + default: + if client, closer, found, err := dialer.pollOnce(ctx, addr); found { + return client, closer, err + } + } + } +} + +// tryFromCache tries to get a connection from the cache, performing a redial +// on that connection if it exists. It returns a TabletManagerClient impl, an +// io.Closer, a flag to indicate whether a connection was found in the cache, +// and an error, which is always nil. +// +// In addition to the addr being dialed, tryFromCache takes a sync.Locker which, +// if not nil, will be used to wrap the lookup and redial in that lock. This +// function can be called in situations where the conns map is locked +// externally (like in pollOnce), so we do not want to manage the locks here. In +// other cases (like in the cache_fast path of dial()), we pass in the dialer.m +// to ensure we have a lock on the cache for the duration of the call. +func (dialer *cachedConnDialer) tryFromCache(addr string, locker sync.Locker) (client tabletmanagerservicepb.TabletManagerClient, closer io.Closer, found bool, err error) { + if locker != nil { + locker.Lock() + defer locker.Unlock() + } + + if conn, ok := dialer.conns[addr]; ok { + client, closer, err := dialer.redialLocked(conn) + return client, closer, ok, err + } + + return nil, nil, false, nil +} + +// pollOnce is called on each iteration of the polling loop in dial(). It: +// - locks the conns cache for writes +// - attempts to get a connection from the cache. If found, redial() it and exit. +// - peeks at the head of the eviction queue. if the peeked conn has no refs, it +// is unused, and can be evicted to make room for the new connection to addr. +// If the peeked conn has refs, exit. +// - pops the conn we just peeked from the queue, deletes it from the cache, and +// close the underlying ClientConn for that conn. +// - attempt a newdial. if the newdial fails, it will release a slot on the +// connWaitSema, so another dial() call can successfully acquire it to dial +// a new conn. if the newdial succeeds, we will have evicted one conn, but +// added another, so the net change is 0, and no changes to the connWaitSema +// are made. +// +// It returns a TabletManagerClient impl, an io.Closer, a flag to indicate +// whether the dial() poll loop should exit, and an error. +func (dialer *cachedConnDialer) pollOnce(ctx context.Context, addr string) (client tabletmanagerservicepb.TabletManagerClient, closer io.Closer, found bool, err error) { + dialer.m.Lock() + + if client, closer, found, err := dialer.tryFromCache(addr, nil); found { + dialer.m.Unlock() + return client, closer, found, err + } + + dialer.sortEvictionsLocked() + + conn := dialer.evict[0] + if conn.refs != 0 { + dialer.m.Unlock() + return nil, nil, false, nil + } + + dialer.evict = dialer.evict[1:] + delete(dialer.conns, conn.addr) + conn.cc.Close() + dialer.m.Unlock() + + client, closer, err = dialer.newdial(ctx, addr) + return client, closer, true, err +} + +// newdial creates a new cached connection, and updates the cache and eviction +// queue accordingly. If newdial fails to create the underlying +// gRPC connection, it will make a call to Release the connWaitSema for other +// newdial calls. +// +// It returns the three-tuple of client-interface, closer, and error that the +// main dial func returns. +func (dialer *cachedConnDialer) newdial(ctx context.Context, addr string) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) { + opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name) + if err != nil { + dialer.connWaitSema.Release() + return nil, nil, err + } + + cc, err := grpcclient.DialContext(ctx, addr, grpcclient.FailFast(false), opt) + if err != nil { + dialer.connWaitSema.Release() + return nil, nil, err + } + + dialer.m.Lock() + defer dialer.m.Unlock() + + if conn, existing := dialer.conns[addr]; existing { + // race condition: some other goroutine has dialed our tablet before we have; + // this is not great, but shouldn't happen often (if at all), so we're going to + // close this connection and reuse the existing one. by doing this, we can keep + // the actual Dial out of the global lock and significantly increase throughput + cc.Close() + dialer.connWaitSema.Release() + return dialer.redialLocked(conn) + } + + dialerStats.ConnNew.Add(1) + + conn := &cachedConn{ + TabletManagerClient: tabletmanagerservicepb.NewTabletManagerClient(cc), + cc: cc, + lastAccessTime: time.Now(), + refs: 1, + addr: addr, + } + + // NOTE: we deliberately do not set dialer.evictSorted=false here. Since + // cachedConns are evicted from the front of the queue, and we are appending + // to the end, if there is already a second evictable connection, it will be + // at the front of the queue, so we can speed up the edge case where we need + // to evict multiple connections in a row. + dialer.evict = append(dialer.evict, conn) + dialer.conns[addr] = conn + + return dialer.connWithCloser(conn) +} + +// redialLocked takes an already-dialed connection in the cache does all the +// work of lending that connection out to one more caller. It returns the +// three-tuple of client-interface, closer, and error that the main dial func +// returns. +func (dialer *cachedConnDialer) redialLocked(conn *cachedConn) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) { + dialerStats.ConnReuse.Add(1) + conn.lastAccessTime = time.Now() + conn.refs++ + dialer.evictSorted = false + return dialer.connWithCloser(conn) +} + +// connWithCloser returns the three-tuple expected by the main dial func, where +// the closer handles the correct state management for updating the conns place +// in the eviction queue. +func (dialer *cachedConnDialer) connWithCloser(conn *cachedConn) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) { + return conn, closeFunc(func() error { + dialer.m.Lock() + defer dialer.m.Unlock() + conn.refs-- + dialer.evictSorted = false + return nil + }), nil +} + +// Close closes all currently cached connections, ***regardless of whether +// those connections are in use***. Calling Close therefore will fail any RPCs +// using currently lent-out connections, and, furthermore, will invalidate the +// io.Closer that was returned for that connection from dialer.dial(). When +// calling those io.Closers, they will still lock the dialer's mutex, and then +// perform needless operations that will slow down dial throughput, but not +// actually impact the correctness of the internal state of the dialer. +// +// As a result, while it is safe to reuse a cachedConnDialer after calling Close, +// it will be less performant than getting a new one, either by calling +// tmclient.TabletManagerClient() with +// TabletManagerProtocol set to "grpc-cached", or by calling +// grpctmclient.NewCachedConnClient directly. +func (dialer *cachedConnDialer) Close() { + dialer.m.Lock() + defer dialer.m.Unlock() + + for _, conn := range dialer.evict { + conn.cc.Close() + delete(dialer.conns, conn.addr) + dialer.connWaitSema.Release() + } + dialer.evict = make([]*cachedConn, 0, dialer.capacity) +} + +func getTabletAddr(tablet *topodatapb.Tablet) string { + return netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"])) +} diff --git a/go/vt/vttablet/grpctmclient/cached_client_test.go b/go/vt/vttablet/grpctmclient/cached_client_test.go new file mode 100644 index 00000000000..096e0278150 --- /dev/null +++ b/go/vt/vttablet/grpctmclient/cached_client_test.go @@ -0,0 +1,444 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package grpctmclient + +import ( + "context" + "fmt" + "io" + "math/rand" + "net" + "runtime" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/nettest" + "google.golang.org/grpc" + + "vitess.io/vitess/go/sync2" + "vitess.io/vitess/go/vt/vttablet/grpctmserver" + "vitess.io/vitess/go/vt/vttablet/tabletmanager" + "vitess.io/vitess/go/vt/vttablet/tmrpctest" + + topodatapb "vitess.io/vitess/go/vt/proto/topodata" +) + +func grpcTestServer(t testing.TB, tm tabletmanager.RPCTM) (*net.TCPAddr, func()) { + t.Helper() + + lis, err := nettest.NewLocalListener("tcp") + if err != nil { + t.Fatalf("Cannot listen: %v", err) + } + + s := grpc.NewServer() + grpctmserver.RegisterForTest(s, tm) + go s.Serve(lis) + + var shutdownOnce sync.Once + + return lis.Addr().(*net.TCPAddr), func() { + shutdownOnce.Do(func() { + s.Stop() + lis.Close() + }) + } +} + +func BenchmarkCachedConnClientSteadyState(b *testing.B) { + tmserv := tmrpctest.NewFakeRPCTM(b) + tablets := make([]*topodatapb.Tablet, 1000) + for i := 0; i < len(tablets); i++ { + addr, shutdown := grpcTestServer(b, tmserv) + defer shutdown() + + tablets[i] = &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "test", + Uid: uint32(addr.Port), + }, + Hostname: addr.IP.String(), + PortMap: map[string]int32{ + "grpc": int32(addr.Port), + }, + } + } + + client := NewCachedConnClient(100) + defer client.Close() + + // fill the pool + for i := 0; i < 100; i++ { + err := client.Ping(context.Background(), tablets[i]) + require.NoError(b, err) + } + + procs := runtime.GOMAXPROCS(0) / 4 + if procs == 0 { + procs = 2 + } + + pingsPerProc := len(tablets) / procs + if pingsPerProc == 0 { + pingsPerProc = 2 + } + + b.ResetTimer() + + // Begin the benchmark + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + for j := 0; j < procs; j++ { + wg.Add(1) + go func() { + defer wg.Done() + + for k := 0; k < pingsPerProc; k++ { + func() { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + x := rand.Intn(len(tablets)) + err := client.Ping(ctx, tablets[x]) + assert.NoError(b, err) + }() + } + }() + } + + wg.Wait() + cancel() + } +} + +func BenchmarkCachedConnClientSteadyStateRedials(b *testing.B) { + tmserv := tmrpctest.NewFakeRPCTM(b) + tablets := make([]*topodatapb.Tablet, 1000) + for i := 0; i < len(tablets); i++ { + addr, shutdown := grpcTestServer(b, tmserv) + defer shutdown() + + tablets[i] = &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "test", + Uid: uint32(addr.Port), + }, + Hostname: addr.IP.String(), + PortMap: map[string]int32{ + "grpc": int32(addr.Port), + }, + } + } + + client := NewCachedConnClient(1000) + defer client.Close() + + // fill the pool + for i := 0; i < 1000; i++ { + err := client.Ping(context.Background(), tablets[i]) + require.NoError(b, err) + } + + procs := runtime.GOMAXPROCS(0) / 4 + if procs == 0 { + procs = 2 + } + + pingsPerProc := len(tablets) / procs + if pingsPerProc == 0 { + pingsPerProc = 2 + } + + b.ResetTimer() + + // Begin the benchmark + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + for j := 0; j < procs; j++ { + wg.Add(1) + go func() { + defer wg.Done() + + for k := 0; k < pingsPerProc; k++ { + func() { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + x := rand.Intn(len(tablets)) + err := client.Ping(ctx, tablets[x]) + assert.NoError(b, err) + }() + } + }() + } + + wg.Wait() + cancel() + } +} + +func BenchmarkCachedConnClientSteadyStateEvictions(b *testing.B) { + tmserv := tmrpctest.NewFakeRPCTM(b) + tablets := make([]*topodatapb.Tablet, 1000) + for i := 0; i < len(tablets); i++ { + addr, shutdown := grpcTestServer(b, tmserv) + defer shutdown() + + tablets[i] = &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "test", + Uid: uint32(addr.Port), + }, + Hostname: addr.IP.String(), + PortMap: map[string]int32{ + "grpc": int32(addr.Port), + }, + } + } + + client := NewCachedConnClient(100) + defer client.Close() + + // fill the pool + for i := 0; i < 100; i++ { + err := client.Ping(context.Background(), tablets[i]) + require.NoError(b, err) + } + + assert.Equal(b, len(client.dialer.(*cachedConnDialer).conns), 100) + + procs := runtime.GOMAXPROCS(0) / 4 + if procs == 0 { + procs = 2 + } + + start := 100 + b.ResetTimer() + + // Begin the benchmark + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan int, 100) // 100 dials per iteration + + var wg sync.WaitGroup + for j := 0; j < procs; j++ { + wg.Add(1) + go func() { + defer wg.Done() + + for idx := range ch { + func() { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + err := client.Ping(ctx, tablets[idx]) + assert.NoError(b, err) + }() + } + }() + } + + for j := 0; j < cap(ch); j++ { + start = (start + j) % 1000 // go in increasing order, wrapping around + ch <- start + } + + close(ch) + wg.Wait() + cancel() + } +} + +func TestCachedConnClient(t *testing.T) { + t.Parallel() + + testCtx, testCancel := context.WithCancel(context.Background()) + wg := sync.WaitGroup{} + procs := 0 + + wg.Add(1) + go func() { + defer wg.Done() + procs = runtime.NumGoroutine() + + for { + select { + case <-testCtx.Done(): + return + case <-time.After(time.Millisecond * 100): + newProcs := runtime.NumGoroutine() + if newProcs > procs { + procs = newProcs + } + } + } + }() + + numTablets := 100 + numGoroutines := 8 + + tmserv := tmrpctest.NewFakeRPCTM(t) + tablets := make([]*topodatapb.Tablet, numTablets) + for i := 0; i < len(tablets); i++ { + addr, shutdown := grpcTestServer(t, tmserv) + defer shutdown() + + tablets[i] = &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "test", + Uid: uint32(addr.Port), + }, + Hostname: addr.IP.String(), + PortMap: map[string]int32{ + "grpc": int32(addr.Port), + }, + } + } + + poolSize := int(float64(numTablets) * 0.5) + client := NewCachedConnClient(poolSize) + defer client.Close() + + dialAttempts := sync2.NewAtomicInt64(0) + dialErrors := sync2.NewAtomicInt64(0) + + longestDials := make(chan time.Duration, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + attempts := 0 + jitter := time.Second * 0 + longestDial := time.Duration(0) + + for { + select { + case <-testCtx.Done(): + dialAttempts.Add(int64(attempts)) + longestDials <- longestDial + return + case <-time.After(jitter): + jitter = time.Millisecond * (time.Duration(rand.Intn(11) + 50)) + attempts++ + + tablet := tablets[rand.Intn(len(tablets))] + start := time.Now() + _, closer, err := client.dialer.dial(context.Background(), tablet) + if err != nil { + dialErrors.Add(1) + continue + } + + dialDuration := time.Since(start) + if dialDuration > longestDial { + longestDial = dialDuration + } + + closer.Close() + } + } + }() + } + + time.Sleep(time.Minute) + testCancel() + wg.Wait() + close(longestDials) + + longestDial := time.Duration(0) + for dialDuration := range longestDials { + if dialDuration > longestDial { + longestDial = dialDuration + } + } + + attempts, errors := dialAttempts.Get(), dialErrors.Get() + assert.Less(t, float64(errors)/float64(attempts), 0.001, fmt.Sprintf("fewer than 0.1%% of dial attempts should fail (attempts = %d, errors = %d, max running procs = %d)", attempts, errors, procs)) + assert.Less(t, errors, int64(1), "at least one dial attempt failed (attempts = %d, errors = %d)", attempts, errors) + assert.Less(t, longestDial.Milliseconds(), int64(50)) +} + +func TestCachedConnClient_evictions(t *testing.T) { + tmserv := tmrpctest.NewFakeRPCTM(t) + tablets := make([]*topodatapb.Tablet, 5) + for i := 0; i < len(tablets); i++ { + addr, shutdown := grpcTestServer(t, tmserv) + defer shutdown() + + tablets[i] = &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "test", + Uid: uint32(addr.Port), + }, + Hostname: addr.IP.String(), + PortMap: map[string]int32{ + "grpc": int32(addr.Port), + }, + } + } + + testCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connHoldContext, connHoldCancel := context.WithCancel(testCtx) + + client := NewCachedConnClient(len(tablets) - 1) + for i := 0; i < len(tablets)-1; i++ { + _, closer, err := client.dialer.dial(context.Background(), tablets[i]) + t.Logf("holding connection open to %d", tablets[i].Alias.Uid) + require.NoError(t, err) + + ctx := testCtx + if i == 0 { + ctx = connHoldContext + } + go func(ctx context.Context, closer io.Closer) { + // Hold on to one connection until the test is done. + // In the case of tablets[0], hold on to the connection until we + // signal to close it. + <-ctx.Done() + closer.Close() + }(ctx, closer) + } + + dialCtx, dialCancel := context.WithTimeout(testCtx, time.Millisecond*50) + defer dialCancel() + + err := client.Ping(dialCtx, tablets[0]) // this should take the rlock_fast path + assert.NoError(t, err, "could not redial on inuse cached connection") + + err = client.Ping(dialCtx, tablets[4]) // this will enter the poll loop until context timeout + assert.Error(t, err, "should have timed out waiting for an eviction, while all conns were held") + + // free up a connection + connHoldCancel() + + dialCtx, dialCancel = context.WithTimeout(testCtx, time.Millisecond*100) + defer dialCancel() + + err = client.Ping(dialCtx, tablets[4]) // this will enter the poll loop and evict a connection + assert.NoError(t, err, "should have evicted a conn and succeeded to dial") +} diff --git a/go/vt/vttablet/grpctmclient/client.go b/go/vt/vttablet/grpctmclient/client.go index ee7d7e7c004..c7dd4e3fcd5 100644 --- a/go/vt/vttablet/grpctmclient/client.go +++ b/go/vt/vttablet/grpctmclient/client.go @@ -19,6 +19,7 @@ package grpctmclient import ( "flag" "fmt" + "io" "sync" "time" @@ -54,6 +55,9 @@ func init() { tmclient.RegisterTabletManagerClientFactory("grpc", func() tmclient.TabletManagerClient { return NewClient() }) + tmclient.RegisterTabletManagerClientFactory("grpc-oneshot", func() tmclient.TabletManagerClient { + return NewClient() + }) } type tmc struct { @@ -61,8 +65,8 @@ type tmc struct { client tabletmanagerservicepb.TabletManagerClient } -// Client implements tmclient.TabletManagerClient -type Client struct { +// grpcClient implements both dialer and poolDialer. +type grpcClient struct { // This cache of connections is to maximize QPS for ExecuteFetch. // Note we'll keep the clients open and close them upon Close() only. // But that's OK because usually the tasks that use them are @@ -72,13 +76,40 @@ type Client struct { rpcClientMap map[string]chan *tmc } +type dialer interface { + dial(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) + Close() +} + +type poolDialer interface { + dialPool(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, error) +} + +// Client implements tmclient.TabletManagerClient. +// +// Connections are produced by the dialer implementation, which is either the +// grpcClient implementation, which reuses connections only for ExecuteFetch and +// otherwise makes single-purpose connections that are closed after use. +// +// In order to more efficiently use the underlying tcp connections, you can +// instead use the cachedConnDialer implementation by specifying +// -tablet_manager_protocol "grpc-cached" +// The cachedConnDialer keeps connections to up to -tablet_manager_grpc_connpool_size distinct +// tablets open at any given time, for faster per-RPC call time, and less +// connection churn. +type Client struct { + dialer dialer +} + // NewClient returns a new gRPC client. func NewClient() *Client { - return &Client{} + return &Client{ + dialer: &grpcClient{}, + } } // dial returns a client to use -func (client *Client) dial(tablet *topodatapb.Tablet) (*grpc.ClientConn, tabletmanagerservicepb.TabletManagerClient, error) { +func (client *grpcClient) dial(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) { addr := netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"])) opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name) if err != nil { @@ -88,10 +119,11 @@ func (client *Client) dial(tablet *topodatapb.Tablet) (*grpc.ClientConn, tabletm if err != nil { return nil, nil, err } - return cc, tabletmanagerservicepb.NewTabletManagerClient(cc), nil + + return tabletmanagerservicepb.NewTabletManagerClient(cc), cc, nil } -func (client *Client) dialPool(tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, error) { +func (client *grpcClient) dialPool(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, error) { addr := netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"])) opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name) if err != nil { @@ -127,17 +159,30 @@ func (client *Client) dialPool(tablet *topodatapb.Tablet) (tabletmanagerservicep return result.client, nil } +// Close is part of the tmclient.TabletManagerClient interface. +func (client *grpcClient) Close() { + client.mu.Lock() + defer client.mu.Unlock() + for _, c := range client.rpcClientMap { + close(c) + for ch := range c { + ch.cc.Close() + } + } + client.rpcClientMap = nil +} + // // Various read-only methods // // Ping is part of the tmclient.TabletManagerClient interface. func (client *Client) Ping(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() result, err := c.Ping(ctx, &tabletmanagerdatapb.PingRequest{ Payload: "payload", }) @@ -152,11 +197,11 @@ func (client *Client) Ping(ctx context.Context, tablet *topodatapb.Tablet) error // Sleep is part of the tmclient.TabletManagerClient interface. func (client *Client) Sleep(ctx context.Context, tablet *topodatapb.Tablet, duration time.Duration) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.Sleep(ctx, &tabletmanagerdatapb.SleepRequest{ Duration: int64(duration), }) @@ -165,11 +210,11 @@ func (client *Client) Sleep(ctx context.Context, tablet *topodatapb.Tablet, dura // ExecuteHook is part of the tmclient.TabletManagerClient interface. func (client *Client) ExecuteHook(ctx context.Context, tablet *topodatapb.Tablet, hk *hook.Hook) (*hook.HookResult, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() hr, err := c.ExecuteHook(ctx, &tabletmanagerdatapb.ExecuteHookRequest{ Name: hk.Name, Parameters: hk.Parameters, @@ -187,11 +232,11 @@ func (client *Client) ExecuteHook(ctx context.Context, tablet *topodatapb.Tablet // GetSchema is part of the tmclient.TabletManagerClient interface. func (client *Client) GetSchema(ctx context.Context, tablet *topodatapb.Tablet, tables, excludeTables []string, includeViews bool) (*tabletmanagerdatapb.SchemaDefinition, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.GetSchema(ctx, &tabletmanagerdatapb.GetSchemaRequest{ Tables: tables, ExcludeTables: excludeTables, @@ -205,11 +250,11 @@ func (client *Client) GetSchema(ctx context.Context, tablet *topodatapb.Tablet, // GetPermissions is part of the tmclient.TabletManagerClient interface. func (client *Client) GetPermissions(ctx context.Context, tablet *topodatapb.Tablet) (*tabletmanagerdatapb.Permissions, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.GetPermissions(ctx, &tabletmanagerdatapb.GetPermissionsRequest{}) if err != nil { return nil, err @@ -223,33 +268,33 @@ func (client *Client) GetPermissions(ctx context.Context, tablet *topodatapb.Tab // SetReadOnly is part of the tmclient.TabletManagerClient interface. func (client *Client) SetReadOnly(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.SetReadOnly(ctx, &tabletmanagerdatapb.SetReadOnlyRequest{}) return err } // SetReadWrite is part of the tmclient.TabletManagerClient interface. func (client *Client) SetReadWrite(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.SetReadWrite(ctx, &tabletmanagerdatapb.SetReadWriteRequest{}) return err } // ChangeType is part of the tmclient.TabletManagerClient interface. func (client *Client) ChangeType(ctx context.Context, tablet *topodatapb.Tablet, dbType topodatapb.TabletType) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.ChangeType(ctx, &tabletmanagerdatapb.ChangeTypeRequest{ TabletType: dbType, }) @@ -258,33 +303,33 @@ func (client *Client) ChangeType(ctx context.Context, tablet *topodatapb.Tablet, // RefreshState is part of the tmclient.TabletManagerClient interface. func (client *Client) RefreshState(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.RefreshState(ctx, &tabletmanagerdatapb.RefreshStateRequest{}) return err } // RunHealthCheck is part of the tmclient.TabletManagerClient interface. func (client *Client) RunHealthCheck(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.RunHealthCheck(ctx, &tabletmanagerdatapb.RunHealthCheckRequest{}) return err } // IgnoreHealthError is part of the tmclient.TabletManagerClient interface. func (client *Client) IgnoreHealthError(ctx context.Context, tablet *topodatapb.Tablet, pattern string) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.IgnoreHealthError(ctx, &tabletmanagerdatapb.IgnoreHealthErrorRequest{ Pattern: pattern, }) @@ -293,11 +338,11 @@ func (client *Client) IgnoreHealthError(ctx context.Context, tablet *topodatapb. // ReloadSchema is part of the tmclient.TabletManagerClient interface. func (client *Client) ReloadSchema(ctx context.Context, tablet *topodatapb.Tablet, waitPosition string) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.ReloadSchema(ctx, &tabletmanagerdatapb.ReloadSchemaRequest{ WaitPosition: waitPosition, }) @@ -306,11 +351,11 @@ func (client *Client) ReloadSchema(ctx context.Context, tablet *topodatapb.Table // PreflightSchema is part of the tmclient.TabletManagerClient interface. func (client *Client) PreflightSchema(ctx context.Context, tablet *topodatapb.Tablet, changes []string) ([]*tabletmanagerdatapb.SchemaChangeResult, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.PreflightSchema(ctx, &tabletmanagerdatapb.PreflightSchemaRequest{ Changes: changes, @@ -324,11 +369,11 @@ func (client *Client) PreflightSchema(ctx context.Context, tablet *topodatapb.Ta // ApplySchema is part of the tmclient.TabletManagerClient interface. func (client *Client) ApplySchema(ctx context.Context, tablet *topodatapb.Tablet, change *tmutils.SchemaChange) (*tabletmanagerdatapb.SchemaChangeResult, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.ApplySchema(ctx, &tabletmanagerdatapb.ApplySchemaRequest{ Sql: change.SQL, Force: change.Force, @@ -347,11 +392,11 @@ func (client *Client) ApplySchema(ctx context.Context, tablet *topodatapb.Tablet // LockTables is part of the tmclient.TabletManagerClient interface. func (client *Client) LockTables(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.LockTables(ctx, &tabletmanagerdatapb.LockTablesRequest{}) return err @@ -359,11 +404,11 @@ func (client *Client) LockTables(ctx context.Context, tablet *topodatapb.Tablet) // UnlockTables is part of the tmclient.TabletManagerClient interface. func (client *Client) UnlockTables(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.UnlockTables(ctx, &tabletmanagerdatapb.UnlockTablesRequest{}) return err @@ -371,11 +416,11 @@ func (client *Client) UnlockTables(ctx context.Context, tablet *topodatapb.Table // ExecuteQuery is part of the tmclient.TabletManagerClient interface. func (client *Client) ExecuteQuery(ctx context.Context, tablet *topodatapb.Tablet, query []byte, maxrows int) (*querypb.QueryResult, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.ExecuteQuery(ctx, &tabletmanagerdatapb.ExecuteQueryRequest{ Query: query, @@ -393,17 +438,21 @@ func (client *Client) ExecuteFetchAsDba(ctx context.Context, tablet *topodatapb. var c tabletmanagerservicepb.TabletManagerClient var err error if usePool { - c, err = client.dialPool(tablet) - if err != nil { - return nil, err + if poolDialer, ok := client.dialer.(poolDialer); ok { + c, err = poolDialer.dialPool(ctx, tablet) + if err != nil { + return nil, err + } } - } else { - var cc *grpc.ClientConn - cc, c, err = client.dial(tablet) + } + + if !usePool || c == nil { + var closer io.Closer + c, closer, err = client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() } response, err := c.ExecuteFetchAsDba(ctx, &tabletmanagerdatapb.ExecuteFetchAsDbaRequest{ @@ -421,14 +470,11 @@ func (client *Client) ExecuteFetchAsDba(ctx context.Context, tablet *topodatapb. // ExecuteFetchAsAllPrivs is part of the tmclient.TabletManagerClient interface. func (client *Client) ExecuteFetchAsAllPrivs(ctx context.Context, tablet *topodatapb.Tablet, query []byte, maxRows int, reloadSchema bool) (*querypb.QueryResult, error) { - var c tabletmanagerservicepb.TabletManagerClient - var err error - var cc *grpc.ClientConn - cc, c, err = client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.ExecuteFetchAsAllPrivs(ctx, &tabletmanagerdatapb.ExecuteFetchAsAllPrivsRequest{ Query: query, @@ -447,17 +493,21 @@ func (client *Client) ExecuteFetchAsApp(ctx context.Context, tablet *topodatapb. var c tabletmanagerservicepb.TabletManagerClient var err error if usePool { - c, err = client.dialPool(tablet) - if err != nil { - return nil, err + if poolDialer, ok := client.dialer.(poolDialer); ok { + c, err = poolDialer.dialPool(ctx, tablet) + if err != nil { + return nil, err + } } - } else { - var cc *grpc.ClientConn - cc, c, err = client.dial(tablet) + } + + if !usePool || c == nil { + var closer io.Closer + c, closer, err = client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() } response, err := c.ExecuteFetchAsApp(ctx, &tabletmanagerdatapb.ExecuteFetchAsAppRequest{ @@ -476,11 +526,11 @@ func (client *Client) ExecuteFetchAsApp(ctx context.Context, tablet *topodatapb. // ReplicationStatus is part of the tmclient.TabletManagerClient interface. func (client *Client) ReplicationStatus(ctx context.Context, tablet *topodatapb.Tablet) (*replicationdatapb.Status, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.ReplicationStatus(ctx, &tabletmanagerdatapb.ReplicationStatusRequest{}) if err != nil { return nil, err @@ -490,11 +540,11 @@ func (client *Client) ReplicationStatus(ctx context.Context, tablet *topodatapb. // MasterStatus is part of the tmclient.TabletManagerClient interface. func (client *Client) MasterStatus(ctx context.Context, tablet *topodatapb.Tablet) (*replicationdatapb.MasterStatus, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.MasterStatus(ctx, &tabletmanagerdatapb.MasterStatusRequest{}) if err != nil { return nil, err @@ -504,11 +554,11 @@ func (client *Client) MasterStatus(ctx context.Context, tablet *topodatapb.Table // MasterPosition is part of the tmclient.TabletManagerClient interface. func (client *Client) MasterPosition(ctx context.Context, tablet *topodatapb.Tablet) (string, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return "", err } - defer cc.Close() + defer closer.Close() response, err := c.MasterPosition(ctx, &tabletmanagerdatapb.MasterPositionRequest{}) if err != nil { return "", err @@ -518,33 +568,34 @@ func (client *Client) MasterPosition(ctx context.Context, tablet *topodatapb.Tab // WaitForPosition is part of the tmclient.TabletManagerClient interface. func (client *Client) WaitForPosition(ctx context.Context, tablet *topodatapb.Tablet, pos string) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.WaitForPosition(ctx, &tabletmanagerdatapb.WaitForPositionRequest{Position: pos}) return err } // StopReplication is part of the tmclient.TabletManagerClient interface. func (client *Client) StopReplication(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.StopReplication(ctx, &tabletmanagerdatapb.StopReplicationRequest{}) return err } // StopReplicationMinimum is part of the tmclient.TabletManagerClient interface. func (client *Client) StopReplicationMinimum(ctx context.Context, tablet *topodatapb.Tablet, minPos string, waitTime time.Duration) (string, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return "", err } - defer cc.Close() + defer closer.Close() + response, err := c.StopReplicationMinimum(ctx, &tabletmanagerdatapb.StopReplicationMinimumRequest{ Position: minPos, WaitTimeout: int64(waitTime), @@ -557,22 +608,22 @@ func (client *Client) StopReplicationMinimum(ctx context.Context, tablet *topoda // StartReplication is part of the tmclient.TabletManagerClient interface. func (client *Client) StartReplication(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.StartReplication(ctx, &tabletmanagerdatapb.StartReplicationRequest{}) return err } // StartReplicationUntilAfter is part of the tmclient.TabletManagerClient interface. func (client *Client) StartReplicationUntilAfter(ctx context.Context, tablet *topodatapb.Tablet, position string, waitTime time.Duration) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.StartReplicationUntilAfter(ctx, &tabletmanagerdatapb.StartReplicationUntilAfterRequest{ Position: position, WaitTimeout: int64(waitTime), @@ -582,11 +633,11 @@ func (client *Client) StartReplicationUntilAfter(ctx context.Context, tablet *to // GetReplicas is part of the tmclient.TabletManagerClient interface. func (client *Client) GetReplicas(ctx context.Context, tablet *topodatapb.Tablet) ([]string, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.GetReplicas(ctx, &tabletmanagerdatapb.GetReplicasRequest{}) if err != nil { return nil, err @@ -596,11 +647,11 @@ func (client *Client) GetReplicas(ctx context.Context, tablet *topodatapb.Tablet // VExec is part of the tmclient.TabletManagerClient interface. func (client *Client) VExec(ctx context.Context, tablet *topodatapb.Tablet, query, workflow, keyspace string) (*querypb.QueryResult, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.VExec(ctx, &tabletmanagerdatapb.VExecRequest{Query: query, Workflow: workflow, Keyspace: keyspace}) if err != nil { return nil, err @@ -610,11 +661,11 @@ func (client *Client) VExec(ctx context.Context, tablet *topodatapb.Tablet, quer // VReplicationExec is part of the tmclient.TabletManagerClient interface. func (client *Client) VReplicationExec(ctx context.Context, tablet *topodatapb.Tablet, query string) (*querypb.QueryResult, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.VReplicationExec(ctx, &tabletmanagerdatapb.VReplicationExecRequest{Query: query}) if err != nil { return nil, err @@ -624,11 +675,11 @@ func (client *Client) VReplicationExec(ctx context.Context, tablet *topodatapb.T // VReplicationWaitForPos is part of the tmclient.TabletManagerClient interface. func (client *Client) VReplicationWaitForPos(ctx context.Context, tablet *topodatapb.Tablet, id int, pos string) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() if _, err = c.VReplicationWaitForPos(ctx, &tabletmanagerdatapb.VReplicationWaitForPosRequest{Id: int64(id), Position: pos}); err != nil { return err } @@ -641,22 +692,23 @@ func (client *Client) VReplicationWaitForPos(ctx context.Context, tablet *topoda // ResetReplication is part of the tmclient.TabletManagerClient interface. func (client *Client) ResetReplication(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.ResetReplication(ctx, &tabletmanagerdatapb.ResetReplicationRequest{}) return err } // InitMaster is part of the tmclient.TabletManagerClient interface. func (client *Client) InitMaster(ctx context.Context, tablet *topodatapb.Tablet) (string, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return "", err } - defer cc.Close() + defer closer.Close() + response, err := c.InitMaster(ctx, &tabletmanagerdatapb.InitMasterRequest{}) if err != nil { return "", err @@ -666,11 +718,11 @@ func (client *Client) InitMaster(ctx context.Context, tablet *topodatapb.Tablet) // PopulateReparentJournal is part of the tmclient.TabletManagerClient interface. func (client *Client) PopulateReparentJournal(ctx context.Context, tablet *topodatapb.Tablet, timeCreatedNS int64, actionName string, masterAlias *topodatapb.TabletAlias, pos string) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.PopulateReparentJournal(ctx, &tabletmanagerdatapb.PopulateReparentJournalRequest{ TimeCreatedNs: timeCreatedNS, ActionName: actionName, @@ -682,11 +734,11 @@ func (client *Client) PopulateReparentJournal(ctx context.Context, tablet *topod // InitReplica is part of the tmclient.TabletManagerClient interface. func (client *Client) InitReplica(ctx context.Context, tablet *topodatapb.Tablet, parent *topodatapb.TabletAlias, replicationPosition string, timeCreatedNS int64) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.InitReplica(ctx, &tabletmanagerdatapb.InitReplicaRequest{ Parent: parent, ReplicationPosition: replicationPosition, @@ -697,11 +749,11 @@ func (client *Client) InitReplica(ctx context.Context, tablet *topodatapb.Tablet // DemoteMaster is part of the tmclient.TabletManagerClient interface. func (client *Client) DemoteMaster(ctx context.Context, tablet *topodatapb.Tablet) (*replicationdatapb.MasterStatus, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } - defer cc.Close() + defer closer.Close() response, err := c.DemoteMaster(ctx, &tabletmanagerdatapb.DemoteMasterRequest{}) if err != nil { return nil, err @@ -719,33 +771,33 @@ func (client *Client) DemoteMaster(ctx context.Context, tablet *topodatapb.Table // UndoDemoteMaster is part of the tmclient.TabletManagerClient interface. func (client *Client) UndoDemoteMaster(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.UndoDemoteMaster(ctx, &tabletmanagerdatapb.UndoDemoteMasterRequest{}) return err } // ReplicaWasPromoted is part of the tmclient.TabletManagerClient interface. func (client *Client) ReplicaWasPromoted(ctx context.Context, tablet *topodatapb.Tablet) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.ReplicaWasPromoted(ctx, &tabletmanagerdatapb.ReplicaWasPromotedRequest{}) return err } // SetMaster is part of the tmclient.TabletManagerClient interface. func (client *Client) SetMaster(ctx context.Context, tablet *topodatapb.Tablet, parent *topodatapb.TabletAlias, timeCreatedNS int64, waitPosition string, forceStartReplication bool) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.SetMaster(ctx, &tabletmanagerdatapb.SetMasterRequest{ Parent: parent, TimeCreatedNs: timeCreatedNS, @@ -757,11 +809,11 @@ func (client *Client) SetMaster(ctx context.Context, tablet *topodatapb.Tablet, // ReplicaWasRestarted is part of the tmclient.TabletManagerClient interface. func (client *Client) ReplicaWasRestarted(ctx context.Context, tablet *topodatapb.Tablet, parent *topodatapb.TabletAlias) error { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return err } - defer cc.Close() + defer closer.Close() _, err = c.ReplicaWasRestarted(ctx, &tabletmanagerdatapb.ReplicaWasRestartedRequest{ Parent: parent, }) @@ -770,11 +822,11 @@ func (client *Client) ReplicaWasRestarted(ctx context.Context, tablet *topodatap // StopReplicationAndGetStatus is part of the tmclient.TabletManagerClient interface. func (client *Client) StopReplicationAndGetStatus(ctx context.Context, tablet *topodatapb.Tablet, stopReplicationMode replicationdatapb.StopReplicationMode) (hybridStatus *replicationdatapb.Status, status *replicationdatapb.StopReplicationStatus, err error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, nil, err } - defer cc.Close() + defer closer.Close() response, err := c.StopReplicationAndGetStatus(ctx, &tabletmanagerdatapb.StopReplicationAndGetStatusRequest{ StopReplicationMode: stopReplicationMode, }) @@ -789,11 +841,12 @@ func (client *Client) StopReplicationAndGetStatus(ctx context.Context, tablet *t // PromoteReplica is part of the tmclient.TabletManagerClient interface. func (client *Client) PromoteReplica(ctx context.Context, tablet *topodatapb.Tablet) (string, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return "", err } - defer cc.Close() + defer closer.Close() + response, err := c.PromoteReplica(ctx, &tabletmanagerdatapb.PromoteReplicaRequest{}) if err != nil { return "", err @@ -806,13 +859,13 @@ func (client *Client) PromoteReplica(ctx context.Context, tablet *topodatapb.Tab // type backupStreamAdapter struct { stream tabletmanagerservicepb.TabletManager_BackupClient - cc *grpc.ClientConn + closer io.Closer } func (e *backupStreamAdapter) Recv() (*logutilpb.Event, error) { br, err := e.stream.Recv() if err != nil { - e.cc.Close() + e.closer.Close() return nil, err } return br.Event, nil @@ -820,7 +873,7 @@ func (e *backupStreamAdapter) Recv() (*logutilpb.Event, error) { // Backup is part of the tmclient.TabletManagerClient interface. func (client *Client) Backup(ctx context.Context, tablet *topodatapb.Tablet, concurrency int, allowMaster bool) (logutil.EventStream, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } @@ -830,24 +883,24 @@ func (client *Client) Backup(ctx context.Context, tablet *topodatapb.Tablet, con AllowMaster: bool(allowMaster), }) if err != nil { - cc.Close() + closer.Close() return nil, err } return &backupStreamAdapter{ stream: stream, - cc: cc, + closer: closer, }, nil } type restoreFromBackupStreamAdapter struct { stream tabletmanagerservicepb.TabletManager_RestoreFromBackupClient - cc *grpc.ClientConn + closer io.Closer } func (e *restoreFromBackupStreamAdapter) Recv() (*logutilpb.Event, error) { br, err := e.stream.Recv() if err != nil { - e.cc.Close() + e.closer.Close() return nil, err } return br.Event, nil @@ -855,31 +908,23 @@ func (e *restoreFromBackupStreamAdapter) Recv() (*logutilpb.Event, error) { // RestoreFromBackup is part of the tmclient.TabletManagerClient interface. func (client *Client) RestoreFromBackup(ctx context.Context, tablet *topodatapb.Tablet) (logutil.EventStream, error) { - cc, c, err := client.dial(tablet) + c, closer, err := client.dialer.dial(ctx, tablet) if err != nil { return nil, err } stream, err := c.RestoreFromBackup(ctx, &tabletmanagerdatapb.RestoreFromBackupRequest{}) if err != nil { - cc.Close() + closer.Close() return nil, err } return &restoreFromBackupStreamAdapter{ stream: stream, - cc: cc, + closer: closer, }, nil } // Close is part of the tmclient.TabletManagerClient interface. func (client *Client) Close() { - client.mu.Lock() - defer client.mu.Unlock() - for _, c := range client.rpcClientMap { - close(c) - for ch := range c { - ch.cc.Close() - } - } - client.rpcClientMap = nil + client.dialer.Close() } diff --git a/go/vt/vttablet/grpctmserver/server.go b/go/vt/vttablet/grpctmserver/server.go index b76cc1e2365..9ba33a20893 100644 --- a/go/vt/vttablet/grpctmserver/server.go +++ b/go/vt/vttablet/grpctmserver/server.go @@ -502,6 +502,6 @@ func init() { } // RegisterForTest will register the RPC, to be used by test instances only -func RegisterForTest(s *grpc.Server, tm *tabletmanager.TabletManager) { +func RegisterForTest(s *grpc.Server, tm tabletmanager.RPCTM) { tabletmanagerservicepb.RegisterTabletManagerServer(s, &server{tm: tm}) } diff --git a/go/vt/vttablet/grpctmserver/server_test.go b/go/vt/vttablet/grpctmserver/server_test.go index e87022764a1..e8be1088604 100644 --- a/go/vt/vttablet/grpctmserver/server_test.go +++ b/go/vt/vttablet/grpctmserver/server_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package grpctmserver +package grpctmserver_test import ( "net" @@ -23,9 +23,9 @@ import ( "google.golang.org/grpc" "vitess.io/vitess/go/vt/vttablet/grpctmclient" + "vitess.io/vitess/go/vt/vttablet/grpctmserver" "vitess.io/vitess/go/vt/vttablet/tmrpctest" - tabletmanagerservicepb "vitess.io/vitess/go/vt/proto/tabletmanagerservice" topodatapb "vitess.io/vitess/go/vt/proto/topodata" ) @@ -43,7 +43,7 @@ func TestGRPCTMServer(t *testing.T) { // Create a gRPC server and listen on the port. s := grpc.NewServer() fakeTM := tmrpctest.NewFakeRPCTM(t) - tabletmanagerservicepb.RegisterTabletManagerServer(s, &server{tm: fakeTM}) + grpctmserver.RegisterForTest(s, fakeTM) go s.Serve(listener) // Create a gRPC client to talk to the fake tablet. diff --git a/go/vt/vttablet/tmrpctest/test_tm_rpc.go b/go/vt/vttablet/tmrpctest/test_tm_rpc.go index 89fb39de3b6..99609d4488e 100644 --- a/go/vt/vttablet/tmrpctest/test_tm_rpc.go +++ b/go/vt/vttablet/tmrpctest/test_tm_rpc.go @@ -45,7 +45,7 @@ import ( // fakeRPCTM implements tabletmanager.RPCTM and fills in all // possible values in all APIs type fakeRPCTM struct { - t *testing.T + t testing.TB panics bool // slow if true will let Ping() sleep and effectively not respond to an RPC. slow bool @@ -68,7 +68,7 @@ func (fra *fakeRPCTM) setSlow(slow bool) { } // NewFakeRPCTM returns a fake tabletmanager.RPCTM that's just a mirror. -func NewFakeRPCTM(t *testing.T) tabletmanager.RPCTM { +func NewFakeRPCTM(t testing.TB) tabletmanager.RPCTM { return &fakeRPCTM{ t: t, } @@ -83,7 +83,7 @@ func NewFakeRPCTM(t *testing.T) tabletmanager.RPCTM { var protoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem() -func compare(t *testing.T, name string, got, want interface{}) { +func compare(t testing.TB, name string, got, want interface{}) { t.Helper() typ := reflect.TypeOf(got) if reflect.TypeOf(got) != reflect.TypeOf(want) { @@ -114,7 +114,7 @@ fail: t.Errorf("Unexpected %v:\ngot %#v\nwant %#v", name, got, want) } -func compareBool(t *testing.T, name string, got bool) { +func compareBool(t testing.TB, name string, got bool) { t.Helper() if !got { t.Errorf("Unexpected %v: got false expected true", name) @@ -200,7 +200,7 @@ func tmRPCTestPingPanic(ctx context.Context, t *testing.T, client tmclient.Table // tmRPCTestDialExpiredContext verifies that // the context returns the right DeadlineExceeded Err() for // RPCs failed due to an expired context before .Dial(). -func tmRPCTestDialExpiredContext(ctx context.Context, t *testing.T, client tmclient.TabletManagerClient, tablet *topodatapb.Tablet) { +func tmRPCTestDialExpiredContext(ctx context.Context, t testing.TB, client tmclient.TabletManagerClient, tablet *topodatapb.Tablet) { // Using a timeout of 0 here such that .Dial() will fail immediately. expiredCtx, cancel := context.WithTimeout(ctx, 0) defer cancel()