Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions pkg/mcs/scheduling/server/apis/v1/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/log"

pdcore "github.com/tikv/pd/pkg/core"
"github.com/tikv/pd/pkg/errs"
scheserver "github.com/tikv/pd/pkg/mcs/scheduling/server"
"github.com/tikv/pd/pkg/mcs/scheduling/server/config"
Expand Down Expand Up @@ -1481,13 +1482,17 @@ func checkRegionsReplicated(c *gin.Context) {
// @Router /stores/{id} [get]
func getStoreByID(c *gin.Context) {
svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server)
basicCluster, ok := getBasicCluster(c, svr)
if !ok {
return
}
idStr := c.Param("id")
storeID, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
store := svr.GetBasicCluster().GetStore(storeID)
store := basicCluster.GetStore(storeID)
if store == nil {
c.String(http.StatusNotFound, errs.ErrStoreNotFound.FastGenByArgs(storeID).Error())
return
Expand All @@ -1505,14 +1510,18 @@ func getStoreByID(c *gin.Context) {
// @Router /stores [get]
func getAllStores(c *gin.Context) {
svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server)
stores := svr.GetBasicCluster().GetMetaStores()
basicCluster, ok := getBasicCluster(c, svr)
if !ok {
return
}
stores := basicCluster.GetMetaStores()
StoresInfo := &response.StoresInfo{
Stores: make([]*response.StoreInfo, 0, len(stores)),
}

for _, s := range stores {
storeID := s.GetId()
store := svr.GetBasicCluster().GetStore(storeID)
store := basicCluster.GetStore(storeID)
if store == nil {
c.String(http.StatusInternalServerError, errs.ErrStoreNotFound.FastGenByArgs(storeID).Error())
return
Expand All @@ -1534,7 +1543,11 @@ func getAllStores(c *gin.Context) {
// @Router /regions [get]
func getAllRegions(c *gin.Context) {
svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server)
regions := svr.GetBasicCluster().GetRegions()
basicCluster, ok := getBasicCluster(c, svr)
if !ok {
return
}
regions := basicCluster.GetRegions()
b, err := response.MarshalRegionsInfoJSON(c.Request.Context(), regions)
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
Expand All @@ -1550,7 +1563,11 @@ func getAllRegions(c *gin.Context) {
// @Router /regions/count [get]
func getRegionCount(c *gin.Context) {
svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server)
count := svr.GetBasicCluster().GetTotalRegionCount()
basicCluster, ok := getBasicCluster(c, svr)
if !ok {
return
}
count := basicCluster.GetTotalRegionCount()
c.IndentedJSON(http.StatusOK, &response.RegionsInfo{Count: count})
}

Expand All @@ -1563,6 +1580,10 @@ func getRegionCount(c *gin.Context) {
// @Router /regions/{id} [get]
func getRegionByID(c *gin.Context) {
svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server)
basicCluster, ok := getBasicCluster(c, svr)
if !ok {
return
}
idStr := c.Param("id")
regionID, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
Expand All @@ -1573,7 +1594,7 @@ func getRegionByID(c *gin.Context) {
c.String(http.StatusBadRequest, errs.ErrRegionInvalidID.FastGenByArgs().Error())
return
}
regionInfo := svr.GetBasicCluster().GetRegion(regionID)
regionInfo := basicCluster.GetRegion(regionID)
if regionInfo == nil {
c.String(http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID).Error())
return
Expand Down Expand Up @@ -1633,6 +1654,15 @@ func getAffinityManager(c *gin.Context) (*affinity.Manager, bool) {
return manager, true
}

func getBasicCluster(c *gin.Context, svr *scheserver.Server) (*pdcore.BasicCluster, bool) {
basicCluster := svr.GetBasicCluster()
Comment thread
lhy1024 marked this conversation as resolved.
if basicCluster == nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, errs.ErrNotBootstrapped.GenWithStackByArgs().Error())
return nil, false
}
return basicCluster, true
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// @Tags affinity-groups
// @Summary List all affinity groups.
// @Param ids query []string false "Optional affinity group IDs. Repeat as ids=a&ids=b."
Expand Down
48 changes: 48 additions & 0 deletions pkg/mcs/scheduling/server/apis/v1/api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2026 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package apis

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

scheserver "github.com/tikv/pd/pkg/mcs/scheduling/server"
"github.com/tikv/pd/pkg/utils/apiutil/multiservicesapi"
"github.com/tikv/pd/pkg/utils/testutil"
)

func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, testutil.LeakOptions...)
}

func TestGetAllStoresReturnsNotBootstrappedWhenBasicClusterMissing(t *testing.T) {
gin.SetMode(gin.TestMode)
re := require.New(t)

resp := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(resp)
ctx.Request = httptest.NewRequest(http.MethodGet, "/stores", nil)
ctx.Set(multiservicesapi.ServiceContextKey, &scheserver.Server{})

getAllStores(ctx)

re.Equal(http.StatusInternalServerError, resp.Code)
re.Contains(resp.Body.String(), "not bootstrapped")
}
Comment thread
bufferflies marked this conversation as resolved.
78 changes: 77 additions & 1 deletion pkg/mcs/scheduling/server/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ import (
"github.com/tikv/pd/pkg/cluster"
"github.com/tikv/pd/pkg/core"
"github.com/tikv/pd/pkg/errs"
mcsaffinity "github.com/tikv/pd/pkg/mcs/scheduling/server/affinity"
"github.com/tikv/pd/pkg/mcs/scheduling/server/config"
"github.com/tikv/pd/pkg/mcs/scheduling/server/meta"
"github.com/tikv/pd/pkg/mcs/scheduling/server/rule"
"github.com/tikv/pd/pkg/ratelimit"
"github.com/tikv/pd/pkg/response"
"github.com/tikv/pd/pkg/schedule"
Expand Down Expand Up @@ -76,7 +79,13 @@ type Cluster struct {
regionStats *statistics.RegionStatistics
labelStats *statistics.LabelStatistics
hotStat *statistics.HotStat
resourceMu sync.RWMutex
storage storage.Storage
hbStreams *hbstream.HeartbeatStreams
metaWatcher *meta.Watcher
configWatcher *config.Watcher
ruleWatcher *rule.Watcher
affinityWatcher *mcsaffinity.Watcher
coordinator *schedule.Coordinator
checkMembershipCh chan struct{}
pdLeader atomic.Value
Expand Down Expand Up @@ -142,6 +151,7 @@ func NewCluster(
labelStats: statistics.NewLabelStatistics(),
regionStats: statistics.NewRegionStatistics(basicCluster, persistConfig, ruleManager),
storage: storage,
hbStreams: hbStreams,
checkMembershipCh: checkMembershipCh,
httpClient: httpClient,
backendAddress: backendAddress,
Expand Down Expand Up @@ -262,9 +272,75 @@ func (c *Cluster) BucketsStats(degree int, regionIDs ...uint64) map[uint64][]*bu

// GetStorage returns the storage.
func (c *Cluster) GetStorage() storage.Storage {
if c == nil {
return nil
}
c.resourceMu.RLock()
defer c.resourceMu.RUnlock()
return c.storage
}

// GetHeartbeatStreams returns the heartbeat streams.
func (c *Cluster) GetHeartbeatStreams() *hbstream.HeartbeatStreams {
if c == nil {
return nil
}
c.resourceMu.RLock()
defer c.resourceMu.RUnlock()
return c.hbStreams
}

// GetMetaWatcher returns the meta watcher.
func (c *Cluster) GetMetaWatcher() *meta.Watcher {
if c == nil {
return nil
}
c.resourceMu.RLock()
defer c.resourceMu.RUnlock()
return c.metaWatcher
}

// SetRuntimeResources installs the cluster-scoped runtime resources after they are created.
func (c *Cluster) SetRuntimeResources(
metaWatcher *meta.Watcher,
configWatcher *config.Watcher,
ruleWatcher *rule.Watcher,
affinityWatcher *mcsaffinity.Watcher,
) {
c.resourceMu.Lock()
defer c.resourceMu.Unlock()
c.metaWatcher = metaWatcher
c.configWatcher = configWatcher
c.ruleWatcher = ruleWatcher
c.affinityWatcher = affinityWatcher
}

func (c *Cluster) cleanupRuntimeResources() {
c.resourceMu.Lock()
defer c.resourceMu.Unlock()
if c.affinityWatcher != nil {
c.affinityWatcher.Close()
c.affinityWatcher = nil
}
if c.ruleWatcher != nil {
c.ruleWatcher.Close()
c.ruleWatcher = nil
}
if c.metaWatcher != nil {
c.metaWatcher.Close()
c.metaWatcher = nil
}
if c.configWatcher != nil {
c.configWatcher.Close()
c.configWatcher = nil
}
if c.hbStreams != nil {
c.hbStreams.Close()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closing hbStreams synchronously here does wait for an in-flight stream.Send to return. Close only cancels the context and then waits on wg; if hbstream.run has already entered the send/keepalive branch, it will not observe ctx.Done until Send returns. In production heartbeatServer.Send has a 5s timeout, so this is usually not a permanent block, but it can delay primary stepdown/transfer on slow streams.

Could we check context cancellation inside the send/keepalive path, or make the close path bounded, so primary exit is not amplified by slow streams?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type blockingHeartbeatStream struct {
        once    sync.Once
        started chan struct{}
        release chan struct{}
  }

  func (s *blockingHeartbeatStream) Send(core.RegionHeartbeatResponse) error {
        s.once.Do(func() { close(s.started) })
        <-s.release
        return nil
  }

  func TestCloseDoesNotBlockOnInFlightSend(t *testing.T) {
        re := require.New(t)
        ctx, cancel := context.WithCancel(context.Background())
        defer cancel()

        basicCluster := core.NewBasicCluster()
        basicCluster.PutStore(core.NewStoreInfo(&metapb.Store{Id: 1, Address: "store-1"}))
        streams := NewHeartbeatStreams(ctx, "", basicCluster)
        stream := &blockingHeartbeatStream{started: make(chan struct{}), release: make(chan struct{})}
        streams.BindStream(1, stream)

        leader := &metapb.Peer{Id: 1, StoreId: 1}
        region := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: []*metapb.Peer{leader}}, leader)

        for {
                streams.SendMsg(region, &Operation{})
                select {
                case <-stream.started:
                        goto closeStream
                case <-time.After(time.Second):
                        re.FailNow("expected heartbeat stream send to start")
                default:
                        time.Sleep(time.Millisecond)
                }
        }

  closeStream:
        closeDone := make(chan struct{})
        go func() {
                streams.Close()
                close(closeDone)
        }()
select {
        case <-closeDone:
        case <-time.After(50 * time.Millisecond):
                close(stream.release)
                <-closeDone
                re.FailNow("HeartbeatStreams.Close should not wait for an in-flight stream Send after cancellation")
        }
  }

c.hbStreams = nil
}
c.storage = nil
}

// GetCheckerConfig returns the checker config.
func (c *Cluster) GetCheckerConfig() sc.CheckerConfigProvider { return c.persistConfig }

Expand Down Expand Up @@ -653,7 +729,7 @@ func (c *Cluster) StartBackgroundJobs() {
c.running.Store(true)
}

// StopBackgroundJobs stops background jobs.
// StopBackgroundJobs stops background jobs, these jobs is created by NewCluster.
func (c *Cluster) StopBackgroundJobs() {
if !c.running.Load() {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On startup rollback, NewCluster may have already started goroutines while StartBackgroundJobs was never called, so running is false and this returns without canceling c.ctx. Please add a cleanup path that cancels the cluster context even before background jobs start.

return
Expand Down
17 changes: 13 additions & 4 deletions pkg/mcs/scheduling/server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/tikv/pd/pkg/core"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/mcs/registry"
"github.com/tikv/pd/pkg/mcs/scheduling/server/meta"
"github.com/tikv/pd/pkg/schedule/hbstream"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/keypath"
Expand Down Expand Up @@ -138,7 +139,11 @@ func (s *Service) RegionHeartbeat(stream schedulingpb.Scheduling_RegionHeartbeat
}

c := s.GetCluster()
if c == nil {
var streams *hbstream.HeartbeatStreams
if c != nil {
streams = c.GetHeartbeatStreams()
}
if c == nil || streams == nil {
resp := &schedulingpb.RegionHeartbeatResponse{Header: notBootstrappedHeader()}
err := server.Send(resp)
return errors.WithStack(err)
Expand All @@ -154,7 +159,7 @@ func (s *Service) RegionHeartbeat(stream schedulingpb.Scheduling_RegionHeartbeat
storeLabel := strconv.FormatUint(storeID, 10)

if time.Since(lastBind) > time.Minute {
s.hbStreams.BindStream(storeID, server)
streams.BindStream(storeID, server)
lastBind = time.Now()
}

Expand Down Expand Up @@ -244,13 +249,17 @@ func (s *Service) RegionBuckets(stream schedulingpb.Scheduling_RegionBucketsServ
// StoreHeartbeat implements gRPC SchedulingServer.
func (s *Service) StoreHeartbeat(_ context.Context, request *schedulingpb.StoreHeartbeatRequest) (*schedulingpb.StoreHeartbeatResponse, error) {
c := s.GetCluster()
if c == nil {
var metaWatcher *meta.Watcher
if c != nil {
metaWatcher = c.GetMetaWatcher()
}
if c == nil || metaWatcher == nil {
return &schedulingpb.StoreHeartbeatResponse{Header: notBootstrappedHeader()}, nil
}

start := time.Now()
if c.GetStore(request.GetStats().GetStoreId()) == nil {
s.metaWatcher.GetStoreWatcher().ForceLoad()
metaWatcher.GetStoreWatcher().ForceLoad()
}

storeID := request.GetStats().GetStoreId()
Expand Down
3 changes: 3 additions & 0 deletions pkg/mcs/scheduling/server/rule/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,7 @@ func (rw *Watcher) initializeRegionLabelWatcher() error {
func (rw *Watcher) Close() {
rw.cancel()
rw.wg.Wait()
if rw.checkerController != nil {
rw.checkerController.ClearSuspectKeyRanges()
}
Comment on lines +283 to +285
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cleanup order can still leave stale suspect key ranges. Close cancels the watcher and immediately calls ClearSuspectKeyRanges, but an already-running watcher callback is not preempted by cancellation and may still run postEventsFn/AddSuspectKeyRange afterward, re-adding suspect ranges after they were cleared.

Could we move ClearSuspectKeyRanges after rw.wg.Wait(), so cleanup happens after all watcher callbacks have exited?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func TestCloseClearsSuspectKeyRangesAfterInFlightCallback(t *testing.T) {
        re := require.New(t)

        checkerCtx, checkerCancel := context.WithCancel(context.Background())
        defer checkerCancel()
        cluster := mockcluster.NewCluster(checkerCtx, mockconfig.NewTestOptions())
        checkerController := checker.NewController(checkerCtx, cluster, cluster.GetCheckerConfig(), nil)

        watcherCtx, watcherCancel := context.WithCancel(context.Background())
        rw := &Watcher{
                ctx:               watcherCtx,
                cancel:            watcherCancel,
                checkerController: checkerController,
        }

        callbackStarted := make(chan struct{})
        releaseCallback := make(chan struct{})
        rw.wg.Add(1)
        go func() {
                defer rw.wg.Done()
                <-watcherCtx.Done()
                close(callbackStarted)
                <-releaseCallback
                checkerController.AddSuspectKeyRange([]byte("a"), []byte("z"))
        }()

        closeDone := make(chan struct{})
        go func() {
                rw.Close()
                close(closeDone)
        }()

        select {
        case <-callbackStarted:
        case <-time.After(time.Second):
                re.FailNow("expected in-flight watcher callback to observe cancellation")
        }
        select {
        case <-closeDone:
                re.FailNow("Close returned before the in-flight callback finished")
        case <-time.After(50 * time.Millisecond):
        }

        close(releaseCallback)
        select {
        case <-closeDone:
        case <-time.After(time.Second):
                re.FailNow("Close did not return after the in-flight callback finished")
        }

        _, ok := checkerController.PopOneSuspectKeyRange()
        re.False(ok, "Close should not leave suspect key ranges added by in-flight callbacks")
  }

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about this

 func (rw *Watcher) Close() {
      rw.cancel()
      rw.wg.Wait()
      if rw.checkerController != nil {
          rw.checkerController.ClearSuspectKeyRanges()
      }
  }

}
Loading
Loading