diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index f972e36098d..c8d58e5092f 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" + pdcore "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" scheserver "github.com/tikv/pd/pkg/mcs/scheduling/server" "github.com/tikv/pd/pkg/mcs/scheduling/server/config" @@ -1481,13 +1482,17 @@ func checkRegionsReplicated(c *gin.Context) { // @Router /stores/{id} [get] func getStoreByID(c *gin.Context) { svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server) + basicCluster, ok := getBasicCluster(c, svr) + if !ok { + return + } idStr := c.Param("id") storeID, err := strconv.ParseUint(idStr, 10, 64) if err != nil { c.String(http.StatusBadRequest, err.Error()) return } - store := svr.GetBasicCluster().GetStore(storeID) + store := basicCluster.GetStore(storeID) if store == nil { c.String(http.StatusNotFound, errs.ErrStoreNotFound.FastGenByArgs(storeID).Error()) return @@ -1505,14 +1510,18 @@ func getStoreByID(c *gin.Context) { // @Router /stores [get] func getAllStores(c *gin.Context) { svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server) - stores := svr.GetBasicCluster().GetMetaStores() + basicCluster, ok := getBasicCluster(c, svr) + if !ok { + return + } + stores := basicCluster.GetMetaStores() StoresInfo := &response.StoresInfo{ Stores: make([]*response.StoreInfo, 0, len(stores)), } for _, s := range stores { storeID := s.GetId() - store := svr.GetBasicCluster().GetStore(storeID) + store := basicCluster.GetStore(storeID) if store == nil { c.String(http.StatusInternalServerError, errs.ErrStoreNotFound.FastGenByArgs(storeID).Error()) return @@ -1534,7 +1543,11 @@ func getAllStores(c *gin.Context) { // @Router /regions [get] func getAllRegions(c *gin.Context) { svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server) - regions := svr.GetBasicCluster().GetRegions() + basicCluster, ok := getBasicCluster(c, svr) + if !ok { + return + } + regions := basicCluster.GetRegions() b, err := response.MarshalRegionsInfoJSON(c.Request.Context(), regions) if err != nil { c.String(http.StatusInternalServerError, err.Error()) @@ -1550,7 +1563,11 @@ func getAllRegions(c *gin.Context) { // @Router /regions/count [get] func getRegionCount(c *gin.Context) { svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server) - count := svr.GetBasicCluster().GetTotalRegionCount() + basicCluster, ok := getBasicCluster(c, svr) + if !ok { + return + } + count := basicCluster.GetTotalRegionCount() c.IndentedJSON(http.StatusOK, &response.RegionsInfo{Count: count}) } @@ -1563,6 +1580,10 @@ func getRegionCount(c *gin.Context) { // @Router /regions/{id} [get] func getRegionByID(c *gin.Context) { svr := c.MustGet(multiservicesapi.ServiceContextKey).(*scheserver.Server) + basicCluster, ok := getBasicCluster(c, svr) + if !ok { + return + } idStr := c.Param("id") regionID, err := strconv.ParseUint(idStr, 10, 64) if err != nil { @@ -1573,7 +1594,7 @@ func getRegionByID(c *gin.Context) { c.String(http.StatusBadRequest, errs.ErrRegionInvalidID.FastGenByArgs().Error()) return } - regionInfo := svr.GetBasicCluster().GetRegion(regionID) + regionInfo := basicCluster.GetRegion(regionID) if regionInfo == nil { c.String(http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID).Error()) return @@ -1633,6 +1654,15 @@ func getAffinityManager(c *gin.Context) (*affinity.Manager, bool) { return manager, true } +func getBasicCluster(c *gin.Context, svr *scheserver.Server) (*pdcore.BasicCluster, bool) { + basicCluster := svr.GetBasicCluster() + if basicCluster == nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, errs.ErrNotBootstrapped.GenWithStackByArgs().Error()) + return nil, false + } + return basicCluster, true +} + // @Tags affinity-groups // @Summary List all affinity groups. // @Param ids query []string false "Optional affinity group IDs. Repeat as ids=a&ids=b." diff --git a/pkg/mcs/scheduling/server/apis/v1/api_test.go b/pkg/mcs/scheduling/server/apis/v1/api_test.go new file mode 100644 index 00000000000..c1c917ace17 --- /dev/null +++ b/pkg/mcs/scheduling/server/apis/v1/api_test.go @@ -0,0 +1,48 @@ +// Copyright 2026 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package apis + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + scheserver "github.com/tikv/pd/pkg/mcs/scheduling/server" + "github.com/tikv/pd/pkg/utils/apiutil/multiservicesapi" + "github.com/tikv/pd/pkg/utils/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m, testutil.LeakOptions...) +} + +func TestGetAllStoresReturnsNotBootstrappedWhenBasicClusterMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + re := require.New(t) + + resp := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(resp) + ctx.Request = httptest.NewRequest(http.MethodGet, "/stores", nil) + ctx.Set(multiservicesapi.ServiceContextKey, &scheserver.Server{}) + + getAllStores(ctx) + + re.Equal(http.StatusInternalServerError, resp.Code) + re.Contains(resp.Body.String(), "not bootstrapped") +} diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index 9f4238898ce..f1437cc8798 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -38,7 +38,10 @@ import ( "github.com/tikv/pd/pkg/cluster" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" + mcsaffinity "github.com/tikv/pd/pkg/mcs/scheduling/server/affinity" "github.com/tikv/pd/pkg/mcs/scheduling/server/config" + "github.com/tikv/pd/pkg/mcs/scheduling/server/meta" + "github.com/tikv/pd/pkg/mcs/scheduling/server/rule" "github.com/tikv/pd/pkg/ratelimit" "github.com/tikv/pd/pkg/response" "github.com/tikv/pd/pkg/schedule" @@ -76,7 +79,13 @@ type Cluster struct { regionStats *statistics.RegionStatistics labelStats *statistics.LabelStatistics hotStat *statistics.HotStat + resourceMu sync.RWMutex storage storage.Storage + hbStreams *hbstream.HeartbeatStreams + metaWatcher *meta.Watcher + configWatcher *config.Watcher + ruleWatcher *rule.Watcher + affinityWatcher *mcsaffinity.Watcher coordinator *schedule.Coordinator checkMembershipCh chan struct{} pdLeader atomic.Value @@ -142,6 +151,7 @@ func NewCluster( labelStats: statistics.NewLabelStatistics(), regionStats: statistics.NewRegionStatistics(basicCluster, persistConfig, ruleManager), storage: storage, + hbStreams: hbStreams, checkMembershipCh: checkMembershipCh, httpClient: httpClient, backendAddress: backendAddress, @@ -262,9 +272,75 @@ func (c *Cluster) BucketsStats(degree int, regionIDs ...uint64) map[uint64][]*bu // GetStorage returns the storage. func (c *Cluster) GetStorage() storage.Storage { + if c == nil { + return nil + } + c.resourceMu.RLock() + defer c.resourceMu.RUnlock() return c.storage } +// GetHeartbeatStreams returns the heartbeat streams. +func (c *Cluster) GetHeartbeatStreams() *hbstream.HeartbeatStreams { + if c == nil { + return nil + } + c.resourceMu.RLock() + defer c.resourceMu.RUnlock() + return c.hbStreams +} + +// GetMetaWatcher returns the meta watcher. +func (c *Cluster) GetMetaWatcher() *meta.Watcher { + if c == nil { + return nil + } + c.resourceMu.RLock() + defer c.resourceMu.RUnlock() + return c.metaWatcher +} + +// SetRuntimeResources installs the cluster-scoped runtime resources after they are created. +func (c *Cluster) SetRuntimeResources( + metaWatcher *meta.Watcher, + configWatcher *config.Watcher, + ruleWatcher *rule.Watcher, + affinityWatcher *mcsaffinity.Watcher, +) { + c.resourceMu.Lock() + defer c.resourceMu.Unlock() + c.metaWatcher = metaWatcher + c.configWatcher = configWatcher + c.ruleWatcher = ruleWatcher + c.affinityWatcher = affinityWatcher +} + +func (c *Cluster) cleanupRuntimeResources() { + c.resourceMu.Lock() + defer c.resourceMu.Unlock() + if c.affinityWatcher != nil { + c.affinityWatcher.Close() + c.affinityWatcher = nil + } + if c.ruleWatcher != nil { + c.ruleWatcher.Close() + c.ruleWatcher = nil + } + if c.metaWatcher != nil { + c.metaWatcher.Close() + c.metaWatcher = nil + } + if c.configWatcher != nil { + c.configWatcher.Close() + c.configWatcher = nil + } + if c.hbStreams != nil { + c.hbStreams.Close() + c.hbStreams = nil + } + c.storage = nil +} + // GetCheckerConfig returns the checker config. func (c *Cluster) GetCheckerConfig() sc.CheckerConfigProvider { return c.persistConfig } @@ -653,7 +729,7 @@ func (c *Cluster) StartBackgroundJobs() { c.running.Store(true) } -// StopBackgroundJobs stops background jobs. +// StopBackgroundJobs stops background jobs, these jobs is created by NewCluster. func (c *Cluster) StopBackgroundJobs() { if !c.running.Load() { return diff --git a/pkg/mcs/scheduling/server/grpc_service.go b/pkg/mcs/scheduling/server/grpc_service.go index 90fb9b927d1..89636527eff 100644 --- a/pkg/mcs/scheduling/server/grpc_service.go +++ b/pkg/mcs/scheduling/server/grpc_service.go @@ -34,6 +34,7 @@ import ( "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mcs/registry" + "github.com/tikv/pd/pkg/mcs/scheduling/server/meta" "github.com/tikv/pd/pkg/schedule/hbstream" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/keypath" @@ -138,7 +139,11 @@ func (s *Service) RegionHeartbeat(stream schedulingpb.Scheduling_RegionHeartbeat } c := s.GetCluster() - if c == nil { + var streams *hbstream.HeartbeatStreams + if c != nil { + streams = c.GetHeartbeatStreams() + } + if c == nil || streams == nil { resp := &schedulingpb.RegionHeartbeatResponse{Header: notBootstrappedHeader()} err := server.Send(resp) return errors.WithStack(err) @@ -154,7 +159,7 @@ func (s *Service) RegionHeartbeat(stream schedulingpb.Scheduling_RegionHeartbeat storeLabel := strconv.FormatUint(storeID, 10) if time.Since(lastBind) > time.Minute { - s.hbStreams.BindStream(storeID, server) + streams.BindStream(storeID, server) lastBind = time.Now() } @@ -244,13 +249,17 @@ func (s *Service) RegionBuckets(stream schedulingpb.Scheduling_RegionBucketsServ // StoreHeartbeat implements gRPC SchedulingServer. func (s *Service) StoreHeartbeat(_ context.Context, request *schedulingpb.StoreHeartbeatRequest) (*schedulingpb.StoreHeartbeatResponse, error) { c := s.GetCluster() - if c == nil { + var metaWatcher *meta.Watcher + if c != nil { + metaWatcher = c.GetMetaWatcher() + } + if c == nil || metaWatcher == nil { return &schedulingpb.StoreHeartbeatResponse{Header: notBootstrappedHeader()}, nil } start := time.Now() if c.GetStore(request.GetStats().GetStoreId()) == nil { - s.metaWatcher.GetStoreWatcher().ForceLoad() + metaWatcher.GetStoreWatcher().ForceLoad() } storeID := request.GetStats().GetStoreId() diff --git a/pkg/mcs/scheduling/server/rule/watcher.go b/pkg/mcs/scheduling/server/rule/watcher.go index e770d5ff51d..cf861f084fc 100644 --- a/pkg/mcs/scheduling/server/rule/watcher.go +++ b/pkg/mcs/scheduling/server/rule/watcher.go @@ -281,4 +281,7 @@ func (rw *Watcher) initializeRegionLabelWatcher() error { func (rw *Watcher) Close() { rw.cancel() rw.wg.Wait() + if rw.checkerController != nil { + rw.checkerController.ClearSuspectKeyRanges() + } } diff --git a/pkg/mcs/scheduling/server/server.go b/pkg/mcs/scheduling/server/server.go index 529f49f83b2..a84ffb9d729 100644 --- a/pkg/mcs/scheduling/server/server.go +++ b/pkg/mcs/scheduling/server/server.go @@ -92,7 +92,6 @@ type Server struct { cfg *config.Config persistConfig *config.PersistConfig - basicCluster *core.BasicCluster // for the primary election of scheduling participant *member.Participant @@ -108,15 +107,7 @@ type Server struct { serviceID *discovery.ServiceRegistryEntry serviceRegister *discovery.ServiceRegister - cluster atomic.Value // *Cluster - hbStreams *hbstream.HeartbeatStreams - storage *endpoint.StorageEndpoint - - // for watching the PD meta info updates that are related to the scheduling. - configWatcher *config.Watcher - ruleWatcher *rule.Watcher - metaWatcher *meta.Watcher - affinityWatcher *affinity.Watcher + cluster atomic.Value // *Cluster // Cgroup Monitor cgMonitor cgroup.Monitor @@ -426,7 +417,10 @@ func (s *Server) GetCluster() *Cluster { // GetBasicCluster returns the basic cluster. func (s *Server) GetBasicCluster() *core.BasicCluster { - return s.basicCluster + if cluster := s.GetCluster(); cluster != nil { + return cluster.GetBasicCluster() + } + return nil } // GetCoordinator returns the coordinator. @@ -506,70 +500,95 @@ func (s *Server) startServer() (err error) { return nil } -func (s *Server) startCluster(context.Context) error { - s.basicCluster = core.NewBasicCluster() - s.storage = endpoint.NewStorageEndpoint(kv.NewMemoryKV(), nil) - err := s.startMetaConfWatcher() +func (s *Server) startCluster(ctx context.Context) error { + basicCluster := core.NewBasicCluster() + storage := endpoint.NewStorageEndpoint(kv.NewMemoryKV(), nil) + + var ( + hbStreams *hbstream.HeartbeatStreams + configWatcher *config.Watcher + metaWatcher *meta.Watcher + ruleWatcher *rule.Watcher + affinityWatcher *affinity.Watcher + err error + ) + metaWatcher, configWatcher, err = s.startMetaConfWatcher(ctx, basicCluster, storage) if err != nil { + configWatcher.Close() + metaWatcher.Close() return err } - s.hbStreams = hbstream.NewHeartbeatStreams(s.Context(), constant.SchedulingServiceName, s.basicCluster) - cluster, err := NewCluster(s.Context(), s.persistConfig, s.storage, s.basicCluster, s.hbStreams, s.checkMembershipCh, s.GetHTTPClient(), s.GetBackendEndpoints()) + hbStreams = hbstream.NewHeartbeatStreams(ctx, constant.SchedulingServiceName, basicCluster) + cluster, err := NewCluster(ctx, s.persistConfig, storage, basicCluster, hbStreams, s.checkMembershipCh, s.GetHTTPClient(), s.GetBackendEndpoints()) + defer func() { + // make sure the cluster is stopped if any error occurs + // if StopBackgroundJobs return false, it means the cluster is not running, so we need to close the context make the + // other goroutines exit. + if cluster != nil { + cluster.StopBackgroundJobs() + } + if hbStreams != nil { + hbStreams.Close() + } + if configWatcher != nil { + configWatcher.Close() + } + if metaWatcher != nil { + metaWatcher.Close() + } + if ruleWatcher != nil { + ruleWatcher.Close() + } + if affinityWatcher != nil { + affinityWatcher.Close() + } + + }() if err != nil { return err } - s.cluster.Store(cluster) - // Inject the cluster components into the config watcher after the scheduler controller is created. - s.configWatcher.SetSchedulersController(cluster.GetCoordinator().GetSchedulersController()) - // Start the rule watcher after the cluster is created. - s.ruleWatcher, err = rule.NewWatcher(s.Context(), s.GetClient(), s.storage, + + configWatcher.SetSchedulersController(cluster.GetCoordinator().GetSchedulersController()) + ruleWatcher, err = rule.NewWatcher(ctx, s.GetClient(), storage, cluster.GetCoordinator().GetCheckerController(), cluster.GetRuleManager(), cluster.GetRegionLabeler()) if err != nil { return err } - // Start the affinity watcher after the cluster is created. - s.affinityWatcher, err = affinity.NewWatcher(s.Context(), s.GetClient(), cluster.GetAffinityManager()) + affinityWatcher, err = affinity.NewWatcher(ctx, s.GetClient(), cluster.GetAffinityManager()) if err != nil { return err } + + cluster.SetRuntimeResources(metaWatcher, configWatcher, ruleWatcher, affinityWatcher) + s.cluster.Store(cluster) cluster.StartBackgroundJobs() + cluster = nil // defer cleanup no longer needed return nil } func (s *Server) stopCluster() { - cluster := s.GetCluster() - if cluster != nil { + if cluster := s.GetCluster(); cluster != nil { s.cluster.Store((*Cluster)(nil)) cluster.StopBackgroundJobs() + cluster.cleanupRuntimeResources() } - s.stopWatcher() } -func (s *Server) startMetaConfWatcher() (err error) { - s.metaWatcher, err = meta.NewWatcher(s.Context(), s.GetClient(), s.basicCluster) +func (s *Server) startMetaConfWatcher( + ctx context.Context, + basicCluster *core.BasicCluster, + storage *endpoint.StorageEndpoint, +) (metaWatcher *meta.Watcher, configWatcher *config.Watcher, err error) { + metaWatcher, err = meta.NewWatcher(ctx, s.GetClient(), basicCluster) if err != nil { - return err + return nil, nil, err } - s.configWatcher, err = config.NewWatcher(s.Context(), s.GetClient(), s.persistConfig, s.storage) + configWatcher, err = config.NewWatcher(ctx, s.GetClient(), s.persistConfig, storage) if err != nil { - return err - } - return err -} - -func (s *Server) stopWatcher() { - if s.affinityWatcher != nil { - s.affinityWatcher.Close() - } - if s.ruleWatcher != nil { - s.ruleWatcher.Close() - } - if s.metaWatcher != nil { - s.metaWatcher.Close() - } - if s.configWatcher != nil { - s.configWatcher.Close() + metaWatcher.Close() + return nil, nil, err } + return metaWatcher, configWatcher, nil } // GetPersistConfig returns the persist config. diff --git a/pkg/mcs/scheduling/server/server_test.go b/pkg/mcs/scheduling/server/server_test.go new file mode 100644 index 00000000000..0ddec11f447 --- /dev/null +++ b/pkg/mcs/scheduling/server/server_test.go @@ -0,0 +1,165 @@ +// Copyright 2026 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "io" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" + + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/kvproto/pkg/schedulingpb" + + "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/mcs/utils/constant" + "github.com/tikv/pd/pkg/schedule/hbstream" + "github.com/tikv/pd/pkg/storage/endpoint" + "github.com/tikv/pd/pkg/storage/kv" +) + +func TestStopCluster(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hbStreams := hbstream.NewHeartbeatStreams(ctx, constant.SchedulingServiceName, core.NewBasicCluster()) + storage := endpoint.NewStorageEndpoint(kv.NewMemoryKV(), nil) + cluster := &Cluster{hbStreams: hbStreams, storage: storage} + + s := &Server{} + s.cluster.Store(cluster) + + s.stopCluster() + + re.Nil(s.GetCluster()) + re.Nil(cluster.GetHeartbeatStreams()) + re.Nil(cluster.GetStorage()) +} + +func TestRegionHeartbeatReturnsNotBootstrappedWhenHeartbeatStreamsMissing(t *testing.T) { + re := require.New(t) + cluster := &Cluster{BasicCluster: core.NewBasicCluster()} + cluster.PutStore(core.NewStoreInfo(&metapb.Store{Id: 1, Address: "store-1"})) + service := &Service{Server: &Server{service: &Service{}}} + service.service = service + service.cluster.Store(cluster) + + stream := &mockRegionHeartbeatStream{ + recvs: []*schedulingpb.RegionHeartbeatRequest{{ + Leader: &metapb.Peer{StoreId: 1}, + Region: &metapb.Region{Id: 1}, + }}, + } + + err := service.RegionHeartbeat(stream) + re.NoError(err) + re.Len(stream.sent, 1) + re.Equal(schedulingpb.ErrorType_NOT_BOOTSTRAPPED, stream.sent[0].GetHeader().GetError().GetType()) +} + +func TestStoreHeartbeatReturnsNotBootstrappedWhenMetaWatcherMissing(t *testing.T) { + re := require.New(t) + service := &Service{Server: &Server{service: &Service{}}} + service.service = service + service.cluster.Store(&Cluster{BasicCluster: core.NewBasicCluster()}) + + resp, err := service.StoreHeartbeat(context.Background(), &schedulingpb.StoreHeartbeatRequest{ + Stats: &pdpb.StoreStats{StoreId: 1}, + }) + re.NoError(err) + re.Equal(schedulingpb.ErrorType_NOT_BOOTSTRAPPED, resp.GetHeader().GetError().GetType()) +} + +func TestRegionHeartbeatDuringCleanupDoesNotPanic(t *testing.T) { + re := require.New(t) + + basicCluster := core.NewBasicCluster() + basicCluster.PutStore(core.NewStoreInfo(&metapb.Store{Id: 1, Address: "store-1"})) + + s := &Server{} + s.cluster.Store(&Cluster{BasicCluster: basicCluster}) + + stream := &mockRegionHeartbeatStream{ + recvs: []*schedulingpb.RegionHeartbeatRequest{ + { + Region: &metapb.Region{ + Id: 1, + Peers: []*metapb.Peer{{Id: 1, StoreId: 1}}, + }, + Leader: &metapb.Peer{Id: 1, StoreId: 1}, + }, + }, + } + + re.NotPanics(func() { + _ = (&Service{Server: s}).RegionHeartbeat(stream) + }) + re.Nil(s.GetCluster().GetHeartbeatStreams()) +} + +func TestStoreHeartbeatDuringCleanupDoesNotPanic(t *testing.T) { + re := require.New(t) + + basicCluster := core.NewBasicCluster() + s := &Server{} + s.cluster.Store(&Cluster{BasicCluster: basicCluster}) + + re.NotPanics(func() { + _, _ = (&Service{Server: s}).StoreHeartbeat(context.Background(), &schedulingpb.StoreHeartbeatRequest{ + Stats: &pdpb.StoreStats{StoreId: 1}, + }) + }) + re.Nil(s.GetCluster().GetMetaWatcher()) +} + +func TestClusterResourceGettersHandleNilReceiver(t *testing.T) { + re := require.New(t) + + var cluster *Cluster + re.Nil(cluster.GetHeartbeatStreams()) + re.Nil(cluster.GetMetaWatcher()) + re.Nil(cluster.GetStorage()) +} + +type mockRegionHeartbeatStream struct { + schedulingpb.Scheduling_RegionHeartbeatServer + recvs []*schedulingpb.RegionHeartbeatRequest + sent []*schedulingpb.RegionHeartbeatResponse +} + +func (m *mockRegionHeartbeatStream) Send(resp *schedulingpb.RegionHeartbeatResponse) error { + m.sent = append(m.sent, resp) + return nil +} + +func (m *mockRegionHeartbeatStream) Recv() (*schedulingpb.RegionHeartbeatRequest, error) { + if len(m.recvs) == 0 { + return nil, io.EOF + } + req := m.recvs[0] + m.recvs = m.recvs[1:] + return req, nil +} + +func (*mockRegionHeartbeatStream) SetHeader(metadata.MD) error { return nil } +func (*mockRegionHeartbeatStream) SendHeader(metadata.MD) error { return nil } +func (*mockRegionHeartbeatStream) SetTrailer(metadata.MD) {} +func (*mockRegionHeartbeatStream) Context() context.Context { return context.Background() } +func (*mockRegionHeartbeatStream) SendMsg(any) error { return nil } +func (*mockRegionHeartbeatStream) RecvMsg(any) error { return nil }