diff --git a/domain/domain.go b/domain/domain.go index b170dc0b063af..dddd5384292b8 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -993,6 +993,7 @@ func (do *Domain) Close() { logutil.BgLogger().Warn("fail to wait until the ttl job manager stop", zap.Error(err)) } } + do.releaseServerID(context.Background()) close(do.exit) if do.etcdClient != nil { terror.Log(errors.Trace(do.etcdClient.Close())) @@ -3007,6 +3008,25 @@ func (do *Domain) acquireServerID(ctx context.Context) error { } } +func (do *Domain) releaseServerID(ctx context.Context) { + serverID := do.ServerID() + if serverID == 0 { + return + } + atomic.StoreUint64(&do.serverID, 0) + + if do.etcdClient == nil { + return + } + key := fmt.Sprintf("%s/%v", serverIDEtcdPath, serverID) + err := ddlutil.DeleteKeyFromEtcd(key, do.etcdClient, refreshServerIDRetryCnt, acquireServerIDTimeout) + if err != nil { + logutil.BgLogger().Error("releaseServerID fail", zap.Uint64("serverID", serverID), zap.Error(err)) + } else { + logutil.BgLogger().Info("releaseServerID succeed", zap.Uint64("serverID", serverID)) + } +} + // propose server ID by random. func (do *Domain) proposeServerID(ctx context.Context, conflictCnt int) (uint64, error) { // get a random server ID in range [min, max] @@ -3020,7 +3040,8 @@ func (do *Domain) proposeServerID(ctx context.Context, conflictCnt int) (uint64, if err != nil { return 0, errors.Trace(err) } - if float32(len(allServerInfo)) < 0.9*globalconn.MaxServerID32 { + // `allServerInfo` contains current TiDB. + if float32(len(allServerInfo)) <= 0.9*float32(globalconn.MaxServerID32) { serverIDs := make(map[uint64]struct{}, len(allServerInfo)) for _, info := range allServerInfo { serverID := info.ServerIDGetter() @@ -3036,6 +3057,9 @@ func (do *Domain) proposeServerID(ctx context.Context, conflictCnt int) (uint64, } } } + logutil.BgLogger().Info("upgrade to 64 bits server ID due to used up", zap.Int("len(allServerInfo)", len(allServerInfo))) + } else { + logutil.BgLogger().Info("upgrade to 64 bits server ID due to conflict", zap.Int("conflictCnt", conflictCnt)) } // upgrade to 64 bits. diff --git a/server/conn.go b/server/conn.go index 45ecd4a18e7a6..7a1a43268aaa1 100644 --- a/server/conn.go +++ b/server/conn.go @@ -335,9 +335,14 @@ func (cc *clientConn) Close() error { return closeConn(cc, connections) } +// closeConn should be idempotent. +// It will be called on the same `clientConn` more than once to avoid connection leak. func closeConn(cc *clientConn, connections int) error { metrics.ConnGauge.Set(float64(connections)) - cc.server.dom.ReleaseConnID(cc.connectionID) + if cc.connectionID > 0 { + cc.server.dom.ReleaseConnID(cc.connectionID) + cc.connectionID = 0 + } if cc.bufReadConn != nil { err := cc.bufReadConn.Close() if err != nil { diff --git a/tests/globalkilltest/BUILD.bazel b/tests/globalkilltest/BUILD.bazel index 4240b8a771149..ed31eb10ef0f8 100644 --- a/tests/globalkilltest/BUILD.bazel +++ b/tests/globalkilltest/BUILD.bazel @@ -17,6 +17,7 @@ go_test( "@com_github_stretchr_testify//require", "@io_etcd_go_etcd_client_v3//:client", "@org_golang_google_grpc//:grpc", + "@org_golang_google_grpc//backoff", "@org_uber_go_zap//:zap", ], ) diff --git a/tests/globalkilltest/Makefile b/tests/globalkilltest/Makefile index ed6013c44bb2e..446681bc8cfa3 100644 --- a/tests/globalkilltest/Makefile +++ b/tests/globalkilltest/Makefile @@ -24,6 +24,10 @@ GLOBAL_KILL_TEST_SERVER_LDFLAGS += -X "github.com/pingcap/tidb/domain.ldflagServ GLOBAL_KILL_TEST_SERVER_LDFLAGS += -X "github.com/pingcap/tidb/domain.ldflagLostConnectionToPDTimeout=5" GLOBAL_KILL_TEST_SERVER_LDFLAGS += -X "github.com/pingcap/tidb/store.ldflagGetEtcdAddrsFromConfig=1" +GLOBAL_KILL_TEST_SERVER_LDFLAGS += -X "github.com/pingcap/tidb/util/globalconn.ldflagIsGlobalKillTest=1" +GLOBAL_KILL_TEST_SERVER_LDFLAGS += -X "github.com/pingcap/tidb/util/globalconn.ldflagServerIDBits32=2" +GLOBAL_KILL_TEST_SERVER_LDFLAGS += -X "github.com/pingcap/tidb/util/globalconn.ldflagLocalConnIDBits32=4" + .PHONY: server buildsucc default: server buildsucc diff --git a/tests/globalkilltest/global_kill_test.go b/tests/globalkilltest/global_kill_test.go index 263b34c3254b3..eaadbdedf85ed 100644 --- a/tests/globalkilltest/global_kill_test.go +++ b/tests/globalkilltest/global_kill_test.go @@ -34,6 +34,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/backoff" ) var ( @@ -95,10 +96,32 @@ func createGlobalKillSuite(t *testing.T, enable32bits bool) *GlobalKillSuite { return s } +// Conn is wrapper of DB connection. +type Conn struct { + db *sql.DB + conn *sql.Conn + connID uint64 +} + +func (c *Conn) Close() { + c.conn.Close() + c.db.Close() +} + +func (c *Conn) mustBe32(t *testing.T) { + require.Lessf(t, c.connID, uint64(1<<32), "connID %x", c.connID) +} + +func (c *Conn) mustBe64(t *testing.T) { + require.Greaterf(t, c.connID, uint64(1<<32), "connID %x", c.connID) +} + func (s *GlobalKillSuite) connectPD() (cli *clientv3.Client, err error) { etcdLogCfg := zap.NewProductionConfig() etcdLogCfg.Level = zap.NewAtomicLevelAt(zap.ErrorLevel) wait := 250 * time.Millisecond + backoffConfig := backoff.DefaultConfig + backoffConfig.MaxDelay = 3 * time.Second for i := 0; i < 5; i++ { log.Info(fmt.Sprintf("trying to connect pd, attempt %d", i)) cli, err = clientv3.New(clientv3.Config{ @@ -107,7 +130,9 @@ func (s *GlobalKillSuite) connectPD() (cli *clientv3.Client, err error) { AutoSyncInterval: 30 * time.Second, DialTimeout: 5 * time.Second, DialOptions: []grpc.DialOption{ - grpc.WithBackoffMaxDelay(time.Second * 3), + grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoffConfig, + }), }, }) if err == nil { @@ -179,22 +204,32 @@ func (s *GlobalKillSuite) startCluster() (err error) { } func (s *GlobalKillSuite) stopPD() (err error) { + if s.pdProc == nil { + log.Info("PD already killed") + return nil + } if err = s.pdProc.Process.Kill(); err != nil { return errors.Trace(err) } if err = s.pdProc.Wait(); err != nil && err.Error() != "signal: killed" { return errors.Trace(err) } + s.pdProc = nil return nil } func (s *GlobalKillSuite) stopTiKV() (err error) { + if s.tikvProc == nil { + log.Info("TiKV already killed") + return nil + } if err = s.tikvProc.Process.Kill(); err != nil { return errors.Trace(err) } if err = s.tikvProc.Wait(); err != nil && err.Error() != "signal: killed" { return errors.Trace(err) } + s.tikvProc = nil return nil } @@ -254,6 +289,12 @@ func (s *GlobalKillSuite) startTiDBWithPD(port int, statusPort int, pdPath strin return cmd, nil } +func (s *GlobalKillSuite) mustStartTiDBWithPD(t *testing.T, port int, statusPort int, pdPath string) *exec.Cmd { + cmd, err := s.startTiDBWithPD(port, statusPort, pdPath) + require.Nil(t, err) + return cmd +} + func (s *GlobalKillSuite) stopService(name string, cmd *exec.Cmd, graceful bool) (err error) { log.Info("stopping: " + cmd.String()) defer func() { @@ -336,6 +377,23 @@ func (s *GlobalKillSuite) connectTiDB(port int) (db *sql.DB, err error) { return db, nil } +func (s *GlobalKillSuite) mustConnectTiDB(t *testing.T, port int) Conn { + ctx := context.TODO() + + db, err := s.connectTiDB(port) + require.Nil(t, err) + + conn, err := db.Conn(ctx) + require.NoError(t, err) + + var connID uint64 + err = conn.QueryRowContext(ctx, "SELECT CONNECTION_ID()").Scan(&connID) + require.NoError(t, err) + + log.Info("connect to server ok", zap.Int("port", port), zap.Uint64("connID", connID)) + return Conn{db, conn, connID} +} + type sleepResult struct { elapsed time.Duration err error @@ -676,4 +734,114 @@ func doTestLostConnection(t *testing.T, enable32Bits bool) { } } -// TODO: test for upgrade 32 -> 64 & downgrade 64 -> 32 +func TestServerIDUpgradeAndDowngrade(t *testing.T) { + s := createGlobalKillSuite(t, true) + require.NoErrorf(t, s.pdErr, msgErrConnectPD, s.pdErr) + + connect := func(idx int) Conn { + return s.mustConnectTiDB(t, *tidbStartPort+idx) + } + + // MaxTiDB32 is determined by `github.com/pingcap/tidb/util/globalconn.ldflagServerIDBits32` + // See the ldflags in `Makefile`. + // Also see `Domain.proposeServerID`. + const MaxTiDB32 = 2 // (3^2 -1) x 0.9 + const MaxTiDB64 = 2 + + // Startup MAX_TIDB_32 number of TiDBs. + tidbs := make([]*exec.Cmd, MaxTiDB32*2) + defer func() { + for i := range tidbs { + if tidbs[i] != nil { + s.stopService(fmt.Sprintf("tidb%v", i), tidbs[i], true) + } + } + }() + { + for i := 0; i < MaxTiDB32; i++ { + tidbs[i] = s.mustStartTiDBWithPD(t, *tidbStartPort+i, *tidbStatusPort+i, *pdClientPath) + } + for i := 0; i < MaxTiDB32; i++ { + conn := connect(i) + conn.mustBe32(t) + conn.Close() + } + } + + // Upgrade to 64 bits due to ServerID used up. + { + for i := MaxTiDB32; i < MaxTiDB32+MaxTiDB64; i++ { + tidbs[i] = s.mustStartTiDBWithPD(t, *tidbStartPort+i, *tidbStatusPort+i, *pdClientPath) + } + for i := MaxTiDB32; i < MaxTiDB32+MaxTiDB64; i++ { + conn := connect(i) + conn.mustBe64(t) + conn.Close() + } + } + + // Close TiDBs to downgrade to 32 bits. + { + for i := MaxTiDB32 / 2; i < MaxTiDB32+MaxTiDB64; i++ { + s.stopService(fmt.Sprintf("tidb%v", i), tidbs[i], true) + tidbs[i] = nil + } + + dbIdx := MaxTiDB32 + MaxTiDB64 + tidb := s.mustStartTiDBWithPD(t, *tidbStartPort+dbIdx, *tidbStatusPort+dbIdx, *pdClientPath) + defer s.stopService(fmt.Sprintf("tidb%v", dbIdx), tidb, true) + conn := connect(dbIdx) + conn.mustBe32(t) + conn.Close() + } +} + +func TestConnIDUpgradeAndDowngrade(t *testing.T) { + s := createGlobalKillSuite(t, true) + require.NoErrorf(t, s.pdErr, msgErrConnectPD, s.pdErr) + + connect := func() Conn { + return s.mustConnectTiDB(t, *tidbStartPort) + } + + tidb := s.mustStartTiDBWithPD(t, *tidbStartPort, *tidbStatusPort, *pdClientPath) + defer s.stopService("tidb0", tidb, true) + + // MaxConn32 is determined by `github.com/pingcap/tidb/util/globalconn.ldflagLocalConnIDBits32` + // See the ldflags in `Makefile`. + // Also see `LockFreeCircularPool.Cap`. + const MaxConn32 = 1<<4 - 1 + + conns32 := make(map[uint64]Conn) + defer func() { + for _, conn := range conns32 { + conn.Close() + } + }() + // 32 bits connection ID + for i := 0; i < MaxConn32; i++ { + conn := connect() + require.Lessf(t, conn.connID, uint64(1<<32), "connID %x", conn.connID) + conns32[conn.connID] = conn + } + // 32bits pool is full, should upgrade to 64 bits + for i := MaxConn32; i < MaxConn32*2; i++ { + conn := connect() + conn.mustBe64(t) + conn.Close() + } + + // Release more than half of 32 bits connections, should downgrade to 32 bits + count := MaxConn32/2 + 1 + for connID, conn := range conns32 { + conn.Close() + delete(conns32, connID) + count-- + if count == 0 { + break + } + } + conn := connect() + conn.mustBe32(t) + conn.Close() +} diff --git a/tests/globalkilltest/run-tests.sh b/tests/globalkilltest/run-tests.sh index 32216631d9624..bb8b1cbefaf98 100755 --- a/tests/globalkilltest/run-tests.sh +++ b/tests/globalkilltest/run-tests.sh @@ -81,6 +81,8 @@ done clean_cluster +# Run specified test case(s) by `-test.run` argument. +# E.g.: go_tests -test.run UpgradeAndDowngrade$ go_tests clean_cluster diff --git a/tests/globalkilltest/up.sh b/tests/globalkilltest/up.sh index 61747f6d91786..2518defe84710 100755 --- a/tests/globalkilltest/up.sh +++ b/tests/globalkilltest/up.sh @@ -26,5 +26,10 @@ cd ../.. TIDB_PATH=$(pwd) docker build -t globalkilltest -f tests/globalkilltest/Dockerfile . -docker run --name globalkilltest -it --rm -v $TIDB_PATH:/tidb globalkilltest /bin/bash -c \ + +# To see the logs, mount an additional volume to /tmp. E.g. -v $TIDB_PATH/tmp:/tmp +docker run --name globalkilltest -it --rm \ + -v $TIDB_PATH:/tidb \ + globalkilltest \ + /bin/bash -c \ 'git config --global --add safe.directory /tidb && cd /tidb/tests/globalkilltest && make && ./run-tests.sh' diff --git a/util/globalconn/globalconn.go b/util/globalconn/globalconn.go index 40e1a6cf1ce53..3089ac96127ec 100644 --- a/util/globalconn/globalconn.go +++ b/util/globalconn/globalconn.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "math" + "strconv" "github.com/ngaut/sync2" "github.com/pingcap/tidb/util/logutil" @@ -52,14 +53,18 @@ type GCID struct { Is64bits bool } -const ( +var ( + // ServerIDBits32 is the number of bits of serverID for 32bits global connection ID. + ServerIDBits32 uint = 11 // MaxServerID32 is maximum serverID for 32bits global connection ID. - MaxServerID32 = 1<<11 - 1 + MaxServerID32 uint64 = 1<