diff --git a/common_client.go b/common_client.go index 3d13a0b..1f7aeec 100644 --- a/common_client.go +++ b/common_client.go @@ -120,9 +120,7 @@ func (o *httpClientOption) newRetryableHTTPClient() (*retryablehttp.Client, erro retryClient.HTTPClient.Timeout = o.timeout } - if retryClient.Logger == nil { - retryClient.Logger = o.logger - } + retryClient.Logger = o.logger transport, ok := retryClient.HTTPClient.Transport.(*http.Transport) if !ok { diff --git a/common_client_test.go b/common_client_test.go index fb16edf..10a59bd 100644 --- a/common_client_test.go +++ b/common_client_test.go @@ -3,6 +3,7 @@ package scaleset import ( "encoding/json" "io" + "log/slog" "net/http" "net/http/httptest" "net/url" @@ -156,6 +157,59 @@ func TestUserAgent(t *testing.T) { assert.Equal(t, want, got) } +func TestWithLogger(t *testing.T) { + newJSONHandler := func() slog.Handler { + return slog.NewJSONHandler( + io.Discard, + &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelError, + }, + ) + } + t.Run("WithLogger(nil) sets a discard logger on raw httpClientOption", func(t *testing.T) { + opts := httpClientOption{} + WithLogger(nil)(&opts) + require.NotNil(t, opts.logger, "WithLogger(nil) should set a discard logger, not leave it nil") + assert.Equal(t, slog.DiscardHandler, opts.logger.Handler(), "WithLogger(nil) should set a discard logger handler") + }) + + t.Run("WithLogger(customLogger) assigns the provided logger", func(t *testing.T) { + handler := newJSONHandler() + customLogger := slog.New(handler) + opts := httpClientOption{} + WithLogger(customLogger)(&opts) + require.Equal(t, customLogger, opts.logger, "WithLogger should assign the provided logger") + assert.Equal(t, handler, opts.logger.Handler(), "WithLogger should set the provided logger handler") + }) + + t.Run("WithLogger(nil) propagates discard logger to retryable HTTP client", func(t *testing.T) { + opts := httpClientOption{} + WithLogger(nil)(&opts) + client, err := opts.newRetryableHTTPClient() + require.NoError(t, err) + assert.NotNil(t, client.Logger, "retryable client should have logger set from WithLogger(nil)") + + logger, ok := client.Logger.(*slog.Logger) + require.True(t, ok, "retryable client logger should be a *slog.Logger") + assert.Same(t, opts.logger, logger, "retryable client logger should be the same logger set by WithLogger(nil)") + assert.Equal(t, slog.DiscardHandler, logger.Handler(), "retryable client logger should be a discard logger from WithLogger(nil)") + }) + + t.Run("WithLogger(customLogger) propagates custom logger to retryable HTTP client", func(t *testing.T) { + handler := newJSONHandler() + customLogger := slog.New(handler) + opts := httpClientOption{} + WithLogger(customLogger)(&opts) + client, err := opts.newRetryableHTTPClient() + require.NoError(t, err) + assert.NotNil(t, client.Logger, "retryable client should have logger set") + logger, ok := client.Logger.(*slog.Logger) + require.True(t, ok, "retryable client logger should be a *slog.Logger") + assert.Equal(t, handler, logger.Handler(), "retryable client logger should be the custom logger from WithLogger") + }) +} + // TestWithRetryableHTTPClient verifies that a custom retryable HTTP client // provided via WithRetryableHTTPClient is actually used instead of the built-in one func TestWithRetryableHTTPClient(t *testing.T) {