diff --git a/bind.go b/bind.go index fa90ade6c3..844ad6b8f8 100644 --- a/bind.go +++ b/bind.go @@ -17,6 +17,11 @@ var ( ErrInvalidTimezone = errors.New("invalid timezone value") ) +func isNilPointer(v any) bool { + rv := reflect.ValueOf(v) + return rv.Kind() == reflect.Pointer && rv.IsNil() +} + func Named(name string, value any) driver.NamedValue { return driver.NamedValue{ Name: name, @@ -124,7 +129,9 @@ func bindPositional(tz *time.Location, query string, args ...any) (_ string, err if argIndex < len(args) { v := args[argIndex] if fn, ok := v.(std_driver.Valuer); ok { - if v, err = fn.Value(); err != nil { + if isNilPointer(v) { + v = nil + } else if v, err = fn.Value(); err != nil { return "", nil } } @@ -167,7 +174,9 @@ func bindNumeric(tz *time.Location, query string, args ...any) (_ string, err er ) for i, v := range args { if fn, ok := v.(std_driver.Valuer); ok { - if v, err = fn.Value(); err != nil { + if isNilPointer(v) { + v = nil + } else if v, err = fn.Value(); err != nil { return "", nil } } @@ -202,7 +211,9 @@ func bindNamed(tz *time.Location, query string, args ...any) (_ string, err erro case driver.NamedValue: value := v.Value if fn, ok := v.Value.(std_driver.Valuer); ok { - if value, err = fn.Value(); err != nil { + if isNilPointer(v.Value) { + value = nil + } else if value, err = fn.Value(); err != nil { return "", err } } @@ -314,9 +325,7 @@ func format(tz *time.Location, scale TimeUnit, v any) (string, error) { } return fmt.Sprintf("[%s]", val), nil case fmt.Stringer: - if v := reflect.ValueOf(v); v.Kind() == reflect.Pointer && - v.IsNil() && - v.Type().Elem().Implements(reflect.TypeOf((*fmt.Stringer)(nil)).Elem()) { + if isNilPointer(v) { return "NULL", nil } return quote(v.String()), nil diff --git a/bind_test.go b/bind_test.go index 55829ac61a..59f8fd9fa4 100644 --- a/bind_test.go +++ b/bind_test.go @@ -1,6 +1,8 @@ package clickhouse import ( + std_driver "database/sql/driver" + "fmt" "testing" "time" @@ -496,3 +498,32 @@ func BenchmarkBindNamed(b *testing.B) { } } } + +// valuerType implements driver.Valuer with a value receiver, +// so a nil *valuerType satisfies the interface but panics on call. +type valuerType [16]byte + +func (v valuerType) Value() (std_driver.Value, error) { + return fmt.Sprintf("%x", v[:]), nil +} + +func TestBindPositionalNilDriverValuer(t *testing.T) { + var nilPtr *valuerType + actual, err := bind(time.Local, "SELECT ?", nilPtr) + require.NoError(t, err) + assert.Equal(t, "SELECT NULL", actual) +} + +func TestBindNumericNilDriverValuer(t *testing.T) { + var nilPtr *valuerType + actual, err := bind(time.Local, "SELECT $1", nilPtr) + require.NoError(t, err) + assert.Equal(t, "SELECT NULL", actual) +} + +func TestBindNamedNilDriverValuer(t *testing.T) { + var nilPtr *valuerType + actual, err := bind(time.Local, "SELECT @col", Named("col", nilPtr)) + require.NoError(t, err) + assert.Equal(t, "SELECT NULL", actual) +} diff --git a/tests/issues/1823_test.go b/tests/issues/1823_test.go new file mode 100644 index 0000000000..db3cc40929 --- /dev/null +++ b/tests/issues/1823_test.go @@ -0,0 +1,102 @@ +package issues + +import ( + "context" + "database/sql" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/ClickHouse/clickhouse-go/v2" + clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests" +) + +// Test1823 verifies that passing a typed nil pointer to a type implementing +// driver.Valuer with a value receiver (e.g. *uuid.UUID) is bound as NULL +// rather than panicking inside fn.Value(). +func Test1823(t *testing.T) { + const ddl = "CREATE TABLE IF NOT EXISTS test_1823 (id UUID, ref_id Nullable(UUID)) Engine Memory" + + t.Run("native_select", func(t *testing.T) { + conn, err := clickhouse_tests.GetConnectionTCP("issues", clickhouse.Settings{ + "max_execution_time": 60, + }, nil, &clickhouse.Compression{Method: clickhouse.CompressionLZ4}) + require.NoError(t, err) + + ctx := context.Background() + var nilUUID *uuid.UUID + + require.NotPanics(t, func() { + var got *uuid.UUID + require.NoError(t, conn.QueryRow(ctx, "SELECT ?", nilUUID).Scan(&got)) + require.Nil(t, got) + }) + }) + + t.Run("native_insert", func(t *testing.T) { + conn, err := clickhouse_tests.GetConnectionTCP("issues", clickhouse.Settings{ + "max_execution_time": 60, + }, nil, &clickhouse.Compression{Method: clickhouse.CompressionLZ4}) + require.NoError(t, err) + + ctx := context.Background() + require.NoError(t, conn.Exec(ctx, ddl)) + defer conn.Exec(ctx, "DROP TABLE IF EXISTS test_1823") + + id := uuid.New() + var nilUUID *uuid.UUID + + require.NotPanics(t, func() { + require.NoError(t, conn.Exec(ctx, "INSERT INTO test_1823 (id, ref_id) VALUES (?, ?)", id, nilUUID)) + }) + + var gotRef *uuid.UUID + require.NoError(t, conn.QueryRow(ctx, "SELECT ref_id FROM test_1823 WHERE id = ?", id).Scan(&gotRef)) + require.Nil(t, gotRef) + }) + + t.Run("std_select", func(t *testing.T) { + env, err := clickhouse_tests.GetTestEnvironment("issues") + require.NoError(t, err) + opts := clickhouse_tests.ClientOptionsFromEnv(env, clickhouse.Settings{}, false) + db, err := sql.Open("clickhouse", clickhouse_tests.OptionsToDSN(&opts)) + require.NoError(t, err) + defer db.Close() + + ctx := context.Background() + var nilUUID *uuid.UUID + + require.NotPanics(t, func() { + var got uuid.NullUUID + require.NoError(t, db.QueryRowContext(ctx, "SELECT ?", nilUUID).Scan(&got)) + require.False(t, got.Valid) + }) + }) + + t.Run("std_insert", func(t *testing.T) { + env, err := clickhouse_tests.GetTestEnvironment("issues") + require.NoError(t, err) + opts := clickhouse_tests.ClientOptionsFromEnv(env, clickhouse.Settings{}, false) + db, err := sql.Open("clickhouse", clickhouse_tests.OptionsToDSN(&opts)) + require.NoError(t, err) + defer db.Close() + + ctx := context.Background() + _, err = db.ExecContext(ctx, ddl) + require.NoError(t, err) + defer db.ExecContext(ctx, "DROP TABLE IF EXISTS test_1823") + + id := uuid.New() + var nilUUID *uuid.UUID + + require.NotPanics(t, func() { + _, err := db.ExecContext(ctx, "INSERT INTO test_1823 (id, ref_id) VALUES (?, ?)", id, nilUUID) + require.NoError(t, err) + }) + + var gotRef uuid.NullUUID + require.NoError(t, db.QueryRowContext(ctx, "SELECT ref_id FROM test_1823 WHERE id = ?", id).Scan(&gotRef)) + require.False(t, gotRef.Valid) + }) +}