diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 52840d1b9cd..ebdcab6968b 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -290,7 +290,7 @@ func (c *RaftCluster) LoadClusterInfo() (*RaftCluster, error) { start = time.Now() // used to load region from kv storage to cache storage. - if err := c.storage.LoadRegionsOnce(c.core.CheckAndPutRegion); err != nil { + if err := c.storage.LoadRegionsOnce(c.ctx, c.core.CheckAndPutRegion); err != nil { return nil, err } log.Info("load regions", diff --git a/server/core/region_storage.go b/server/core/region_storage.go index d5ee5d546bb..8566e92d295 100644 --- a/server/core/region_storage.go +++ b/server/core/region_storage.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" "github.com/tikv/pd/pkg/encryption" @@ -131,6 +132,7 @@ func deleteRegion(kv kv.Base, region *metapb.Region) error { } func loadRegions( + ctx context.Context, kv kv.Base, encryptionKeyManager *encryptionkm.KeyManager, f func(region *RegionInfo) []*RegionInfo, @@ -143,6 +145,10 @@ func loadRegions( // a variable rangeLimit to work around. rangeLimit := maxKVRangeLimit for { + failpoint.Inject("slowLoadRegion", func() { + rangeLimit = 1 + time.Sleep(time.Second) + }) startKey := regionPath(nextID) _, res, err := kv.LoadRange(startKey, endKey, rangeLimit) if err != nil { @@ -151,6 +157,11 @@ func loadRegions( } return err } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } for _, s := range res { region := &metapb.Region{} diff --git a/server/core/storage.go b/server/core/storage.go index 1fc5f6fc212..d65c6610a69 100644 --- a/server/core/storage.go +++ b/server/core/storage.go @@ -14,6 +14,7 @@ package core import ( + "context" "encoding/json" "fmt" "math" @@ -193,22 +194,22 @@ func (s *Storage) LoadRegion(regionID uint64, region *metapb.Region) (ok bool, e } // LoadRegions loads all regions from storage to RegionsInfo. -func (s *Storage) LoadRegions(f func(region *RegionInfo) []*RegionInfo) error { +func (s *Storage) LoadRegions(ctx context.Context, f func(region *RegionInfo) []*RegionInfo) error { if atomic.LoadInt32(&s.useRegionStorage) > 0 { - return loadRegions(s.regionStorage, s.encryptionKeyManager, f) + return loadRegions(ctx, s.regionStorage, s.encryptionKeyManager, f) } - return loadRegions(s.Base, s.encryptionKeyManager, f) + return loadRegions(ctx, s.Base, s.encryptionKeyManager, f) } // LoadRegionsOnce loads all regions from storage to RegionsInfo.Only load one time from regionStorage. -func (s *Storage) LoadRegionsOnce(f func(region *RegionInfo) []*RegionInfo) error { +func (s *Storage) LoadRegionsOnce(ctx context.Context, f func(region *RegionInfo) []*RegionInfo) error { if atomic.LoadInt32(&s.useRegionStorage) == 0 { - return loadRegions(s.Base, s.encryptionKeyManager, f) + return loadRegions(ctx, s.Base, s.encryptionKeyManager, f) } s.mu.Lock() defer s.mu.Unlock() if s.regionLoaded == 0 { - if err := loadRegions(s.regionStorage, s.encryptionKeyManager, f); err != nil { + if err := loadRegions(ctx, s.regionStorage, s.encryptionKeyManager, f); err != nil { return err } s.regionLoaded = 1 diff --git a/server/core/storage_test.go b/server/core/storage_test.go index d513da0d037..bc6855c423c 100644 --- a/server/core/storage_test.go +++ b/server/core/storage_test.go @@ -14,6 +14,7 @@ package core import ( + "context" "encoding/json" "fmt" "math" @@ -143,7 +144,7 @@ func (s *testKVSuite) TestLoadRegions(c *C) { n := 10 regions := mustSaveRegions(c, storage, n) - c.Assert(storage.LoadRegions(cache.SetRegion), IsNil) + c.Assert(storage.LoadRegions(context.Background(), cache.SetRegion), IsNil) c.Assert(cache.GetRegionCount(), Equals, n) for _, region := range cache.GetMetaRegions() { @@ -157,7 +158,7 @@ func (s *testKVSuite) TestLoadRegionsToCache(c *C) { n := 10 regions := mustSaveRegions(c, storage, n) - c.Assert(storage.LoadRegionsOnce(cache.SetRegion), IsNil) + c.Assert(storage.LoadRegionsOnce(context.Background(), cache.SetRegion), IsNil) c.Assert(cache.GetRegionCount(), Equals, n) for _, region := range cache.GetMetaRegions() { @@ -166,7 +167,7 @@ func (s *testKVSuite) TestLoadRegionsToCache(c *C) { n = 20 mustSaveRegions(c, storage, n) - c.Assert(storage.LoadRegionsOnce(cache.SetRegion), IsNil) + c.Assert(storage.LoadRegionsOnce(context.Background(), cache.SetRegion), IsNil) c.Assert(cache.GetRegionCount(), Equals, n) } @@ -176,7 +177,7 @@ func (s *testKVSuite) TestLoadRegionsExceedRangeLimit(c *C) { n := 1000 regions := mustSaveRegions(c, storage, n) - c.Assert(storage.LoadRegions(cache.SetRegion), IsNil) + c.Assert(storage.LoadRegions(context.Background(), cache.SetRegion), IsNil) c.Assert(cache.GetRegionCount(), Equals, n) for _, region := range cache.GetMetaRegions() { c.Assert(region, DeepEquals, regions[region.GetId()]) diff --git a/server/region_syncer/client.go b/server/region_syncer/client.go index 2f425cf8f6b..58f516fd6aa 100644 --- a/server/region_syncer/client.go +++ b/server/region_syncer/client.go @@ -40,10 +40,6 @@ const ( // StopSyncWithLeader stop to sync the region with leader. func (s *RegionSyncer) StopSyncWithLeader() { s.reset() - s.mu.Lock() - close(s.mu.closed) - s.mu.closed = make(chan struct{}) - s.mu.Unlock() s.wg.Wait() } @@ -51,19 +47,15 @@ func (s *RegionSyncer) reset() { s.mu.Lock() defer s.mu.Unlock() - if s.mu.regionSyncerCancel == nil { - return + if s.mu.clientCancel != nil { + s.mu.clientCancel() } - s.mu.regionSyncerCancel() - s.mu.regionSyncerCancel, s.mu.regionSyncerCtx = nil, nil + s.mu.clientCancel, s.mu.clientCtx = nil, nil } -func (s *RegionSyncer) establish(addr string) (*grpc.ClientConn, error) { - s.reset() - ctx, cancel := context.WithCancel(s.server.LoopContext()) +func (s *RegionSyncer) establish(ctx context.Context, addr string) (*grpc.ClientConn, error) { tlsCfg, err := s.tlsConfig.ToTLSConfig() if err != nil { - cancel() return nil, err } cc, err := grpcutil.GetClientConn( @@ -88,28 +80,16 @@ func (s *RegionSyncer) establish(addr string) (*grpc.ClientConn, error) { grpc.WithBlock(), ) if err != nil { - cancel() return nil, errors.WithStack(err) } - - s.mu.Lock() - s.mu.regionSyncerCtx, s.mu.regionSyncerCancel = ctx, cancel - s.mu.Unlock() return cc, nil } -func (s *RegionSyncer) syncRegion(conn *grpc.ClientConn) (ClientStream, error) { +func (s *RegionSyncer) syncRegion(ctx context.Context, conn *grpc.ClientConn) (ClientStream, error) { cli := pdpb.NewPDClient(conn) - var ctx context.Context - s.mu.RLock() - ctx = s.mu.regionSyncerCtx - s.mu.RUnlock() - if ctx == nil { - return nil, errors.New("syncRegion failed due to regionSyncerCtx is nil") - } syncStream, err := cli.SyncRegions(ctx) if err != nil { - return syncStream, errs.ErrGRPCCreateStream.Wrap(err).FastGenWithCause() + return nil, errs.ErrGRPCCreateStream.Wrap(err).FastGenWithCause() } err = syncStream.Send(&pdpb.SyncRegionRequest{ Header: &pdpb.RequestHeader{ClusterId: s.server.ClusterID()}, @@ -117,7 +97,7 @@ func (s *RegionSyncer) syncRegion(conn *grpc.ClientConn) (ClientStream, error) { StartIndex: s.history.GetNextIndex(), }) if err != nil { - return syncStream, errs.ErrGRPCSend.Wrap(err).FastGenWithCause() + return nil, errs.ErrGRPCSend.Wrap(err).FastGenWithCause() } return syncStream, nil @@ -128,15 +108,21 @@ var regionGuide = core.GenerateRegionGuideFunc(false) // StartSyncWithLeader starts to sync with leader. func (s *RegionSyncer) StartSyncWithLeader(addr string) { s.wg.Add(1) - s.mu.RLock() - closed := s.mu.closed - s.mu.RUnlock() + + s.mu.Lock() + defer s.mu.Unlock() + s.mu.clientCtx, s.mu.clientCancel = context.WithCancel(s.server.LoopContext()) + ctx := s.mu.clientCtx + go func() { defer s.wg.Done() // used to load region from kv storage to cache storage. bc := s.server.GetBasicCluster() storage := s.server.GetStorage() - err := storage.LoadRegionsOnce(bc.CheckAndPutRegion) + log.Info("region syncer start load region") + start := time.Now() + err := storage.LoadRegionsOnce(ctx, bc.CheckAndPutRegion) + log.Info("region syncer finished load region", zap.Duration("time-cost", time.Since(start))) if err != nil { log.Warn("failed to load regions.", errs.ZapError(err)) } @@ -144,11 +130,11 @@ func (s *RegionSyncer) StartSyncWithLeader(addr string) { var conn *grpc.ClientConn for { select { - case <-closed: + case <-ctx.Done(): return default: } - conn, err = s.establish(addr) + conn, err = s.establish(ctx, addr) if err != nil { log.Error("cannot establish connection with leader", zap.String("server", s.server.Name()), zap.String("leader", s.server.GetLeader().GetName()), errs.ZapError(err)) continue @@ -160,12 +146,12 @@ func (s *RegionSyncer) StartSyncWithLeader(addr string) { // Start syncing data. for { select { - case <-closed: + case <-ctx.Done(): return default: } - stream, err := s.syncRegion(conn) + stream, err := s.syncRegion(ctx, conn) if err != nil { if ev, ok := status.FromError(err); ok { if ev.Code() == codes.Canceled { diff --git a/server/region_syncer/client_test.go b/server/region_syncer/client_test.go new file mode 100644 index 00000000000..1084aa45bbe --- /dev/null +++ b/server/region_syncer/client_test.go @@ -0,0 +1,105 @@ +// Copyright 2021 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncer + +import ( + "context" + "io/ioutil" + "os" + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/tikv/pd/pkg/grpcutil" + "github.com/tikv/pd/server/core" + "github.com/tikv/pd/server/kv" +) + +var _ = Suite(&testClientSuite{}) + +type testClientSuite struct{} + +// For issue https://github.com/tikv/pd/issues/3936 +func (t *testClientSuite) TestLoadRegion(c *C) { + tempDir, err := ioutil.TempDir(os.TempDir(), "region_syncer_load_region") + c.Assert(err, IsNil) + defer os.RemoveAll(tempDir) + rs, err := core.NewRegionStorage(context.Background(), tempDir, nil) + c.Assert(err, IsNil) + + server := &mockServer{ + ctx: context.Background(), + storage: core.NewStorage(kv.NewMemoryKV(), core.WithRegionStorage(rs)), + bc: core.NewBasicCluster(), + } + for i := 0; i < 30; i++ { + rs.SaveRegion(&metapb.Region{Id: uint64(i) + 1}) + } + c.Assert(failpoint.Enable("github.com/tikv/pd/server/core/slowLoadRegion", "return(true)"), IsNil) + defer func() { c.Assert(failpoint.Disable("github.com/tikv/pd/server/core/slowLoadRegion"), IsNil) }() + + rc := NewRegionSyncer(server) + start := time.Now() + rc.StartSyncWithLeader("") + time.Sleep(time.Second) + rc.StopSyncWithLeader() + c.Assert(time.Since(start), Greater, time.Second) // make sure failpoint is injected + c.Assert(time.Since(start), Less, time.Second*2) +} + +type mockServer struct { + ctx context.Context + member, leader *pdpb.Member + storage *core.Storage + bc *core.BasicCluster +} + +func (s *mockServer) LoopContext() context.Context { + return s.ctx +} + +func (s *mockServer) ClusterID() uint64 { + return 1 +} + +func (s *mockServer) GetMemberInfo() *pdpb.Member { + return s.member +} + +func (s *mockServer) GetLeader() *pdpb.Member { + return s.leader +} + +func (s *mockServer) GetStorage() *core.Storage { + return s.storage +} + +func (s *mockServer) Name() string { + return "mock-server" +} + +func (s *mockServer) GetRegions() []*core.RegionInfo { + return s.bc.GetRegions() +} + +func (s *mockServer) GetTLSConfig() *grpcutil.TLSConfig { + return &grpcutil.TLSConfig{} +} + +func (s *mockServer) GetBasicCluster() *core.BasicCluster { + return s.bc +} diff --git a/server/region_syncer/server.go b/server/region_syncer/server.go index d0ab8f7f221..6569b8863de 100644 --- a/server/region_syncer/server.go +++ b/server/region_syncer/server.go @@ -69,10 +69,9 @@ type Server interface { type RegionSyncer struct { mu struct { sync.RWMutex - streams map[string]ServerStream - regionSyncerCtx context.Context - regionSyncerCancel context.CancelFunc - closed chan struct{} + streams map[string]ServerStream + clientCtx context.Context + clientCancel context.CancelFunc } server Server wg sync.WaitGroup @@ -94,7 +93,6 @@ func NewRegionSyncer(s Server) *RegionSyncer { tlsConfig: s.GetTLSConfig(), } syncer.mu.streams = make(map[string]ServerStream) - syncer.mu.closed = make(chan struct{}) return syncer } diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index c3c15c280b2..ce0ddcfdcbd 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -718,7 +718,7 @@ func (s *clusterTestSuite) TestLoadClusterInfo(c *C) { for _, region := range regions { c.Assert(storage.SaveRegion(region), IsNil) } - raftCluster.GetStorage().LoadRegionsOnce(raftCluster.GetCacheCluster().PutRegion) + raftCluster.GetStorage().LoadRegionsOnce(s.ctx, raftCluster.GetCacheCluster().PutRegion) c.Assert(raftCluster.GetRegionCount(), Equals, n) }