Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions server/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/failpoint"
"github.com/tikv/pd/pkg/audit"
"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/requestutil"
"github.com/tikv/pd/server"
"github.com/tikv/pd/server/cluster"
Expand Down Expand Up @@ -77,20 +78,36 @@ func (rm *requestInfoMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques
}

type clusterMiddleware struct {
s *server.Server
rd *render.Render
s *server.Server
rd *render.Render
allowFollowerSyncedRegion bool
}

func newClusterMiddleware(s *server.Server) clusterMiddleware {
return clusterMiddleware{
type clusterMiddlewareOption func(*clusterMiddleware)

func withFollowerSyncedRegion() clusterMiddlewareOption {
return func(m *clusterMiddleware) {
m.allowFollowerSyncedRegion = true
}
}

func newClusterMiddleware(s *server.Server, opts ...clusterMiddlewareOption) clusterMiddleware {
m := clusterMiddleware{
s: s,
rd: render.New(render.Options{IndentJSON: true}),
}
for _, opt := range opts {
opt(&m)
}
return m
}

func (m clusterMiddleware) middleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rc := m.s.GetRaftCluster()
if rc == nil {
rc = m.getFollowerSyncedCluster(r)
}
if rc == nil {
m.rd.JSON(w, http.StatusInternalServerError, errs.ErrNotBootstrapped.FastGenByArgs().Error())
return
Expand All @@ -100,6 +117,20 @@ func (m clusterMiddleware) middleware(h http.Handler) http.Handler {
})
}

func (m clusterMiddleware) getFollowerSyncedCluster(r *http.Request) *cluster.RaftCluster {
if r.Method != http.MethodGet ||
!m.allowFollowerSyncedRegion ||
m.s.IsServing() ||
r.Header.Get(apiutil.PDAllowFollowerHandleHeader) == "" {
return nil
}
rc := m.s.DirectlyGetRaftCluster()
if rc == nil || !rc.GetRegionSyncer().IsRunning() {
return nil
}
return rc
}

type clusterCtxKey struct{}

func getCluster(r *http.Request) *cluster.RaftCluster {
Expand Down
43 changes: 23 additions & 20 deletions server/api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,11 @@ func createRouter(prefix string, svr *server.Server) *mux.Router {

clusterRouter := apiRouter.NewRoute().Subrouter()
clusterRouter.Use(newClusterMiddleware(svr).middleware)
regionReadRouter := apiRouter.NewRoute().Subrouter()
regionReadRouter.Use(newClusterMiddleware(svr, withFollowerSyncedRegion()).middleware)

escapeRouter := clusterRouter.NewRoute().Subrouter().UseEncodedPath()
regionReadEscapeRouter := regionReadRouter.NewRoute().Subrouter().UseEncodedPath()

operatorHandler := newOperatorHandler(handler, rd)
registerFunc(apiRouter, "/operators", operatorHandler.GetOperators, setMethods(http.MethodGet), setAuditBackend(prometheus))
Expand Down Expand Up @@ -244,27 +247,27 @@ func createRouter(prefix string, svr *server.Server) *mux.Router {
registerFunc(apiRouter, "/hotspot/buckets", hotStatusHandler.GetHotBuckets, setMethods(http.MethodGet), setAuditBackend(prometheus))

regionHandler := newRegionHandler(svr, rd)
registerFunc(clusterRouter, "/region/id/{id}", regionHandler.GetRegionByID, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter.UseEncodedPath(), "/region/key/{key}", regionHandler.GetRegion, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/region/id/{id}", regionHandler.GetRegionByID, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadEscapeRouter, "/region/key/{key}", regionHandler.GetRegion, setMethods(http.MethodGet), setAuditBackend(prometheus))

srd := createStreamingRender()
regionsAllHandler := newRegionsHandler(svr, srd)
registerFunc(clusterRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus))
registerFunc(regionReadRouter, "/regions", regionsAllHandler.GetRegions, setMethods(http.MethodGet), setAuditBackend(localLog, prometheus))

regionsHandler := newRegionsHandler(svr, rd)
registerFunc(clusterRouter, "/regions/key", regionsHandler.ScanRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/count", regionsHandler.GetRegionCount, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/store/{id}", regionsHandler.GetStoreRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/keyspace/id/{id}", regionsHandler.GetKeyspaceRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/writeflow", regionsHandler.GetTopWriteFlowRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/writequery", regionsHandler.GetTopWriteQueryRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/readflow", regionsHandler.GetTopReadFlowRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/readquery", regionsHandler.GetTopReadQueryRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/confver", regionsHandler.GetTopConfVerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/version", regionsHandler.GetTopVersionRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/size", regionsHandler.GetTopSizeRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/keys", regionsHandler.GetTopKeysRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/cpu", regionsHandler.GetTopCPURegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/key", regionsHandler.ScanRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/count", regionsHandler.GetRegionCount, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/store/{id}", regionsHandler.GetStoreRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/keyspace/id/{id}", regionsHandler.GetKeyspaceRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/writeflow", regionsHandler.GetTopWriteFlowRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/writequery", regionsHandler.GetTopWriteQueryRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/readflow", regionsHandler.GetTopReadFlowRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/readquery", regionsHandler.GetTopReadQueryRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/confver", regionsHandler.GetTopConfVerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/version", regionsHandler.GetTopVersionRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/size", regionsHandler.GetTopSizeRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/keys", regionsHandler.GetTopKeysRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/cpu", regionsHandler.GetTopCPURegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/check/miss-peer", regionsHandler.GetMissPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/check/extra-peer", regionsHandler.GetExtraPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/check/pending-peer", regionsHandler.GetPendingPeerRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
Expand All @@ -275,14 +278,14 @@ func createRouter(prefix string, svr *server.Server) *mux.Router {
registerFunc(clusterRouter, "/regions/check/oversized-region", regionsHandler.GetOverSizedRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/check/undersized-region", regionsHandler.GetUndersizedRegions, setMethods(http.MethodGet), setAuditBackend(prometheus))

registerFunc(clusterRouter, "/regions/check/hist-size", regionsHandler.GetSizeHistogram, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/check/hist-keys", regionsHandler.GetKeysHistogram, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/sibling/{id}", regionsHandler.GetRegionSiblings, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/check/hist-size", regionsHandler.GetSizeHistogram, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/check/hist-keys", regionsHandler.GetKeysHistogram, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/sibling/{id}", regionsHandler.GetRegionSiblings, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/accelerate-schedule", regionsHandler.AccelerateRegionsScheduleInRange, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus))
registerFunc(clusterRouter, "/regions/accelerate-schedule/batch", regionsHandler.AccelerateRegionsScheduleInRanges, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus))
registerFunc(clusterRouter, "/regions/scatter", regionsHandler.ScatterRegions, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus))
registerFunc(clusterRouter, "/regions/split", regionsHandler.SplitRegions, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus))
registerFunc(clusterRouter, "/regions/range-holes", regionsHandler.GetRangeHoles, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(regionReadRouter, "/regions/range-holes", regionsHandler.GetRangeHoles, setMethods(http.MethodGet), setAuditBackend(prometheus))
registerFunc(clusterRouter, "/regions/replicated", regionsHandler.CheckRegionsReplicated, setMethods(http.MethodGet), setQueries("startKey", "{startKey}", "endKey", "{endKey}"), setAuditBackend(prometheus))

registerFunc(apiRouter, "/version", newVersionHandler(rd).GetVersion, setMethods(http.MethodGet), setAuditBackend(prometheus))
Expand Down
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ func (s *Server) IsServiceIndependent(name string) bool {
}

// DirectlyGetRaftCluster returns raft cluster directly.
// Only used for test.
// It bypasses the leader-running check for follower-local paths and tests.
func (s *Server) DirectlyGetRaftCluster() *cluster.RaftCluster {
return s.cluster
}
Expand Down
69 changes: 69 additions & 0 deletions tests/server/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/tikv/pd/pkg/core"
"github.com/tikv/pd/pkg/response"
"github.com/tikv/pd/pkg/utils/apiutil"
"github.com/tikv/pd/pkg/utils/testutil"
"github.com/tikv/pd/pkg/utils/typeutil"
Expand Down Expand Up @@ -725,6 +726,74 @@ func (suite *redirectorTestSuite) TestXForwardedFor() {
re.NotContains(l, suite.cluster.GetConfig().GetClientURLs())
}

func TestFollowerRegionAPIWithNoForward(t *testing.T) {
re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cluster, err := tests.NewTestCluster(ctx, 3, func(conf *config.Config, _ string) {
conf.PDServerCfg.UseRegionStorage = true
conf.TickInterval = typeutil.Duration{Duration: 50 * time.Millisecond}
conf.ElectionInterval = typeutil.Duration{Duration: 250 * time.Millisecond}
})
re.NoError(err)
defer cluster.Destroy()

re.NoError(cluster.RunInitialServers())
re.NotEmpty(cluster.WaitLeader())
leader := cluster.GetLeaderServer()
re.NoError(leader.BootstrapCluster())
re.True(cluster.WaitRegionSyncerClientsReady(2))

follower := cluster.GetServer(cluster.GetFollower())
re.NotNil(follower)
testutil.Eventually(re, func() bool {
return follower.GetServer().DirectlyGetRaftCluster().GetRegionSyncer().IsRunning()
})

regions := tests.InitRegions(3)
for _, region := range regions {
re.NoError(leader.GetRaftCluster().HandleRegionHeartbeat(region))
}
testutil.Eventually(re, func() bool {
return len(follower.GetServer().GetBasicCluster().GetRegions()) == len(regions)
})

req, err := http.NewRequest(http.MethodGet, follower.GetAddr()+"/pd/api/v1/regions", http.NoBody)
re.NoError(err)
req.Header.Set(apiutil.PDAllowFollowerHandleHeader, "true")
resp, err := tests.TestDialClient.Do(req)
re.NoError(err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
re.NoError(err)
re.Equal(http.StatusOK, resp.StatusCode, string(body))
var regionsInfo response.RegionsInfo
re.NoError(json.Unmarshal(body, &regionsInfo))
re.Equal(len(regions), regionsInfo.Count)

req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/region/id/%d", follower.GetAddr(), regions[0].GetID()), http.NoBody)
re.NoError(err)
req.Header.Set(apiutil.PDAllowFollowerHandleHeader, "true")
resp, err = tests.TestDialClient.Do(req)
re.NoError(err)
defer resp.Body.Close()
body, err = io.ReadAll(resp.Body)
re.NoError(err)
re.Equal(http.StatusOK, resp.StatusCode, string(body))
re.Contains(string(body), fmt.Sprintf(`"id":%d`, regions[0].GetID()))

req, err = http.NewRequest(http.MethodGet, follower.GetAddr()+"/pd/api/v1/regions/check/miss-peer", http.NoBody)
re.NoError(err)
req.Header.Set(apiutil.PDAllowFollowerHandleHeader, "true")
resp, err = tests.TestDialClient.Do(req)
re.NoError(err)
defer resp.Body.Close()
body, err = io.ReadAll(resp.Body)
re.NoError(err)
re.Equal(http.StatusInternalServerError, resp.StatusCode, string(body))
re.Contains(string(body), "TiKV cluster not bootstrapped")
}

func mustRequestSuccess(re *require.Assertions, s *server.Server) http.Header {
resp, err := tests.TestDialClient.Get(s.GetAddr() + "/pd/api/v1/version")
re.NoError(err)
Expand Down
Loading