Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions internal/server/http_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -730,3 +731,137 @@ func TestLogRuntimeError(t *testing.T) {
func stringPtr(s string) *string {
return &s
}

// TestWriteErrorResponse verifies that writeErrorResponse writes a JSON error
// body with "error" and "message" fields and the correct status code and Content-Type.
func TestWriteErrorResponse(t *testing.T) {
tests := []struct {
name string
statusCode int
code string
message string
}{
{
name: "400 bad request",
statusCode: http.StatusBadRequest,
code: "bad_request",
message: "the request body is malformed",
},
{
name: "401 unauthorized",
statusCode: http.StatusUnauthorized,
code: "unauthorized",
message: "invalid API key",
},
{
name: "403 forbidden",
statusCode: http.StatusForbidden,
code: "forbidden",
message: "access denied",
},
{
name: "404 not found",
statusCode: http.StatusNotFound,
code: "not_found",
message: "the requested resource does not exist",
},
{
name: "500 internal server error",
statusCode: http.StatusInternalServerError,
code: "internal_error",
message: "an unexpected error occurred",
},
{
name: "503 service unavailable",
statusCode: http.StatusServiceUnavailable,
code: "service_unavailable",
message: "the gateway is shutting down",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
writeErrorResponse(w, tt.statusCode, tt.code, tt.message)

assert.Equal(t, tt.statusCode, w.Code, "status code should match")
assert.Equal(t, "application/json", w.Header().Get("Content-Type"), "Content-Type should be application/json")

var body map[string]string
err := json.NewDecoder(w.Body).Decode(&body)
require.NoError(t, err, "response body should be valid JSON")
assert.Equal(t, tt.code, body["error"], "response body 'error' field should match code")
assert.Equal(t, tt.message, body["message"], "response body 'message' field should match message")
})
}
}

// TestRejectRequest verifies that rejectRequest writes the correct HTTP error
// response and does not panic for a variety of status codes and messages.
func TestRejectRequest(t *testing.T) {
tests := []struct {
name string
status int
code string
msg string
logCategory string
runtimeErrType string
runtimeDetail string
path string
method string
}{
{
name: "401 missing auth header",
status: http.StatusUnauthorized,
code: "unauthorized",
msg: "missing Authorization header",
logCategory: "auth",
runtimeErrType: "authentication_failed",
runtimeDetail: "missing_auth_header",
path: "/mcp",
method: "GET",
},
{
name: "401 invalid api key",
status: http.StatusUnauthorized,
code: "unauthorized",
msg: "invalid API key",
logCategory: "auth",
runtimeErrType: "authentication_failed",
runtimeDetail: "invalid_api_key",
path: "/mcp/github",
method: "POST",
},
{
name: "503 shutdown",
status: http.StatusServiceUnavailable,
code: "service_unavailable",
msg: "Gateway is shutting down",
logCategory: "gateway",
runtimeErrType: "shutdown",
runtimeDetail: "shutdown_in_progress",
path: "/mcp",
method: "POST",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(tt.method, tt.path, nil)

assert.NotPanics(t, func() {
rejectRequest(w, r, tt.status, tt.code, tt.msg, tt.logCategory, tt.runtimeErrType, tt.runtimeDetail)
})

assert.Equal(t, tt.status, w.Code, "status code should match")
assert.Equal(t, "application/json", w.Header().Get("Content-Type"), "Content-Type should be application/json")

var body map[string]string
err := json.NewDecoder(w.Body).Decode(&body)
require.NoError(t, err, "response body should be valid JSON")
assert.Equal(t, tt.code, body["error"], "response body 'error' field should match code")
assert.Equal(t, tt.msg, body["message"], "response body 'message' field should match msg")
})
}
}
Loading