diff --git a/server/api/middleware.go b/server/api/middleware.go index 7e82671d920..155876b95c6 100644 --- a/server/api/middleware.go +++ b/server/api/middleware.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/failpoint" "github.com/tikv/pd/pkg/audit" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/requestutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/cluster" @@ -77,20 +78,36 @@ func (rm *requestInfoMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques } type clusterMiddleware struct { - s *server.Server - rd *render.Render + s *server.Server + rd *render.Render + allowFollowerSyncedRegion bool } -func newClusterMiddleware(s *server.Server) clusterMiddleware { - return clusterMiddleware{ +type clusterMiddlewareOption func(*clusterMiddleware) + +func withFollowerSyncedRegion() clusterMiddlewareOption { + return func(m *clusterMiddleware) { + m.allowFollowerSyncedRegion = true + } +} + +func newClusterMiddleware(s *server.Server, opts ...clusterMiddlewareOption) clusterMiddleware { + m := clusterMiddleware{ s: s, rd: render.New(render.Options{IndentJSON: true}), } + for _, opt := range opts { + opt(&m) + } + return m } func (m clusterMiddleware) middleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rc := m.s.GetRaftCluster() + if rc == nil { + rc = m.getFollowerSyncedCluster(r) + } if rc == nil { m.rd.JSON(w, http.StatusInternalServerError, errs.ErrNotBootstrapped.FastGenByArgs().Error()) return @@ -100,6 +117,20 @@ func (m clusterMiddleware) middleware(h http.Handler) http.Handler { }) } +func (m clusterMiddleware) getFollowerSyncedCluster(r *http.Request) *cluster.RaftCluster { + if r.Method != http.MethodGet || + !m.allowFollowerSyncedRegion || + m.s.IsServing() || + r.Header.Get(apiutil.PDAllowFollowerHandleHeader) == "" { + return nil + } + rc := m.s.DirectlyGetRaftCluster() + if rc == nil || !rc.GetRegionSyncer().IsRunning() { + return nil + } + return rc +} + type clusterCtxKey struct{} func getCluster(r *http.Request) *cluster.RaftCluster { diff --git a/server/api/router.go b/server/api/router.go index 0e129706b43..02a956657e5 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -127,8 +127,11 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { clusterRouter := apiRouter.NewRoute().Subrouter() clusterRouter.Use(newClusterMiddleware(svr).middleware) + regionReadRouter := apiRouter.NewRoute().Subrouter() + regionReadRouter.Use(newClusterMiddleware(svr, withFollowerSyncedRegion()).middleware) escapeRouter := clusterRouter.NewRoute().Subrouter().UseEncodedPath() + regionReadEscapeRouter := regionReadRouter.NewRoute().Subrouter().UseEncodedPath() operatorHandler := newOperatorHandler(handler, rd) registerFunc(apiRouter, "/operators", operatorHandler.GetOperators, setMethods(http.MethodGet), setAuditBackend(prometheus)) @@ -244,27 +247,27 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(apiRouter, "/hotspot/buckets", hotStatusHandler.GetHotBuckets, setMethods(http.MethodGet), setAuditBackend(prometheus)) regionHandler := newRegionHandler(svr, rd) - registerFunc(clusterRouter, "/region/id/{id}", regionHandler.GetRegionByID, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter.UseEncodedPath(), "/region/key/{key}", regionHandler.GetRegion, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/region/id/{id}", regionHandler.GetRegionByID, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadEscapeRouter, "/region/key/{key}", regionHandler.GetRegion, setMethods(http.MethodGet), setAuditBackend(prometheus)) srd := createStreamingRender() regionsAllHandler := newRegionsHandler(svr, srd) - registerFunc(clusterRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus)) + registerFunc(regionReadRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus)) regionsHandler := newRegionsHandler(svr, rd) - registerFunc(clusterRouter, "/regions/key", regionsHandler.ScanRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/count", regionsHandler.GetRegionCount, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/store/{id}", regionsHandler.GetStoreRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/keyspace/id/{id}", regionsHandler.GetKeyspaceRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/writeflow", regionsHandler.GetTopWriteFlowRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/writequery", regionsHandler.GetTopWriteQueryRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/readflow", regionsHandler.GetTopReadFlowRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/readquery", regionsHandler.GetTopReadQueryRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/confver", regionsHandler.GetTopConfVerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/version", regionsHandler.GetTopVersionRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/size", regionsHandler.GetTopSizeRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/keys", regionsHandler.GetTopKeysRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/cpu", regionsHandler.GetTopCPURegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/key", regionsHandler.ScanRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/count", regionsHandler.GetRegionCount, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/store/{id}", regionsHandler.GetStoreRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/keyspace/id/{id}", regionsHandler.GetKeyspaceRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/writeflow", regionsHandler.GetTopWriteFlowRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/writequery", regionsHandler.GetTopWriteQueryRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/readflow", regionsHandler.GetTopReadFlowRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/readquery", regionsHandler.GetTopReadQueryRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/confver", regionsHandler.GetTopConfVerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/version", regionsHandler.GetTopVersionRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/size", regionsHandler.GetTopSizeRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/keys", regionsHandler.GetTopKeysRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/cpu", regionsHandler.GetTopCPURegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) registerFunc(clusterRouter, "/regions/check/miss-peer", regionsHandler.GetMissPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) registerFunc(clusterRouter, "/regions/check/extra-peer", regionsHandler.GetExtraPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) registerFunc(clusterRouter, "/regions/check/pending-peer", regionsHandler.GetPendingPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) @@ -275,14 +278,14 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(clusterRouter, "/regions/check/oversized-region", regionsHandler.GetOverSizedRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) registerFunc(clusterRouter, "/regions/check/undersized-region", regionsHandler.GetUndersizedRegions, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/check/hist-size", regionsHandler.GetSizeHistogram, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/check/hist-keys", regionsHandler.GetKeysHistogram, setMethods(http.MethodGet), setAuditBackend(prometheus)) - registerFunc(clusterRouter, "/regions/sibling/{id}", regionsHandler.GetRegionSiblings, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/check/hist-size", regionsHandler.GetSizeHistogram, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/check/hist-keys", regionsHandler.GetKeysHistogram, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/sibling/{id}", regionsHandler.GetRegionSiblings, setMethods(http.MethodGet), setAuditBackend(prometheus)) registerFunc(clusterRouter, "/regions/accelerate-schedule", regionsHandler.AccelerateRegionsScheduleInRange, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) registerFunc(clusterRouter, "/regions/accelerate-schedule/batch", regionsHandler.AccelerateRegionsScheduleInRanges, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) registerFunc(clusterRouter, "/regions/scatter", regionsHandler.ScatterRegions, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) registerFunc(clusterRouter, "/regions/split", regionsHandler.SplitRegions, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(clusterRouter, "/regions/range-holes", regionsHandler.GetRangeHoles, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(regionReadRouter, "/regions/range-holes", regionsHandler.GetRangeHoles, setMethods(http.MethodGet), setAuditBackend(prometheus)) registerFunc(clusterRouter, "/regions/replicated", regionsHandler.CheckRegionsReplicated, setMethods(http.MethodGet), setQueries("startKey", "{startKey}", "endKey", "{endKey}"), setAuditBackend(prometheus)) registerFunc(apiRouter, "/version", newVersionHandler(rd).GetVersion, setMethods(http.MethodGet), setAuditBackend(prometheus)) diff --git a/server/server.go b/server/server.go index e5e0b96da0a..1f1708d5d6c 100644 --- a/server/server.go +++ b/server/server.go @@ -1426,7 +1426,7 @@ func (s *Server) IsServiceIndependent(name string) bool { } // DirectlyGetRaftCluster returns raft cluster directly. -// Only used for test. +// It bypasses the leader-running check for follower-local paths and tests. func (s *Server) DirectlyGetRaftCluster() *cluster.RaftCluster { return s.cluster } diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 891ce40cb57..672c7ffa93e 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -33,6 +33,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/response" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/pkg/utils/typeutil" @@ -725,6 +726,74 @@ func (suite *redirectorTestSuite) TestXForwardedFor() { re.NotContains(l, suite.cluster.GetConfig().GetClientURLs()) } +func TestFollowerRegionAPIWithNoForward(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3, func(conf *config.Config, _ string) { + conf.PDServerCfg.UseRegionStorage = true + conf.TickInterval = typeutil.Duration{Duration: 50 * time.Millisecond} + conf.ElectionInterval = typeutil.Duration{Duration: 250 * time.Millisecond} + }) + re.NoError(err) + defer cluster.Destroy() + + re.NoError(cluster.RunInitialServers()) + re.NotEmpty(cluster.WaitLeader()) + leader := cluster.GetLeaderServer() + re.NoError(leader.BootstrapCluster()) + re.True(cluster.WaitRegionSyncerClientsReady(2)) + + follower := cluster.GetServer(cluster.GetFollower()) + re.NotNil(follower) + testutil.Eventually(re, func() bool { + return follower.GetServer().DirectlyGetRaftCluster().GetRegionSyncer().IsRunning() + }) + + regions := tests.InitRegions(3) + for _, region := range regions { + re.NoError(leader.GetRaftCluster().HandleRegionHeartbeat(region)) + } + testutil.Eventually(re, func() bool { + return len(follower.GetServer().GetBasicCluster().GetRegions()) == len(regions) + }) + + req, err := http.NewRequest(http.MethodGet, follower.GetAddr()+"/pd/api/v1/regions", http.NoBody) + re.NoError(err) + req.Header.Set(apiutil.PDAllowFollowerHandleHeader, "true") + resp, err := tests.TestDialClient.Do(req) + re.NoError(err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + re.NoError(err) + re.Equal(http.StatusOK, resp.StatusCode, string(body)) + var regionsInfo response.RegionsInfo + re.NoError(json.Unmarshal(body, ®ionsInfo)) + re.Equal(len(regions), regionsInfo.Count) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/region/id/%d", follower.GetAddr(), regions[0].GetID()), http.NoBody) + re.NoError(err) + req.Header.Set(apiutil.PDAllowFollowerHandleHeader, "true") + resp, err = tests.TestDialClient.Do(req) + re.NoError(err) + defer resp.Body.Close() + body, err = io.ReadAll(resp.Body) + re.NoError(err) + re.Equal(http.StatusOK, resp.StatusCode, string(body)) + re.Contains(string(body), fmt.Sprintf(`"id":%d`, regions[0].GetID())) + + req, err = http.NewRequest(http.MethodGet, follower.GetAddr()+"/pd/api/v1/regions/check/miss-peer", http.NoBody) + re.NoError(err) + req.Header.Set(apiutil.PDAllowFollowerHandleHeader, "true") + resp, err = tests.TestDialClient.Do(req) + re.NoError(err) + defer resp.Body.Close() + body, err = io.ReadAll(resp.Body) + re.NoError(err) + re.Equal(http.StatusInternalServerError, resp.StatusCode, string(body)) + re.Contains(string(body), "TiKV cluster not bootstrapped") +} + func mustRequestSuccess(re *require.Assertions, s *server.Server) http.Header { resp, err := tests.TestDialClient.Get(s.GetAddr() + "/pd/api/v1/version") re.NoError(err)