diff --git a/server/embed/serve.go b/server/embed/serve.go index e50529dfe438..6227b188332e 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -476,6 +476,20 @@ func (ac *accessController) ServeHTTP(rw http.ResponseWriter, req *http.Request) rw.WriteHeader(http.StatusOK) return } + // Limit the request body size to prevent memory-exhaustion DoS via the + // gRPC-gateway JSON decoder (which base64-decodes fields before any + // server-side size check). + maxBytes := int64(ac.s.Cfg.MaxRequestBytesWithOverhead()) + // Fast-reject: if Content-Length is known and oversized, return 413 + // immediately without reading any body bytes. + if req.ContentLength > maxBytes { + http.Error(rw, "request body too large", http.StatusRequestEntityTooLarge) + return + } + // Safety net: cap the readable bytes for cases where Content-Length is + // absent, zero, or lies. The JSON decoder cannot allocate unbounded + // memory once MaxBytesReader is in place. + req.Body = http.MaxBytesReader(rw, req.Body, maxBytes) ac.mux.ServeHTTP(rw, req) } diff --git a/tests/e2e/v3_curl_kv_test.go b/tests/e2e/v3_curl_kv_test.go index f98e61453410..9003c12df0be 100644 --- a/tests/e2e/v3_curl_kv_test.go +++ b/tests/e2e/v3_curl_kv_test.go @@ -15,8 +15,15 @@ package e2e import ( + "bytes" + "context" + "encoding/base64" "encoding/json" + "fmt" + "io" + "net/http" "testing" + "time" protov1 "github.com/golang/protobuf/proto" //nolint:staticcheck // TODO: remove for a supported version gw "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -212,3 +219,61 @@ func testCurlV3KVCompact(cx ctlCtx) { }) require.NoErrorf(cx.t, err, "testCurlV3KVCompact failed") } + +func TestCurlV3KVOversizedRequestRejected(t *testing.T) { + e2e.BeforeTest(t) + + ctx, cancel := context.WithTimeout(t.Context(), 20*time.Second) + defer cancel() + + clus, err := e2e.NewEtcdProcessCluster(ctx, t, + e2e.WithClusterSize(1), + ) + require.NoError(t, err) + defer clus.Close() + + endpoint := clus.Procs[0].EndpointsHTTP()[0] + "/v3/kv/range" + client := &http.Client{} + + t.Run("normal request succeeds", func(t *testing.T) { + smallKey := base64.StdEncoding.EncodeToString([]byte("foo")) + reqData := []byte(fmt.Sprintf(`{"key":"%s"}`, smallKey)) + + req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqData)) + require.NoError(t, reqErr) + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("oversized request rejected", func(t *testing.T) { + largeKey := base64.StdEncoding.EncodeToString(make([]byte, 2*1024*1024)) + reqData := []byte(fmt.Sprintf(`{"key":"%s"}`, largeKey)) + + req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqData)) + require.NoError(t, reqErr) + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + t.Logf("oversized request response: status=%d body=%s", resp.StatusCode, string(body)) + + require.NotEqual(t, http.StatusOK, resp.StatusCode, + "oversized request must not succeed") + require.NotEqual(t, http.StatusTooManyRequests, resp.StatusCode, + "oversized request must be rejected before full body decode (was 429 before fix)") + require.True(t, + resp.StatusCode == http.StatusRequestEntityTooLarge || + resp.StatusCode == http.StatusBadRequest, + "expected 413 or 400, got %d", resp.StatusCode) + }) +}