diff --git a/internal/notif/discord/client.go b/internal/notif/discord/client.go index 606e2672c..e76d6f98e 100644 --- a/internal/notif/discord/client.go +++ b/internal/notif/discord/client.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/url" "time" @@ -133,27 +134,49 @@ func (c *Client) Send(entry model.NotifEntry) error { } cancelCtx, cancel := context.WithCancelCause(context.Background()) - timeoutCtx, _ := context.WithTimeoutCause(cancelCtx, *c.cfg.Timeout, errors.WithStack(context.DeadlineExceeded)) //nolint:govet // no need to manually cancel this context as we already rely on parent defer func() { cancel(errors.WithStack(context.Canceled)) }() - hc := http.Client{} - req, err := http.NewRequestWithContext(timeoutCtx, "POST", u.String(), dataBuf) - if err != nil { - return err - } + max_retries := 3 + for range max_retries { + timeoutCtx, _ := context.WithTimeoutCause(cancelCtx, *c.cfg.Timeout, errors.WithStack(context.DeadlineExceeded)) //nolint:govet // no need to manually cancel this context as we already rely on parent + + hc := http.Client{} + req, err := http.NewRequestWithContext(timeoutCtx, "POST", u.String(), bytes.NewReader(dataBuf.Bytes())) + if err != nil { + return err + } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", c.meta.UserAgent) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", c.meta.UserAgent) - resp, err := hc.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() + resp, err := hc.Do(req) + if err != nil { + return err + } + + if resp.StatusCode == http.StatusTooManyRequests { + bodyBytes, _ := io.ReadAll(resp.Body) + resp.Body.Close() + var retryData struct { + RetryAfter float64 `json:"retry_after"` + } + if err := json.Unmarshal(bodyBytes, &retryData); err == nil && retryData.RetryAfter > 0 { + time.Sleep(time.Duration(retryData.RetryAfter * float64(time.Second))) + continue + } + time.Sleep(5 * time.Second) + continue + } + + if resp.StatusCode != http.StatusNoContent { + bodyBytes, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return errors.Errorf("unexpected HTTP status %d: %s", resp.StatusCode, string(bodyBytes)) + } - if resp.StatusCode != http.StatusNoContent { - return errors.Errorf("unexpected HTTP status %d: %s", resp.StatusCode, resp.Body) + resp.Body.Close() + return nil } - return nil + return errors.New("max retries exceeded for Discord rate limit") } diff --git a/internal/notif/discord/client_test.go b/internal/notif/discord/client_test.go new file mode 100644 index 000000000..21fbae4e0 --- /dev/null +++ b/internal/notif/discord/client_test.go @@ -0,0 +1,127 @@ +package discord + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/crazy-max/diun/v4/internal/model" + "github.com/crazy-max/diun/v4/pkg/registry" + "github.com/opencontainers/go-digest" +) + +// mockServerConfig holds configuration for the mock Discord server +type mockServerConfig struct { + t *testing.T + handler func(w http.ResponseWriter, r *http.Request, count *atomic.Int32) +} + +// createMockServer creates a mock Discord server with the given handler +func createMockServer(cfg mockServerConfig) (server *httptest.Server, requestCount *atomic.Int32) { + cfg.t.Helper() + var counter atomic.Int32 + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + counter.Add(1) + io.ReadAll(r.Body) + cfg.handler(w, r, &counter) + })) + cfg.t.Cleanup(func() { server.Close() }) + return server, &counter +} + +// createTestClient creates a test client with the given webhook URL +func createTestClient(t *testing.T, webhookURL string) *Client { + t.Helper() + timeout := 10 * time.Second + cfg := &model.NotifDiscord{ + WebhookURL: webhookURL, + RenderEmbeds: ptr(false), + Timeout: &timeout, + } + return &Client{ + cfg: cfg, + meta: model.Meta{ + Name: "Test", + Hostname: "test", + UserAgent: "test", + }, + } +} + +// createTestEntry creates a standard test notification entry +func createTestEntry() model.NotifEntry { + img, _ := registry.ParseImage(registry.ParseImageOptions{ + Name: "test/image:latest", + }) + return model.NotifEntry{ + Provider: "test", + Image: img, + Manifest: registry.Manifest{ + Created: ptr(time.Now()), + Digest: digest.Digest("sha256:test"), + Platform: "linux/amd64", + }, + } +} + +func TestSendWith429Retry(t *testing.T) { + // Create mock server that returns 429 once, then success + server, requestCount := createMockServer(mockServerConfig{ + t: t, + handler: func(w http.ResponseWriter, r *http.Request, count *atomic.Int32) { + if count.Load() == 1 { + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprintf(w, `{"message": "You are being rate limited.", "retry_after": 0.1, "global": false}`) + t.Logf("Request 1: returning 429 (retry_after: 1.0s)") + } else { + w.WriteHeader(http.StatusNoContent) + t.Logf("Request %d: returning 204 (success)", count.Load()) + } + }, + }) + + client := createTestClient(t, server.URL) + entry := createTestEntry() + + err := client.Send(entry) + + if err != nil { + t.Fatalf("Expected success after retry, got error: %v", err) + } + + // Should have made 2 requests (1 failure + 1 success) + if requestCount.Load() != 2 { + t.Errorf("Expected 2 requests, got %d", requestCount.Load()) + } +} + +func TestSendWith429MaxRetries(t *testing.T) { + server, requestCount := createMockServer(mockServerConfig{ + t: t, + handler: func(w http.ResponseWriter, r *http.Request, count *atomic.Int32) { + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprintf(w, `{"message": "You are being rate limited.", "retry_after": 0.1, "global": false}`) + }, + }) + + client := createTestClient(t, server.URL) + entry := createTestEntry() + + err := client.Send(entry) + + if err == nil { + t.Fatal("Expected error after max retries, got nil") + } + + if requestCount.Load() != 3 { + t.Errorf("Expected 3 requests (max retries), got %d", requestCount.Load()) + } + + t.Logf("Correctly failed after max retries: %v", err) +} + +func ptr[T any](v T) *T { return &v }