diff --git a/.gitignore b/.gitignore index 9725b60..12c8503 100644 --- a/.gitignore +++ b/.gitignore @@ -38,4 +38,6 @@ coverage/ # Liquibase Specific files/folders liquibase.properties -liquibase_libs/ \ No newline at end of file +liquibase_libs/ + +.junie \ No newline at end of file diff --git a/internal/error_envelope.go b/internal/error_envelope.go new file mode 100644 index 0000000..0f2b521 --- /dev/null +++ b/internal/error_envelope.go @@ -0,0 +1,39 @@ +package internal + +import ( + "encoding/json" + "log/slog" + "net/http" +) + +type ErrorEnvelope struct { + Type string `json:"type"` + Title string `json:"title"` + Status int `json:"status"` + Detail string `json:"detail,omitempty"` + Instance string `json:"instance,omitempty"` +} + +// HandleHttpError handles an error and writes it to the response writer +// in the format specified by RFC 7807. If an invalid status code is provided, +// it will default to 500 Internal Server Error. +func HandleHttpError(w http.ResponseWriter, err ErrorEnvelope, statusCode int) { + if err.Type == "" { + err.Type = "about:blank" + } + if http.StatusText(statusCode) == "" { + slog.Error("Unknown HTTP status code", "status_code", statusCode) + statusCode = http.StatusInternalServerError + } + if err.Title == "" { + err.Title = http.StatusText(statusCode) + } + err.Status = statusCode + + w.Header().Set("Content-Type", "application/problem+json") + w.WriteHeader(statusCode) + + if encodeErr := json.NewEncoder(w).Encode(err); encodeErr != nil { + slog.Error("Failed to encode error response", "error", encodeErr) + } +} diff --git a/internal/error_envelope_test.go b/internal/error_envelope_test.go new file mode 100644 index 0000000..83494ad --- /dev/null +++ b/internal/error_envelope_test.go @@ -0,0 +1,99 @@ +package internal + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHandleHttpError(t *testing.T) { + tests := []struct { + name string + err ErrorEnvelope + statusCode int + expectedStatus int + expectedBody string + expectedTitle string + expectedType string + }{ + { + name: "Empty Envelope", + err: ErrorEnvelope{}, + statusCode: http.StatusBadRequest, + expectedStatus: http.StatusBadRequest, + expectedTitle: "Bad Request", + expectedType: "about:blank", + }, + { + name: "With Detail", + err: ErrorEnvelope{ + Detail: "Validation failed", + }, + statusCode: http.StatusUnprocessableEntity, + expectedStatus: http.StatusUnprocessableEntity, + expectedTitle: "Unprocessable Entity", + expectedType: "about:blank", + }, + { + name: "Custom Type and Title", + err: ErrorEnvelope{ + Type: "https://example.com/probs/out-of-credit", + Title: "You do not have enough credit.", + }, + statusCode: http.StatusForbidden, + expectedStatus: http.StatusForbidden, + expectedTitle: "You do not have enough credit.", + expectedType: "https://example.com/probs/out-of-credit", + }, + { + name: "Zero Status Code Defaults to 500", + err: ErrorEnvelope{Detail: "Something went wrong"}, + statusCode: 0, + expectedStatus: http.StatusInternalServerError, + expectedTitle: "Internal Server Error", + expectedType: "about:blank", + }, + { + name: "Negative Status Code Defaults to 500", + err: ErrorEnvelope{Detail: "Something went wrong"}, + statusCode: -1, + expectedStatus: http.StatusInternalServerError, + expectedTitle: "Internal Server Error", + expectedType: "about:blank", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := httptest.NewRecorder() + HandleHttpError(rr, tt.err, tt.statusCode) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + contentType := rr.Header().Get("Content-Type") + if contentType != "application/problem+json" { + t.Errorf("expected Content-Type application/problem+json, got %q", contentType) + } + + var got ErrorEnvelope + if err := json.NewDecoder(rr.Body).Decode(&got); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if got.Status != tt.expectedStatus { + t.Errorf("expected envelope status %d, got %d", tt.expectedStatus, got.Status) + } + + if got.Title != tt.expectedTitle { + t.Errorf("expected envelope title %q, got %q", tt.expectedTitle, got.Title) + } + + if got.Type != tt.expectedType { + t.Errorf("expected envelope type %q, got %q", tt.expectedType, got.Type) + } + }) + } +} diff --git a/internal/platform/handler.go b/internal/platform/handler.go index b84e35d..0b860b9 100644 --- a/internal/platform/handler.go +++ b/internal/platform/handler.go @@ -1,6 +1,7 @@ package platform import ( + "bytes" "context" "encoding/json" "errors" @@ -27,18 +28,18 @@ type platformHandler struct { func (h *platformHandler) CreatePlatform(w http.ResponseWriter, r *http.Request) { var req createPlatformRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid request body"}, http.StatusBadRequest) return } if req.Name == "" { - http.Error(w, "Name is required", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Name is required"}, http.StatusBadRequest) return } contextWithTimeOut, cancel := context.WithTimeout(r.Context(), 5*time.Second) defer cancel() if err := h.queries.CreatePlatform(contextWithTimeOut, req.ToParams()); err != nil { - http.Error(w, "Failed to create platform", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to create platform"}, http.StatusInternalServerError) return } @@ -48,22 +49,22 @@ func (h *platformHandler) CreatePlatform(w http.ResponseWriter, r *http.Request) func (h *platformHandler) UpdatePlatform(w http.ResponseWriter, r *http.Request) { var req updatePlatformRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid request body"}, http.StatusBadRequest) return } id, ok := internal.GetIntFromRequestPath("id", r) if !ok { - http.Error(w, "Invalid platform ID", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid platform ID"}, http.StatusBadRequest) return } if req.Name == "" { - http.Error(w, "Name is required", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Name is required"}, http.StatusBadRequest) return } if req.ID != id { - http.Error(w, "Platform ID does not match path", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Platform ID does not match path"}, http.StatusBadRequest) return } contextWithTimeOut, cancel := context.WithTimeout(r.Context(), 5*time.Second) @@ -71,10 +72,10 @@ func (h *platformHandler) UpdatePlatform(w http.ResponseWriter, r *http.Request) _, err := h.queries.UpdatePlatform(contextWithTimeOut, req.ToParams(id)) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - http.Error(w, "Platform not found", http.StatusNotFound) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Platform not found"}, http.StatusNotFound) return } - http.Error(w, "Failed to update platform", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to update platform"}, http.StatusInternalServerError) return } @@ -84,7 +85,7 @@ func (h *platformHandler) UpdatePlatform(w http.ResponseWriter, r *http.Request) func (h *platformHandler) DeletePlatform(w http.ResponseWriter, r *http.Request) { id, ok := internal.GetIntFromRequestPath("id", r) if !ok { - http.Error(w, "Invalid platform ID", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid platform ID"}, http.StatusBadRequest) return } contextWithTimeOut, cancel := context.WithTimeout(r.Context(), 5*time.Second) @@ -92,12 +93,12 @@ func (h *platformHandler) DeletePlatform(w http.ResponseWriter, r *http.Request) _, err := h.queries.DeletePlatform(contextWithTimeOut, id) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - http.Error(w, "Platform not found", http.StatusNotFound) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Platform not found"}, http.StatusNotFound) return } logger := internal.LoggerFromContext(r.Context()) logger.Error("Failed to delete platform", "error", err, "platform_id", id) - http.Error(w, "Internal server error", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Internal server error"}, http.StatusInternalServerError) return } @@ -109,25 +110,27 @@ func (h *platformHandler) GetPlatforms(w http.ResponseWriter, r *http.Request) { defer cancel() platforms, err := h.queries.GetPlatforms(contextWithTimeOut) if err != nil { - http.Error(w, "Failed to fetch platforms", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to fetch platforms"}, http.StatusInternalServerError) return } if platforms == nil { platforms = []Platform{} } - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(platforms) + var buf bytes.Buffer + err = json.NewEncoder(&buf).Encode(platforms) if err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to encode response"}, http.StatusInternalServerError) return } + w.Header().Set("Content-Type", "application/json") + w.Write(buf.Bytes()) } func (h *platformHandler) GetPlatform(w http.ResponseWriter, r *http.Request) { id, ok := internal.GetIntFromRequestPath("id", r) if !ok { - http.Error(w, "Invalid platform ID", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid platform ID"}, http.StatusBadRequest) return } contextWithTimeOut, cancel := context.WithTimeout(r.Context(), 5*time.Second) @@ -135,17 +138,20 @@ func (h *platformHandler) GetPlatform(w http.ResponseWriter, r *http.Request) { platform, err := h.queries.GetPlatform(contextWithTimeOut, id) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - http.Error(w, "Platform not found", http.StatusNotFound) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Platform not found"}, http.StatusNotFound) return } - http.Error(w, "Failed to fetch platform", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to fetch platform"}, http.StatusInternalServerError) return } - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(platform) + var buf bytes.Buffer + err = json.NewEncoder(&buf).Encode(platform) if err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to encode response"}, http.StatusInternalServerError) return } + + w.Header().Set("Content-Type", "application/json") + w.Write(buf.Bytes()) } diff --git a/internal/platform/handler_test.go b/internal/platform/handler_test.go index 03d3d85..0fa8de6 100644 --- a/internal/platform/handler_test.go +++ b/internal/platform/handler_test.go @@ -7,6 +7,7 @@ import ( "errors" "net/http" "net/http/httptest" + "products/internal" "strconv" "testing" @@ -298,14 +299,14 @@ func TestDeletePlatform(t *testing.T) { id: "1", dbErr: errors.New("db error"), expectedStatus: http.StatusInternalServerError, - expectedBody: "Internal server error\n", + expectedBody: "Internal server error", }, { name: "DB Error (pgx.ErrNoRows)", id: "1", dbErr: pgx.ErrNoRows, expectedStatus: http.StatusNotFound, - expectedBody: "Platform not found\n", + expectedBody: "Platform not found", }, { name: "Invalid ID (Zero)", @@ -351,7 +352,15 @@ func TestDeletePlatform(t *testing.T) { } if tt.expectedBody != "" { - if rr.Body.String() != tt.expectedBody { + if rr.Code >= 400 { + var envelope internal.ErrorEnvelope + if err := json.NewDecoder(rr.Body).Decode(&envelope); err != nil { + t.Fatalf("failed to decode error response: %v", err) + } + if envelope.Detail != tt.expectedBody { + t.Errorf("expected detail %q, got %q", tt.expectedBody, envelope.Detail) + } + } else if rr.Body.String() != tt.expectedBody { t.Errorf("expected body %q, got %q", tt.expectedBody, rr.Body.String()) } } diff --git a/internal/product/handler.go b/internal/product/handler.go index 95dbb8b..318979d 100644 --- a/internal/product/handler.go +++ b/internal/product/handler.go @@ -1,6 +1,7 @@ package product import ( + "bytes" "context" "encoding/json" "errors" @@ -28,13 +29,13 @@ type productHandler struct { func (h *productHandler) CreateProduct(w http.ResponseWriter, r *http.Request) { var req createProductRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid request body"}, http.StatusBadRequest) return } req.Name = strings.TrimSpace(req.Name) if req.Name == "" || req.PlatformID == 0 { - http.Error(w, "Name and platform ID are required", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Name and platform ID are required"}, http.StatusBadRequest) return } @@ -42,7 +43,7 @@ func (h *productHandler) CreateProduct(w http.ResponseWriter, r *http.Request) { defer cancel() if err := h.queries.CreateProduct(contextWithTimeOut, req.ToParams()); err != nil { - http.Error(w, "Failed to create product", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to create product"}, http.StatusInternalServerError) return } @@ -52,17 +53,17 @@ func (h *productHandler) CreateProduct(w http.ResponseWriter, r *http.Request) { func (h *productHandler) DeleteProduct(w http.ResponseWriter, r *http.Request) { id, ok := internal.GetIntFromRequestPath("id", r) if !ok { - http.Error(w, "Invalid product ID", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid product ID"}, http.StatusBadRequest) return } contextWithTimeOut, cancel := context.WithTimeout(r.Context(), 5*time.Second) defer cancel() if _, err := h.queries.DeleteProduct(contextWithTimeOut, id); err != nil { if errors.Is(err, pgx.ErrNoRows) { - http.Error(w, "Product not found", http.StatusNotFound) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Product not found"}, http.StatusNotFound) return } - http.Error(w, "Failed to delete product", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to delete product"}, http.StatusInternalServerError) return } w.WriteHeader(http.StatusNoContent) @@ -71,19 +72,19 @@ func (h *productHandler) DeleteProduct(w http.ResponseWriter, r *http.Request) { func (h *productHandler) UpdateProduct(w http.ResponseWriter, r *http.Request) { id, ok := internal.GetIntFromRequestPath("id", r) if !ok { - http.Error(w, "Invalid product ID", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid product ID"}, http.StatusBadRequest) return } var req updateProductRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request body", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid request body"}, http.StatusBadRequest) return } req.Name = strings.TrimSpace(req.Name) if req.Name == "" || req.PlatformID == 0 { - http.Error(w, "Name and platform ID are required", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Name and platform ID are required"}, http.StatusBadRequest) return } @@ -92,10 +93,10 @@ func (h *productHandler) UpdateProduct(w http.ResponseWriter, r *http.Request) { if _, err := h.queries.UpdateProduct(ctx, req.ToParams(id)); err != nil { if errors.Is(err, pgx.ErrNoRows) { - http.Error(w, "Product not found", http.StatusNotFound) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Product not found"}, http.StatusNotFound) return } - http.Error(w, "Failed to update product", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to update product"}, http.StatusInternalServerError) return } @@ -106,7 +107,7 @@ func (h *productHandler) UpdateProduct(w http.ResponseWriter, r *http.Request) { func (h *productHandler) GetProductsByPlatform(w http.ResponseWriter, r *http.Request) { platformID, ok := internal.GetIntFromRequestPath("platform_id", r) if !ok { - http.Error(w, "Invalid platform ID", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid platform ID"}, http.StatusBadRequest) return } @@ -115,7 +116,7 @@ func (h *productHandler) GetProductsByPlatform(w http.ResponseWriter, r *http.Re products, err := h.queries.GetProductsByPlatform(ctx, platformID) if err != nil { - http.Error(w, "Failed to fetch products", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to fetch products"}, http.StatusInternalServerError) return } @@ -123,18 +124,21 @@ func (h *productHandler) GetProductsByPlatform(w http.ResponseWriter, r *http.Re products = []Product{} } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(products); err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(products); err != nil { + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to encode response"}, http.StatusInternalServerError) return } -} -// GetProductById fetches a single product by ID. + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write(buf.Bytes()); err != nil { + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to write response"}, http.StatusInternalServerError) + } +} func (h *productHandler) GetProductById(w http.ResponseWriter, r *http.Request) { id, ok := internal.GetIntFromRequestPath("id", r) if !ok { - http.Error(w, "Invalid product ID", http.StatusBadRequest) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Invalid product ID"}, http.StatusBadRequest) return } @@ -144,16 +148,21 @@ func (h *productHandler) GetProductById(w http.ResponseWriter, r *http.Request) product, err := h.queries.GetProductById(ctx, id) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - http.Error(w, "Product not found", http.StatusNotFound) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Product not found"}, http.StatusNotFound) return } - http.Error(w, "Failed to fetch product", http.StatusInternalServerError) + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to fetch product"}, http.StatusInternalServerError) return } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(product); err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(product); err != nil { + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to encode response"}, http.StatusInternalServerError) return } + + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write(buf.Bytes()); err != nil { + internal.HandleHttpError(w, internal.ErrorEnvelope{Detail: "Failed to write response"}, http.StatusInternalServerError) + } }