diff --git a/pkg/syncer/client.go b/pkg/syncer/client.go index 6f164b777ba..5d4603de92d 100644 --- a/pkg/syncer/client.go +++ b/pkg/syncer/client.go @@ -65,6 +65,11 @@ func (s *RegionSyncer) reset() { s.mu.clientCancel, s.mu.clientCtx = nil, nil } +// ResetHistoryIndex resets and persists the next region sync history index. +func (s *RegionSyncer) ResetHistoryIndex(index uint64) { + s.history.resetWithIndexAndPersist(index) +} + func (s *RegionSyncer) syncRegion(ctx context.Context, conn *grpc.ClientConn) (ClientStream, error) { cli := pdpb.NewPDClient(conn) syncStream, err := cli.SyncRegions(ctx) @@ -85,7 +90,7 @@ func (s *RegionSyncer) syncRegion(ctx context.Context, conn *grpc.ClientConn) (C var regionGuide = core.GenerateRegionGuideFunc(false) -func (s *RegionSyncer) handleRegionSyncResponse(ctx context.Context, resp *pdpb.SyncRegionResponse, bc *core.BasicCluster, regionStorage storage.Storage) { +func (s *RegionSyncer) handleRegionSyncResponse(ctx context.Context, resp *pdpb.SyncRegionResponse, bc *core.BasicCluster, regionStorage storage.Storage, fullSyncing *bool) { if s.history.getNextIndex() != resp.GetStartIndex() { log.Warn("server sync index not match the leader", zap.String("server", s.server.Name()), @@ -99,6 +104,9 @@ func (s *RegionSyncer) handleRegionSyncResponse(ctx context.Context, resp *pdpb. regions := resp.GetRegions() buckets := resp.GetBuckets() regionLeaders := resp.GetRegionLeaders() + if !s.IsRunning() && resp.GetStartIndex() == 0 && len(regions) > 0 { + *fullSyncing = true + } hasStats := len(stats) == len(regions) hasBuckets := len(buckets) == len(regions) for i, r := range regions { @@ -146,6 +154,13 @@ func (s *RegionSyncer) handleRegionSyncResponse(ctx context.Context, resp *pdpb. _ = regionStorage.DeleteRegion(old.GetMeta()) } } + if *fullSyncing { + if len(regions) == 0 { + *fullSyncing = false + s.streamingRunning.Store(true) + } + return + } // mark the client as running status when it finished the first history region sync. s.streamingRunning.Store(true) } @@ -233,6 +248,7 @@ func (s *RegionSyncer) StartSyncWithLeader(addr string) { continue } log.Info("server starts to synchronize with leader", zap.String("server", s.server.Name()), zap.String("leader", s.server.GetLeader().GetName()), zap.Uint64("request-index", s.history.getNextIndex())) + fullSyncing := false for { resp, err := stream.Recv() if err == io.EOF { @@ -260,7 +276,7 @@ func (s *RegionSyncer) StartSyncWithLeader(addr string) { } break } - s.handleRegionSyncResponse(ctx, resp, bc, regionStorage) + s.handleRegionSyncResponse(ctx, resp, bc, regionStorage, &fullSyncing) } } }() diff --git a/pkg/syncer/history_buffer.go b/pkg/syncer/history_buffer.go index a066762405f..ea4f11e0d24 100644 --- a/pkg/syncer/history_buffer.go +++ b/pkg/syncer/history_buffer.go @@ -122,18 +122,35 @@ func (h *historyBuffer) recordsFrom(index uint64) []*core.RegionInfo { func (h *historyBuffer) resetWithIndex(index uint64) { h.Lock() defer h.Unlock() + h.resetWithIndexLocked(index) +} + +func (h *historyBuffer) resetWithIndexLocked(index uint64) { h.index = index h.head = 0 h.tail = 0 h.flushCount = defaultFlushCount } +func (h *historyBuffer) resetWithIndexAndPersist(index uint64) { + h.Lock() + defer h.Unlock() + h.resetWithIndexLocked(index) + h.persist() +} + func (h *historyBuffer) getNextIndex() uint64 { h.RLock() defer h.RUnlock() return h.index } +func (h *historyBuffer) getFirstIndex() uint64 { + h.RLock() + defer h.RUnlock() + return h.firstIndex() +} + func (h *historyBuffer) get(index uint64) *core.RegionInfo { h.RLock() defer h.RUnlock() diff --git a/pkg/syncer/server.go b/pkg/syncer/server.go index 415e0d62985..bbb5d8e85dd 100644 --- a/pkg/syncer/server.go +++ b/pkg/syncer/server.go @@ -242,12 +242,14 @@ func (s *RegionSyncer) Sync(ctx context.Context, stream pdpb.PD_SyncRegionsServe zap.String("requested-server", request.GetMember().GetName()), zap.String("url", request.GetMember().GetClientUrls()[0])) - err = s.syncHistoryRegion(ctx, request, stream) + syncStream, err := s.syncHistoryRegion(ctx, request, stream) if err != nil { return err } name := request.GetMember().GetName() - syncStream := s.bindStream(name, stream) + if syncStream == nil { + syncStream = s.bindStream(name, stream) + } select { case <-ctx.Done(): s.unbindStream(name, syncStream) @@ -284,9 +286,12 @@ func recvSyncRegionRequest(ctx context.Context, stream pdpb.PD_SyncRegionsServer } } -func (s *RegionSyncer) syncHistoryRegion(ctx context.Context, request *pdpb.SyncRegionRequest, stream pdpb.PD_SyncRegionsServer) error { +func (s *RegionSyncer) syncHistoryRegion(ctx context.Context, request *pdpb.SyncRegionRequest, stream pdpb.PD_SyncRegionsServer) (*regionSyncStream, error) { startIndex := request.GetStartIndex() name := request.GetMember().GetName() + if startIndex == 0 { + return s.syncFullRegions(ctx, name, stream) + } records := s.history.recordsFrom(startIndex) if len(records) == 0 { if s.history.getNextIndex() == startIndex { @@ -301,21 +306,17 @@ func (s *RegionSyncer) syncHistoryRegion(ctx context.Context, request *pdpb.Sync RegionLeaders: nil, Buckets: nil, } - return stream.Send(resp) + return nil, stream.Send(resp) } - // do full synchronization - if startIndex == 0 { - return s.syncFullRegions(ctx, name, stream) - } - log.Warn("no history regions from index, the leader may be restarted", zap.Uint64("index", startIndex)) - return nil + log.Warn("no history regions from index, fall back to full sync", zap.Uint64("index", startIndex)) + return s.syncFullRegions(ctx, name, stream) } log.Info("sync the history regions with server", zap.String("server", name), zap.Uint64("from-index", startIndex), zap.Uint64("last-index", s.history.getNextIndex()), zap.Int("records-length", len(records))) - return s.syncHistoryRecords(startIndex, records, stream) + return nil, s.syncHistoryRecords(startIndex, records, stream) } func (*RegionSyncer) syncHistoryRecords(startIndex uint64, records []*core.RegionInfo, stream pdpb.PD_SyncRegionsServer) error { @@ -348,10 +349,41 @@ func (*RegionSyncer) syncHistoryRecords(startIndex uint64, records []*core.Regio return stream.Send(resp) } -func (s *RegionSyncer) syncFullRegions(ctx context.Context, name string, stream pdpb.PD_SyncRegionsServer) error { +func (s *RegionSyncer) syncFullRegions(ctx context.Context, name string, stream pdpb.PD_SyncRegionsServer) (*regionSyncStream, error) { + for { + start := time.Now() + catchUpIndex, canceled, err := s.sendFullRegionSnapshot(ctx, stream) + if err != nil { + return nil, err + } + if canceled { + return nil, nil + } + catchUpIndex, restart, err := s.catchUpFullSyncHistory(name, catchUpIndex, stream) + if err != nil { + return nil, err + } + if restart { + continue + } + return s.completeFullSyncAndBindStream(name, stream, catchUpIndex, start) + } +} + +func (s *RegionSyncer) sendFullRegionSnapshot(ctx context.Context, stream pdpb.PD_SyncRegionsServer) (uint64, bool, error) { + catchUpIndex := s.history.getNextIndex() regions := s.server.GetRegions() + if len(regions) == 0 { + resp := &pdpb.SyncRegionResponse{ + Header: &pdpb.ResponseHeader{ClusterId: keypath.ClusterID()}, + StartIndex: 0, + } + if err := stream.Send(resp); err != nil { + log.Warn("failed to send sync region response", errs.ZapError(errs.ErrGRPCSend, err)) + return 0, false, err + } + } lastIndex := 0 - start := time.Now() metas := make([]*metapb.Region, 0, maxSyncRegionBatchSize) stats := make([]*pdpb.RegionStat, 0, maxSyncRegionBatchSize) leaders := make([]*metapb.Peer, 0, maxSyncRegionBatchSize) @@ -363,7 +395,7 @@ func (s *RegionSyncer) syncFullRegions(ctx context.Context, name string, stream failpoint.Inject("noFastExitSync", func() { failpoint.Goto("doSync") }) - return nil + return 0, true, nil default: } failpoint.Label("doSync") @@ -392,21 +424,84 @@ func (s *RegionSyncer) syncFullRegions(ctx context.Context, name string, stream } if err := s.limit.WaitN(ctx, resp.Size()); err != nil { log.Error("failed to wait rate limit", errs.ZapError(err)) - return err + return 0, false, err } lastIndex += len(metas) if err := stream.Send(resp); err != nil { log.Error("failed to send sync region response", errs.ZapError(errs.ErrGRPCSend, err)) - return err + return 0, false, err } metas = metas[:0] stats = stats[:0] leaders = leaders[:0] buckets = buckets[:0] } + return catchUpIndex, false, nil +} + +func (s *RegionSyncer) catchUpFullSyncHistory( + name string, + catchUpIndex uint64, + stream pdpb.PD_SyncRegionsServer, +) (uint64, bool, error) { + for { + records := s.history.recordsFrom(catchUpIndex) + if len(records) == 0 { + if catchUpIndex < s.history.getFirstIndex() { + log.Warn("region history buffer overflow during full synchronization, restart full synchronization", + zap.String("requested-server", name), + zap.String("server", s.server.Name()), + zap.Uint64("catch-up-index", catchUpIndex), + zap.Uint64("first-index", s.history.getFirstIndex())) + return catchUpIndex, true, nil + } + if catchUpIndex == s.history.getNextIndex() { + break + } + continue + } + if err := s.syncHistoryRecords(catchUpIndex, records, stream); err != nil { + return 0, false, err + } + catchUpIndex += uint64(len(records)) + } + return catchUpIndex, false, nil +} + +func (s *RegionSyncer) completeFullSyncAndBindStream( + name string, + stream pdpb.PD_SyncRegionsServer, + catchUpIndex uint64, + start time.Time, +) (*regionSyncStream, error) { + syncStream := newRegionSyncStream(stream) + s.mu.Lock() + defer s.mu.Unlock() + for { + records := s.history.recordsFrom(catchUpIndex) + if len(records) == 0 { + if catchUpIndex < s.history.getFirstIndex() { + return nil, errors.Errorf("region history buffer overflow during full sync catch-up, catch-up-index %d, first-index %d", catchUpIndex, s.history.getFirstIndex()) + } + break + } + if err := s.syncHistoryRecords(catchUpIndex, records, stream); err != nil { + return nil, err + } + catchUpIndex += uint64(len(records)) + } log.Info("requested server has completed full synchronization with server", zap.String("requested-server", name), zap.String("server", s.server.Name()), zap.Duration("cost", time.Since(start))) - return nil + resp := &pdpb.SyncRegionResponse{ + Header: &pdpb.ResponseHeader{ClusterId: keypath.ClusterID()}, + StartIndex: catchUpIndex, + } + if err := stream.Send(resp); err != nil { + log.Warn("failed to send sync region completion response", errs.ZapError(errs.ErrGRPCSend, err)) + return nil, err + } + s.bindStreamLocked(name, syncStream) + return syncStream, nil } // bindStream binds the established server stream. @@ -414,11 +509,15 @@ func (s *RegionSyncer) bindStream(name string, stream ServerStream) *regionSyncS syncStream := newRegionSyncStream(stream) s.mu.Lock() defer s.mu.Unlock() + s.bindStreamLocked(name, syncStream) + return syncStream +} + +func (s *RegionSyncer) bindStreamLocked(name string, syncStream *regionSyncStream) { if oldStream := s.mu.streams[name]; oldStream != nil { oldStream.close() } s.mu.streams[name] = syncStream - return syncStream } func (s *RegionSyncer) unbindStream(name string, stream *regionSyncStream) { diff --git a/pkg/syncer/server_test.go b/pkg/syncer/server_test.go index 2b5398e841f..ea6c1ad66bc 100644 --- a/pkg/syncer/server_test.go +++ b/pkg/syncer/server_test.go @@ -27,6 +27,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/tikv/pd/pkg/core" @@ -68,6 +69,7 @@ func TestSyncExitsWhenRegionSyncerStops(t *testing.T) { }, } re.NotNil(<-stream.sendCh) + re.NotNil(<-stream.sendCh) testutil.Eventually(re, func() bool { names := syncer.GetAllDownstreamNames() return len(names) == 1 && names[0] == "pd-follower" @@ -122,6 +124,7 @@ func TestSyncExitsWhenBroadcastSendFails(t *testing.T) { }, } re.NotNil(<-stream.sendCh) + re.NotNil(<-stream.sendCh) testutil.Eventually(re, func() bool { names := syncer.GetAllDownstreamNames() return len(names) == 1 && names[0] == "pd-follower" @@ -181,6 +184,7 @@ func TestCloseAllClientClosesStreamsBeforeSend(t *testing.T) { }, } re.NotNil(<-stream.sendCh) + re.NotNil(<-stream.sendCh) testutil.Eventually(re, func() bool { names := syncer.GetAllDownstreamNames() return len(names) == 1 && names[0] == "pd-follower" @@ -253,6 +257,7 @@ func TestBroadcastClosesStreamWhenSendBlocks(t *testing.T) { }, } re.NotNil(<-stream.sendCh) + re.NotNil(<-stream.sendCh) testutil.Eventually(re, func() bool { names := syncer.GetAllDownstreamNames() return len(names) == 1 && names[0] == "pd-follower" @@ -333,6 +338,208 @@ func TestSyncExitsWhenContextCanceledBeforeRequest(t *testing.T) { }) } +func TestSyncFallsBackToFullSyncWhenHistoryMissing(t *testing.T) { + re := require.New(t) + syncer, _ := newTestRegionSyncer(t, newTestSyncRegion(1, 11)) + syncer.history.resetWithIndex(100) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream := newMockSyncRegionsServer() + blockCh := stream.blockSend() + done := startTestRegionSync(ctx, syncer, stream) + + sendTestSyncRegionRequest(stream, 1) + testutil.Eventually(re, stream.isSendBlocked) + syncer.history.record(newTestSyncRegion(2, 12)) + close(blockCh) + + resp := mustRecvSyncRegionResponse(t, stream, "expected full sync response") + re.Equal(uint64(0), resp.GetStartIndex()) + re.Len(resp.GetRegions(), 1) + re.Equal(uint64(1), resp.GetRegions()[0].GetId()) + + resp = mustRecvSyncRegionResponse(t, stream, "expected full sync catch-up response") + re.Equal(uint64(100), resp.GetStartIndex()) + re.Len(resp.GetRegions(), 1) + re.Equal(uint64(2), resp.GetRegions()[0].GetId()) + + resp = mustRecvSyncRegionResponse(t, stream, "expected full sync completion response") + re.Equal(uint64(101), resp.GetStartIndex()) + re.Empty(resp.GetRegions()) + waitTestRegionSyncerBound(re, syncer) + + cancel() + waitTestRegionSyncerUnavailable(re, done) +} + +func TestFullSyncRestartsWhenHistoryBufferOverflowsDuringCatchUp(t *testing.T) { + re := require.New(t) + syncer, bc := newTestRegionSyncer(t, newTestSyncRegion(1, 11)) + syncer.history = newHistoryBuffer(1, syncer.history.kv) + syncer.history.resetWithIndex(100) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream := newMockSyncRegionsServer() + blockCh := stream.blockSend() + done := startTestRegionSync(ctx, syncer, stream) + + sendTestSyncRegionRequest(stream, 1) + testutil.Eventually(re, stream.isSendBlocked) + for _, region := range []*core.RegionInfo{ + newTestSyncRegion(2, 12), + newTestSyncRegion(3, 13), + } { + bc.PutRegion(region) + syncer.history.record(region) + } + close(blockCh) + + resp := mustRecvSyncRegionResponse(t, stream, "expected original full sync response") + re.Equal(uint64(0), resp.GetStartIndex()) + re.Len(resp.GetRegions(), 1) + re.Equal(uint64(1), resp.GetRegions()[0].GetId()) + + resp = mustRecvSyncRegionResponse(t, stream, "expected restarted full sync response") + re.Equal(uint64(0), resp.GetStartIndex()) + re.Len(resp.GetRegions(), 3) + regionIDs := make([]uint64, 0, len(resp.GetRegions())) + for _, region := range resp.GetRegions() { + regionIDs = append(regionIDs, region.GetId()) + } + re.ElementsMatch([]uint64{1, 2, 3}, regionIDs) + + resp = mustRecvSyncRegionResponse(t, stream, "expected full sync completion response") + re.Equal(uint64(102), resp.GetStartIndex()) + re.Empty(resp.GetRegions()) + waitTestRegionSyncerBound(re, syncer) + + cancel() + waitTestRegionSyncerUnavailable(re, done) +} + +func TestClientWaitsForFullSyncCompletionBeforeRunning(t *testing.T) { + re := require.New(t) + regionStorage := storage.NewStorageWithMemoryBackend() + server := mockserver.NewMockServer( + context.Background(), + nil, + nil, + regionStorage, + core.NewBasicCluster(), + ) + syncer := NewRegionSyncer(server) + bc := core.NewBasicCluster() + fullSyncing := false + region := &metapb.Region{ + Id: 1, + StartKey: []byte{1}, + EndKey: []byte{2}, + RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, + Peers: []*metapb.Peer{{Id: 11, StoreId: 1}}, + } + + syncer.handleRegionSyncResponse(context.Background(), &pdpb.SyncRegionResponse{ + Header: &pdpb.ResponseHeader{ClusterId: keypath.ClusterID()}, + Regions: []*metapb.Region{region}, + StartIndex: 0, + }, bc, regionStorage, &fullSyncing) + re.True(fullSyncing) + re.False(syncer.IsRunning()) + + syncer.handleRegionSyncResponse(context.Background(), &pdpb.SyncRegionResponse{ + Header: &pdpb.ResponseHeader{ClusterId: keypath.ClusterID()}, + StartIndex: 1, + }, bc, regionStorage, &fullSyncing) + re.False(fullSyncing) + re.True(syncer.IsRunning()) +} + +func newTestRegionSyncer(t *testing.T, regions ...*core.RegionInfo) (*RegionSyncer, *core.BasicCluster) { + t.Helper() + re := require.New(t) + tempDir := t.TempDir() + regionStorage, err := storage.NewRegionStorageWithLevelDBBackend(context.Background(), tempDir, nil) + re.NoError(err) + t.Cleanup(func() { + re.NoError(regionStorage.Close()) + }) + + bc := core.NewBasicCluster() + for _, region := range regions { + bc.PutRegion(region) + } + server := mockserver.NewMockServer( + context.Background(), + nil, + nil, + storage.NewCoreStorage(storage.NewStorageWithMemoryBackend(), regionStorage), + bc, + ) + return NewRegionSyncer(server), bc +} + +func newTestSyncRegion(regionID, peerID uint64) *core.RegionInfo { + return core.NewRegionInfo(&metapb.Region{ + Id: regionID, + StartKey: []byte{byte(regionID)}, + EndKey: []byte{byte(regionID + 1)}, + RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, + Peers: []*metapb.Peer{{Id: peerID, StoreId: 1}}, + }, &metapb.Peer{Id: peerID, StoreId: 1}) +} + +func startTestRegionSync(ctx context.Context, syncer *RegionSyncer, stream *mockSyncRegionsServer) chan error { + done := make(chan error, 1) + go func() { + done <- syncer.Sync(ctx, stream) + }() + return done +} + +func sendTestSyncRegionRequest(stream *mockSyncRegionsServer, startIndex uint64) { + stream.recvCh <- &pdpb.SyncRegionRequest{ + Header: &pdpb.RequestHeader{ClusterId: keypath.ClusterID()}, + StartIndex: startIndex, + Member: &pdpb.Member{ + Name: "pd-follower", + ClientUrls: []string{"http://127.0.0.1:2379"}, + }, + } +} + +func mustRecvSyncRegionResponse(t *testing.T, stream *mockSyncRegionsServer, message string) *pdpb.SyncRegionResponse { + t.Helper() + select { + case resp := <-stream.sendCh: + return resp + case <-time.After(3 * time.Second): + require.FailNow(t, message) + return nil + } +} + +func waitTestRegionSyncerBound(re *require.Assertions, syncer *RegionSyncer) { + testutil.Eventually(re, func() bool { + names := syncer.GetAllDownstreamNames() + return len(names) == 1 && names[0] == "pd-follower" + }) +} + +func waitTestRegionSyncerUnavailable(re *require.Assertions, done <-chan error) { + var syncErr error + testutil.Eventually(re, func() bool { + if syncErr == nil { + select { + case syncErr = <-done: + default: + return false + } + } + st, ok := status.FromError(syncErr) + return ok && st.Code() == codes.Unavailable + }) +} + type mockSyncRegionsServer struct { mu sync.Mutex ctx context.Context diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 8aaa79485aa..69c947a5358 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -55,7 +55,7 @@ const ( // PDRedirectorHeader is used to mark which PD redirected this request. PDRedirectorHeader = "PD-Redirector" - // PDAllowFollowerHandleHeader is used to mark whether this request is allowed to be handled by the follower PD. + // PDAllowFollowerHandleHeader is used to mark whether this request is allowed to be handled by the follower PD locally. PDAllowFollowerHandleHeader = "PD-Allow-follower-handle" // #nosec G101 // XForwardedForHeader is used to mark the client IP. XForwardedForHeader = "X-Forwarded-For" diff --git a/server/api/admin.go b/server/api/admin.go index 73eafd9730e..f298ee4cb0c 100644 --- a/server/api/admin.go +++ b/server/api/admin.go @@ -57,8 +57,9 @@ func newAdminHandler(svr *server.Server, rd *render.Render) *adminHandler { // @Summary Drop a specific region from cache. // @Param id path integer true "Region Id" // @Produce json -// @Success 200 {string} string "The region is removed from server cache." +// @Success 200 {string} string "The region is removed from server/follower cache." // @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "The follower failed to reset region cache." // @Router /admin/cache/region/{id} [delete] func (h *adminHandler) DeleteRegionCache(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) @@ -69,6 +70,14 @@ func (h *adminHandler) DeleteRegionCache(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } + if isFollowerSyncedClusterRequest(r) { + if err = h.svr.ResetFollowerRegionCache(regionID); err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + h.rd.JSON(w, http.StatusOK, "The region is removed from follower cache and the follower starts to resync regions from leader.") + return + } rc.RemoveRegionIfExist(regionID) msg := "The region is removed from server cache." if rc.IsServiceIndependent(constant.SchedulingServiceName) { @@ -131,11 +140,20 @@ func (h *adminHandler) DeleteRegionStorage(w http.ResponseWriter, r *http.Reques // @Tags admin // @Summary Drop all regions from cache. // @Produce json -// @Success 200 {string} string "All regions are removed from server cache." +// @Success 200 {string} string "All regions are removed from server/follower cache." +// @Failure 500 {string} string "The follower failed to reset region cache." // @Router /admin/cache/regions [delete] func (h *adminHandler) DeleteAllRegionCache(w http.ResponseWriter, r *http.Request) { var err error rc := getCluster(r) + if isFollowerSyncedClusterRequest(r) { + if err = h.svr.ResetFollowerRegionCache(); err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + h.rd.JSON(w, http.StatusOK, "All regions are removed from follower cache and the follower starts to resync regions from leader.") + return + } rc.ResetRegionCache() msg := "All regions are removed from server cache." if rc.IsServiceIndependent(constant.SchedulingServiceName) { diff --git a/server/api/middleware.go b/server/api/middleware.go index acf41c00124..203751e8a07 100644 --- a/server/api/middleware.go +++ b/server/api/middleware.go @@ -83,6 +83,7 @@ type clusterMiddleware struct { s *server.Server rd *render.Render allowFollowerSyncedRegion bool + allowFollowerRegionReset bool } type clusterMiddlewareOption func(*clusterMiddleware) @@ -93,6 +94,12 @@ func withFollowerSyncedRegion() clusterMiddlewareOption { } } +func withFollowerRegionReset() clusterMiddlewareOption { + return func(m *clusterMiddleware) { + m.allowFollowerRegionReset = true + } +} + func newClusterMiddleware(s *server.Server, opts ...clusterMiddlewareOption) clusterMiddleware { m := clusterMiddleware{ s: s, @@ -107,38 +114,60 @@ func newClusterMiddleware(s *server.Server, opts ...clusterMiddlewareOption) clu func (m clusterMiddleware) middleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rc := m.s.GetRaftCluster() + isFollowerSyncedCluster := false if rc == nil { rc = m.getFollowerSyncedCluster(r) + isFollowerSyncedCluster = rc != nil } if rc == nil { m.rd.JSON(w, http.StatusInternalServerError, errs.ErrNotBootstrapped.FastGenByArgs().Error()) return } ctx := context.WithValue(r.Context(), clusterCtxKey{}, rc) + ctx = context.WithValue(ctx, followerSyncedClusterCtxKey{}, isFollowerSyncedCluster) h.ServeHTTP(w, r.WithContext(ctx)) }) } func (m clusterMiddleware) getFollowerSyncedCluster(r *http.Request) *cluster.RaftCluster { - if r.Method != http.MethodGet || - !m.allowFollowerSyncedRegion || - m.s.GetMember().IsServing() || - r.Header.Get(apiutil.PDAllowFollowerHandleHeader) == "" { + if r.Header.Get(apiutil.PDAllowFollowerHandleHeader) == "" || + m.s.GetMember().IsServing() { + return nil + } + switch r.Method { + case http.MethodGet: + if !m.allowFollowerSyncedRegion { + return nil + } + case http.MethodDelete: + if !m.allowFollowerRegionReset { + return nil + } + default: return nil } rc := m.s.DirectlyGetRaftCluster() - if rc == nil || !rc.GetRegionSyncer().IsRunning() { + if rc == nil { + return nil + } + if r.Method == http.MethodGet && !rc.GetRegionSyncer().IsRunning() { return nil } return rc } type clusterCtxKey struct{} +type followerSyncedClusterCtxKey struct{} func getCluster(r *http.Request) *cluster.RaftCluster { return r.Context().Value(clusterCtxKey{}).(*cluster.RaftCluster) } +func isFollowerSyncedClusterRequest(r *http.Request) bool { + v, _ := r.Context().Value(followerSyncedClusterCtxKey{}).(bool) + return v +} + type auditMiddleware struct { svr *server.Server } diff --git a/server/api/router.go b/server/api/router.go index f1593f5e388..477b396e936 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -132,6 +132,8 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { clusterRouter.Use(newClusterMiddleware(svr).middleware) regionReadRouter := apiRouter.NewRoute().Subrouter() regionReadRouter.Use(newClusterMiddleware(svr, withFollowerSyncedRegion()).middleware) + regionResetRouter := apiRouter.NewRoute().Subrouter() + regionResetRouter.Use(newClusterMiddleware(svr, withFollowerRegionReset()).middleware) escapeRouter := clusterRouter.NewRoute().Subrouter().UseEncodedPath() regionReadEscapeRouter := regionReadRouter.NewRoute().Subrouter().UseEncodedPath() @@ -307,9 +309,9 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(clusterRouter, "/stats/region", statsHandler.GetRegionStatus, setMethods(http.MethodGet), setAuditBackend(prometheus)) adminHandler := newAdminHandler(svr, rd) - registerFunc(clusterRouter, "/admin/cache/region/{id}", adminHandler.DeleteRegionCache, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(regionResetRouter, "/admin/cache/region/{id}", adminHandler.DeleteRegionCache, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) registerFunc(clusterRouter, "/admin/storage/region/{id}", adminHandler.DeleteRegionStorage, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/admin/cache/regions", adminHandler.DeleteAllRegionCache, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) + registerFunc(regionResetRouter, "/admin/cache/regions", adminHandler.DeleteAllRegionCache, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) registerFunc(apiRouter, "/admin/persist-file/{file_name}", adminHandler.SavePersistFile, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) registerFunc(apiRouter, "/admin/cluster/markers/snapshot-recovering", adminHandler.isSnapshotRecovering, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus)) registerFunc(apiRouter, "/admin/cluster/markers/snapshot-recovering", adminHandler.markSnapshotRecovering, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) diff --git a/server/server.go b/server/server.go index 040ae8792dd..0136e3d8664 100644 --- a/server/server.go +++ b/server/server.go @@ -203,6 +203,8 @@ type Server struct { // Store as map[string]*grpc.ClientConn clientConns sync.Map + followerRegionResetMu sync.Mutex + tsoClientPool struct { syncutil.RWMutex clients map[string]*streamWrapper @@ -931,6 +933,155 @@ func (s *Server) GetBasicCluster() *core.BasicCluster { return s.basicCluster } +// ResetFollowerRegionCache resets follower local region cache and restarts +// region sync from the leader. +func (s *Server) ResetFollowerRegionCache(regionIDs ...uint64) error { + if !s.persistOptions.IsUseRegionStorage() { + return errors.New("region storage is disabled") + } + s.followerRegionResetMu.Lock() + defer s.followerRegionResetMu.Unlock() + + leader := s.GetLeader() + if leader == nil { + return errs.ErrLeaderNil.FastGenByArgs() + } + leaderURLs := leader.GetClientUrls() + if len(leaderURLs) == 0 { + return errors.New("pd leader has no client url") + } + + syncer := s.cluster.GetRegionSyncer() + syncer.StopSyncWithLeader() + // Keep the follower connected even when the reset returns an error. + defer syncer.StartSyncWithLeader(leaderURLs[0]) + + var resetErr error + if err := s.storage.Flush(); err != nil { + resetErr = errors.Wrap(err, "flush follower region storage") + } + if len(regionIDs) == 0 { + if err := s.deleteFollowerRegionStorage(); resetErr == nil && err != nil { + resetErr = err + } + s.basicCluster.ResetRegionCache() + } else { + for _, regionID := range regionIDs { + if err := s.deleteFollowerRegion(regionID); resetErr == nil && err != nil { + resetErr = err + } + } + } + if err := s.storage.Flush(); err != nil && resetErr == nil { + resetErr = errors.Wrap(err, "flush follower region storage") + } + // Force a full sync after the local reset attempt so the follower can + // rebuild any cache entries that were removed before an error happened. + syncer.ResetHistoryIndex(0) + + log.Info("reset follower region cache and restart region syncer", + zap.String("server", s.Name()), + zap.String("leader", leader.GetName()), + zap.Int("region-count", len(regionIDs))) + return resetErr +} + +func (s *Server) deleteFollowerRegion(regionID uint64) error { + region := s.basicCluster.GetRegion(regionID) + if region == nil { + meta := &metapb.Region{} + ok, err := s.storage.LoadRegion(regionID, meta) + if err != nil { + return errors.Wrap(err, "load follower region from local storage") + } + if ok { + region = core.NewRegionInfo(meta, nil, core.SetSource(core.Storage)) + } + } + if region != nil { + if err := s.deleteFollowerRegionMeta(region.GetMeta()); err != nil { + return err + } + } + s.basicCluster.RemoveRegionIfExist(regionID) + return nil +} + +func (s *Server) deleteFollowerRegionStorage() error { + regionStorage := storage.RetrieveRegionStorage(s.storage) + regionKV, ok := regionStorage.(kv.Base) + if !ok { + return errors.New("region storage does not support range scan") + } + + startID := uint64(0) + endKey := keypath.RegionPath(math.MaxUint64) + for { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + default: + } + + keys, _, err := regionKV.LoadRange(keypath.RegionPath(startID), endKey, endpoint.MaxKVRangeLimit) + if err != nil { + return errors.Wrap(err, "load follower regions from local storage") + } + var lastRegionID uint64 + for _, key := range keys { + regionID, err := parseRegionIDFromStorageKey(key) + if err != nil { + return err + } + lastRegionID = regionID + } + if err := deleteFollowerRegionStorageKeys(s.ctx, regionKV, keys); err != nil { + return errors.Wrap(err, "delete follower regions from local storage") + } + if len(keys) < endpoint.MaxKVRangeLimit { + return nil + } + if lastRegionID == math.MaxUint64 { + return nil + } + startID = lastRegionID + 1 + } +} + +func deleteFollowerRegionStorageKeys(ctx context.Context, regionKV kv.Base, keys []string) error { + return regionKV.RunInTxn(ctx, func(txn kv.Txn) error { + for _, key := range keys { + if err := txn.Remove(key); err != nil { + return err + } + } + return nil + }) +} + +func (s *Server) deleteFollowerRegionMeta(region *metapb.Region) error { + if err := s.storage.DeleteRegion(region); err != nil { + log.Warn("failed to delete follower region from local storage", + zap.String("server", s.Name()), + zap.Uint64("region-id", region.GetId()), + errs.ZapError(err)) + return errors.Wrap(err, "delete follower region from local storage") + } + return nil +} + +func parseRegionIDFromStorageKey(key string) (uint64, error) { + idx := strings.LastIndexByte(key, '/') + if idx < 0 || idx == len(key)-1 { + return 0, errors.Errorf("invalid region storage key %q", key) + } + regionID, err := strconv.ParseUint(key[idx+1:], 10, 64) + if err != nil { + return 0, errors.Wrap(err, "parse region storage key") + } + return regionID, nil +} + // GetPersistOptions returns the schedule option. func (s *Server) GetPersistOptions() *config.PersistOptions { return s.persistOptions @@ -1691,12 +1842,16 @@ func (s *Server) leaderLoop() { } syncer := s.cluster.GetRegionSyncer() if s.persistOptions.IsUseRegionStorage() { + s.followerRegionResetMu.Lock() syncer.StartSyncWithLeader(leader.GetListenUrls()[0]) + s.followerRegionResetMu.Unlock() } log.Info("start to watch pd leader", zap.Stringer("pd-leader", leader)) // WatchLeader will keep looping and never return unless the PD leader has changed. leader.Watch(s.serverLoopCtx) + s.followerRegionResetMu.Lock() syncer.StopSyncWithLeader() + s.followerRegionResetMu.Unlock() log.Info("pd leader has changed, try to re-campaign a pd leader") } diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 00000000000..848104355f2 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,351 @@ +// Copyright 2026 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + stderrors "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/pingcap/kvproto/pkg/metapb" + + "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/member" + "github.com/tikv/pd/pkg/storage" + "github.com/tikv/pd/pkg/storage/kv" + "github.com/tikv/pd/pkg/utils/keypath" + "github.com/tikv/pd/server/config" +) + +var errTestFollowerRegionStorage = stderrors.New("test follower region storage error") + +func TestResetFollowerRegionCacheRequiresRegionStorage(t *testing.T) { + re := require.New(t) + cfg := config.NewConfig() + cfg.PDServerCfg.UseRegionStorage = false + s := &Server{persistOptions: config.NewPersistOptions(cfg)} + + re.ErrorContains(s.ResetFollowerRegionCache(), "region storage is disabled") + + cfg.PDServerCfg.UseRegionStorage = true + s = newTestFollowerRegionResetServer(context.Background()) + s.persistOptions = config.NewPersistOptions(cfg) + s.member = member.NewMember(nil, nil, 1) + re.Error(s.ResetFollowerRegionCache()) +} + +func TestDeleteFollowerRegion(t *testing.T) { + tests := []struct { + name string + setup func(*require.Assertions, *Server) uint64 + errContains string + check func(*require.Assertions, *Server, uint64) + }{ + { + name: "cached region", + setup: func(re *require.Assertions, s *Server) uint64 { + region := newTestFollowerRegionMeta(1) + re.NoError(s.storage.SaveRegion(region)) + s.basicCluster.PutRegion(core.NewRegionInfo(region, nil, core.SetSource(core.Storage))) + return region.GetId() + }, + check: assertTestFollowerRegionDeleted, + }, + { + name: "storage-only region", + setup: func(re *require.Assertions, s *Server) uint64 { + region := newTestFollowerRegionMeta(2) + re.NoError(s.storage.SaveRegion(region)) + return region.GetId() + }, + check: assertTestFollowerRegionDeleted, + }, + { + name: "missing region", + setup: func(*require.Assertions, *Server) uint64 { + return 3 + }, + }, + { + name: "load storage error", + setup: func(_ *require.Assertions, s *Server) uint64 { + s.storage = &testFollowerRegionStorage{ + Storage: s.storage, + loadRegionErr: errTestFollowerRegionStorage, + } + return 4 + }, + errContains: "load follower region from local storage", + }, + { + name: "delete storage error", + setup: func(_ *require.Assertions, s *Server) uint64 { + region := newTestFollowerRegionMeta(5) + s.basicCluster.PutRegion(core.NewRegionInfo(region, nil, core.SetSource(core.Storage))) + s.storage = &testFollowerRegionStorage{ + Storage: s.storage, + deleteRegionErr: errTestFollowerRegionStorage, + } + return region.GetId() + }, + errContains: "delete follower region from local storage", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + re := require.New(t) + s := newTestFollowerRegionResetServer(context.Background()) + regionID := test.setup(re, s) + + err := s.deleteFollowerRegion(regionID) + if test.errContains != "" { + re.ErrorContains(err, test.errContains) + return + } + re.NoError(err) + if test.check != nil { + test.check(re, s, regionID) + } + }) + } +} + +func TestDeleteFollowerRegionStorage(t *testing.T) { + tests := []struct { + name string + ctx context.Context + setup func(*require.Assertions, *Server) func(*require.Assertions, *Server) + errIs error + errContains string + }{ + { + name: "deletes all region storage keys", + setup: func(re *require.Assertions, s *Server) func(*require.Assertions, *Server) { + regions := []*metapb.Region{ + newTestFollowerRegionMeta(10), + newTestFollowerRegionMeta(11), + } + for _, region := range regions { + re.NoError(s.storage.SaveRegion(region)) + } + return func(re *require.Assertions, s *Server) { + for _, region := range regions { + assertTestFollowerRegionDeleted(re, s, region.GetId()) + } + } + }, + }, + { + name: "deletes by key without loading region meta", + setup: func(re *require.Assertions, s *Server) func(*require.Assertions, *Server) { + region := newTestFollowerRegionMeta(12) + re.NoError(s.storage.SaveRegion(region)) + regionStorage := s.storage + s.storage = &testFollowerRegionStorage{ + Storage: regionStorage, + loadRegionErr: errTestFollowerRegionStorage, + } + return func(re *require.Assertions, s *Server) { + s.storage = regionStorage + assertTestFollowerRegionDeleted(re, s, region.GetId()) + } + }, + }, + { + name: "context canceled", + ctx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }(), + errIs: context.Canceled, + }, + { + name: "load range error", + setup: func(_ *require.Assertions, s *Server) func(*require.Assertions, *Server) { + s.storage = &testFollowerRegionStorage{ + Storage: s.storage, + loadRangeErr: errTestFollowerRegionStorage, + } + return nil + }, + errContains: "load follower regions from local storage", + }, + { + name: "delete transaction error", + setup: func(re *require.Assertions, s *Server) func(*require.Assertions, *Server) { + region := newTestFollowerRegionMeta(21) + re.NoError(s.storage.SaveRegion(region)) + s.storage = &testFollowerRegionStorage{ + Storage: s.storage, + runInTxnErr: errTestFollowerRegionStorage, + } + return nil + }, + errContains: "delete follower regions from local storage", + }, + { + name: "invalid region storage key", + setup: func(_ *require.Assertions, s *Server) func(*require.Assertions, *Server) { + s.storage = &testFollowerRegionStorage{ + Storage: s.storage, + loadRangeKeys: []string{"invalid-region-key"}, + } + return nil + }, + errContains: "invalid region storage key", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + re := require.New(t) + ctx := test.ctx + if ctx == nil { + ctx = context.Background() + } + s := newTestFollowerRegionResetServer(ctx) + var check func(*require.Assertions, *Server) + if test.setup != nil { + check = test.setup(re, s) + } + + err := s.deleteFollowerRegionStorage() + switch { + case test.errIs != nil: + re.ErrorIs(err, test.errIs) + return + case test.errContains != "": + re.ErrorContains(err, test.errContains) + return + default: + re.NoError(err) + } + if check != nil { + check(re, s) + } + }) + } +} + +func TestParseRegionIDFromStorageKey(t *testing.T) { + tests := []struct { + name string + key string + regionID uint64 + errContains string + }{ + { + name: "valid region key", + key: keypath.RegionPath(123), + regionID: 123, + }, + { + name: "missing region id", + key: "invalid-region-key", + errContains: "invalid region storage key", + }, + { + name: "invalid region id", + key: "/pd/0/raft/r/not-a-number", + errContains: "parse region storage key", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + re := require.New(t) + regionID, err := parseRegionIDFromStorageKey(test.key) + if test.errContains != "" { + re.ErrorContains(err, test.errContains) + return + } + re.NoError(err) + re.Equal(test.regionID, regionID) + }) + } +} + +func newTestFollowerRegionResetServer(ctx context.Context) *Server { + cfg := config.NewConfig() + return &Server{ + ctx: ctx, + cfg: cfg, + storage: storage.NewStorageWithMemoryBackend(), + basicCluster: core.NewBasicCluster(), + } +} + +func newTestFollowerRegionMeta(regionID uint64) *metapb.Region { + return &metapb.Region{ + Id: regionID, + StartKey: []byte{byte(regionID)}, + EndKey: []byte{byte(regionID + 1)}, + RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, + Peers: []*metapb.Peer{ + {Id: regionID*10 + 1, StoreId: 1}, + }, + } +} + +func assertTestFollowerRegionDeleted(re *require.Assertions, s *Server, regionID uint64) { + region := &metapb.Region{} + ok, err := s.storage.LoadRegion(regionID, region) + re.NoError(err) + re.False(ok) + re.Nil(s.basicCluster.GetRegion(regionID)) +} + +type testFollowerRegionStorage struct { + storage.Storage + loadRegionErr error + deleteRegionErr error + loadRangeErr error + runInTxnErr error + loadRangeKeys []string +} + +func (s *testFollowerRegionStorage) LoadRange(key, endKey string, limit int) (keys []string, values []string, err error) { + if s.loadRangeErr != nil { + return nil, nil, s.loadRangeErr + } + if s.loadRangeKeys != nil { + return s.loadRangeKeys, nil, nil + } + return s.Storage.LoadRange(key, endKey, limit) +} + +func (s *testFollowerRegionStorage) LoadRegion(regionID uint64, region *metapb.Region) (bool, error) { + if s.loadRegionErr != nil { + return false, s.loadRegionErr + } + return s.Storage.LoadRegion(regionID, region) +} + +func (s *testFollowerRegionStorage) DeleteRegion(region *metapb.Region) error { + if s.deleteRegionErr != nil { + return s.deleteRegionErr + } + return s.Storage.DeleteRegion(region) +} + +func (s *testFollowerRegionStorage) RunInTxn(ctx context.Context, f func(txn kv.Txn) error) error { + if s.runInTxnErr != nil { + return s.runInTxnErr + } + return s.Storage.RunInTxn(ctx, f) +} diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 33e9f65366c..d13f442d736 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -24,6 +24,7 @@ import ( "net/http" "os" "strings" + "sync" "testing" "time" @@ -877,6 +878,215 @@ func TestFollowerRegionAPIWithNoForward(t *testing.T) { re.Contains(string(body), "TiKV cluster not bootstrapped") } +func TestFollowerRegionResetCacheWithNoForward(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3, func(conf *config.Config, _ string) { + conf.PDServerCfg.UseRegionStorage = true + conf.TickInterval = typeutil.Duration{Duration: 50 * time.Millisecond} + conf.ElectionInterval = typeutil.Duration{Duration: 250 * time.Millisecond} + }) + re.NoError(err) + defer cluster.Destroy() + + re.NoError(cluster.RunInitialServers()) + re.NotEmpty(cluster.WaitLeader()) + leader := cluster.GetLeaderServer() + re.NoError(leader.BootstrapCluster()) + re.True(cluster.WaitRegionSyncerClientsReady(2)) + + follower := cluster.GetServer(cluster.GetFollower()) + re.NotNil(follower) + regions := tests.InitRegions(3) + for _, region := range regions { + re.NoError(leader.GetRaftCluster().HandleRegionHeartbeat(region)) + } + waitFollowerRegions(re, follower, regions) + + t.Run("requires opt-in header for follower local reset", func(t *testing.T) { + re := require.New(t) + staleRegion := newFollowerStaleRegion(999) + injectFollowerStaleRegion(re, follower, staleRegion) + syncer := follower.GetServer().DirectlyGetRaftCluster().GetRegionSyncer() + syncer.StopSyncWithLeader() + re.False(syncer.IsRunning()) + + path := fmt.Sprintf("/pd/api/v1/admin/cache/region/%d", staleRegion.GetID()) + statusCode, body, err := requestDeleteFollowerRegionCacheWithHeaders(follower, path, http.Header{}) + re.NoError(err) + re.Equal(http.StatusOK, statusCode, body) + re.Contains(body, "server cache") + re.False(syncer.IsRunning()) + assertFollowerRegionStored(re, follower, staleRegion.GetID()) + + deleteFollowerRegionCache(re, follower, path) + testutil.Eventually(re, syncer.IsRunning) + waitFollowerRegions(re, follower, regions) + assertFollowerRegionNotStored(re, follower, staleRegion.GetID()) + }) + + t.Run("deletes one stale follower region", func(t *testing.T) { + re := require.New(t) + staleRegion := newFollowerStaleRegion(1001) + injectFollowerStaleRegion(re, follower, staleRegion) + deleteFollowerRegionCache(re, follower, fmt.Sprintf("/pd/api/v1/admin/cache/region/%d", staleRegion.GetID())) + waitFollowerRegions(re, follower, regions) + assertFollowerRegionNotStored(re, follower, staleRegion.GetID()) + }) + + t.Run("deletes one storage-only stale follower region", func(t *testing.T) { + re := require.New(t) + storageOnlyRegion := newFollowerStaleRegion(1009) + injectFollowerStaleRegionStorageOnly(re, follower, storageOnlyRegion) + deleteFollowerRegionCache(re, follower, fmt.Sprintf("/pd/api/v1/admin/cache/region/%d", storageOnlyRegion.GetID())) + waitFollowerRegions(re, follower, regions) + assertFollowerRegionNotStored(re, follower, storageOnlyRegion.GetID()) + }) + + t.Run("deletes all stale follower regions", func(t *testing.T) { + re := require.New(t) + staleRegion := newFollowerStaleRegion(1019) + injectFollowerStaleRegion(re, follower, staleRegion) + storageOnlyRegion := newFollowerStaleRegion(1029) + injectFollowerStaleRegionStorageOnly(re, follower, storageOnlyRegion) + deleteFollowerRegionCache(re, follower, "/pd/api/v1/admin/cache/regions") + waitFollowerRegions(re, follower, regions) + assertFollowerRegionNotStored(re, follower, staleRegion.GetID()) + assertFollowerRegionNotStored(re, follower, storageOnlyRegion.GetID()) + assertFollowerRegionStored(re, follower, regions[len(regions)-1].GetID()) + }) + + t.Run("serializes concurrent all-region reset requests", func(t *testing.T) { + re := require.New(t) + staleRegion := newFollowerStaleRegion(2009) + injectFollowerStaleRegion(re, follower, staleRegion) + deleteFollowerRegionCacheConcurrently(re, follower, "/pd/api/v1/admin/cache/regions") + waitFollowerRegions(re, follower, regions) + assertFollowerRegionNotStored(re, follower, staleRegion.GetID()) + }) +} + +func newFollowerStaleRegion(regionID uint64) *core.RegionInfo { + region := &metapb.Region{ + Id: regionID, + StartKey: []byte{2}, + EndKey: []byte{}, + RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, + Peers: []*metapb.Peer{ + {Id: regionID*10 + 1, StoreId: 1}, + {Id: regionID*10 + 2, StoreId: 2}, + {Id: regionID*10 + 3, StoreId: 3}, + }, + } + return core.NewRegionInfo(region, region.GetPeers()[0], core.SetSource(core.Storage)) +} + +func injectFollowerStaleRegion(re *require.Assertions, follower *tests.TestServer, region *core.RegionInfo) { + overlaps := follower.GetServer().GetBasicCluster().PutRegion(region) + re.NotEmpty(overlaps) + injectFollowerStaleRegionStorageOnly(re, follower, region) + re.NotNil(follower.GetServer().GetBasicCluster().GetRegion(region.GetID())) +} + +func injectFollowerStaleRegionStorageOnly(re *require.Assertions, follower *tests.TestServer, region *core.RegionInfo) { + re.NoError(follower.GetServer().GetStorage().SaveRegion(region.GetMeta())) + re.NoError(follower.GetServer().GetStorage().Flush()) +} + +func deleteFollowerRegionCache(re *require.Assertions, follower *tests.TestServer, path string) { + statusCode, body, err := requestDeleteFollowerRegionCache(follower, path) + re.NoError(err) + re.Equal(http.StatusOK, statusCode, body) +} + +func deleteFollowerRegionCacheConcurrently(re *require.Assertions, follower *tests.TestServer, path string) { + const requestCount = 2 + var wg sync.WaitGroup + startCh := make(chan struct{}) + errCh := make(chan error, requestCount) + for range requestCount { + wg.Add(1) + go func() { + defer wg.Done() + <-startCh + statusCode, body, err := requestDeleteFollowerRegionCache(follower, path) + if err != nil { + errCh <- err + return + } + if statusCode != http.StatusOK { + errCh <- fmt.Errorf("unexpected status %d: %s", statusCode, body) + } + }() + } + close(startCh) + wg.Wait() + close(errCh) + for err := range errCh { + re.NoError(err) + } +} + +func requestDeleteFollowerRegionCache(follower *tests.TestServer, path string) (int, string, error) { + return requestDeleteFollowerRegionCacheWithHeaders(follower, path, http.Header{ + apiutil.PDAllowFollowerHandleHeader: {"true"}, + }) +} + +func requestDeleteFollowerRegionCacheWithHeaders(follower *tests.TestServer, path string, headers http.Header) (int, string, error) { + req, err := http.NewRequest(http.MethodDelete, follower.GetAddr()+path, http.NoBody) + if err != nil { + return 0, "", err + } + for key, values := range headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + resp, err := tests.TestDialClient.Do(req) + if err != nil { + return 0, "", err + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return 0, "", err + } + return resp.StatusCode, string(body), nil +} + +func waitFollowerRegions(re *require.Assertions, follower *tests.TestServer, regions []*core.RegionInfo) { + testutil.Eventually(re, func() bool { + bc := follower.GetServer().GetBasicCluster() + if len(bc.GetRegions()) != len(regions) { + return false + } + for _, region := range regions { + if bc.GetRegion(region.GetID()) == nil { + return false + } + } + return true + }, testutil.WithWaitFor(10*time.Second), testutil.WithTickInterval(50*time.Millisecond)) +} + +func assertFollowerRegionNotStored(re *require.Assertions, follower *tests.TestServer, regionID uint64) { + re.NoError(follower.GetServer().GetStorage().Flush()) + region := &metapb.Region{} + ok, err := follower.GetServer().GetStorage().LoadRegion(regionID, region) + re.NoError(err) + re.False(ok) +} + +func assertFollowerRegionStored(re *require.Assertions, follower *tests.TestServer, regionID uint64) { + re.NoError(follower.GetServer().GetStorage().Flush()) + region := &metapb.Region{} + ok, err := follower.GetServer().GetStorage().LoadRegion(regionID, region) + re.NoError(err) + re.True(ok) +} + func mustRequestSuccess(re *require.Assertions, s *server.Server) http.Header { resp, err := tests.TestDialClient.Get(s.GetAddr() + "/pd/api/v1/version") re.NoError(err)