Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ var (
ErrInvalidTimezone = errors.New("invalid timezone value")
)

func isNilPointer(v any) bool {
rv := reflect.ValueOf(v)
return rv.Kind() == reflect.Pointer && rv.IsNil()
}

func Named(name string, value any) driver.NamedValue {
return driver.NamedValue{
Name: name,
Expand Down Expand Up @@ -124,7 +129,9 @@ func bindPositional(tz *time.Location, query string, args ...any) (_ string, err
if argIndex < len(args) {
v := args[argIndex]
if fn, ok := v.(std_driver.Valuer); ok {
if v, err = fn.Value(); err != nil {
if isNilPointer(v) {
v = nil
} else if v, err = fn.Value(); err != nil {
return "", nil
}
}
Expand Down Expand Up @@ -167,7 +174,9 @@ func bindNumeric(tz *time.Location, query string, args ...any) (_ string, err er
)
for i, v := range args {
if fn, ok := v.(std_driver.Valuer); ok {
if v, err = fn.Value(); err != nil {
if isNilPointer(v) {
v = nil
} else if v, err = fn.Value(); err != nil {
return "", nil
}
}
Expand Down Expand Up @@ -202,7 +211,9 @@ func bindNamed(tz *time.Location, query string, args ...any) (_ string, err erro
case driver.NamedValue:
value := v.Value
if fn, ok := v.Value.(std_driver.Valuer); ok {
if value, err = fn.Value(); err != nil {
if isNilPointer(v.Value) {
value = nil
} else if value, err = fn.Value(); err != nil {
return "", err
}
}
Expand Down Expand Up @@ -314,9 +325,7 @@ func format(tz *time.Location, scale TimeUnit, v any) (string, error) {
}
return fmt.Sprintf("[%s]", val), nil
case fmt.Stringer:
if v := reflect.ValueOf(v); v.Kind() == reflect.Pointer &&
v.IsNil() &&
v.Type().Elem().Implements(reflect.TypeOf((*fmt.Stringer)(nil)).Elem()) {
if isNilPointer(v) {
return "NULL", nil
}
return quote(v.String()), nil
Expand Down
31 changes: 31 additions & 0 deletions bind_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package clickhouse

import (
std_driver "database/sql/driver"
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -496,3 +498,32 @@ func BenchmarkBindNamed(b *testing.B) {
}
}
}

// valuerType implements driver.Valuer with a value receiver,
// so a nil *valuerType satisfies the interface but panics on call.
type valuerType [16]byte

func (v valuerType) Value() (std_driver.Value, error) {
return fmt.Sprintf("%x", v[:]), nil
}

func TestBindPositionalNilDriverValuer(t *testing.T) {
var nilPtr *valuerType
actual, err := bind(time.Local, "SELECT ?", nilPtr)
require.NoError(t, err)
assert.Equal(t, "SELECT NULL", actual)
}

func TestBindNumericNilDriverValuer(t *testing.T) {
var nilPtr *valuerType
actual, err := bind(time.Local, "SELECT $1", nilPtr)
require.NoError(t, err)
assert.Equal(t, "SELECT NULL", actual)
}

func TestBindNamedNilDriverValuer(t *testing.T) {
var nilPtr *valuerType
actual, err := bind(time.Local, "SELECT @col", Named("col", nilPtr))
require.NoError(t, err)
assert.Equal(t, "SELECT NULL", actual)
}
102 changes: 102 additions & 0 deletions tests/issues/1823_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package issues

import (
"context"
"database/sql"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/ClickHouse/clickhouse-go/v2"
clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests"
)

// Test1823 verifies that passing a typed nil pointer to a type implementing
// driver.Valuer with a value receiver (e.g. *uuid.UUID) is bound as NULL
// rather than panicking inside fn.Value().
func Test1823(t *testing.T) {
const ddl = "CREATE TABLE IF NOT EXISTS test_1823 (id UUID, ref_id Nullable(UUID)) Engine Memory"

t.Run("native_select", func(t *testing.T) {
conn, err := clickhouse_tests.GetConnectionTCP("issues", clickhouse.Settings{
"max_execution_time": 60,
}, nil, &clickhouse.Compression{Method: clickhouse.CompressionLZ4})
require.NoError(t, err)

ctx := context.Background()
var nilUUID *uuid.UUID

require.NotPanics(t, func() {
var got *uuid.UUID
require.NoError(t, conn.QueryRow(ctx, "SELECT ?", nilUUID).Scan(&got))
require.Nil(t, got)
})
})

t.Run("native_insert", func(t *testing.T) {
conn, err := clickhouse_tests.GetConnectionTCP("issues", clickhouse.Settings{
"max_execution_time": 60,
}, nil, &clickhouse.Compression{Method: clickhouse.CompressionLZ4})
require.NoError(t, err)

ctx := context.Background()
require.NoError(t, conn.Exec(ctx, ddl))
defer conn.Exec(ctx, "DROP TABLE IF EXISTS test_1823")

id := uuid.New()
var nilUUID *uuid.UUID

require.NotPanics(t, func() {
require.NoError(t, conn.Exec(ctx, "INSERT INTO test_1823 (id, ref_id) VALUES (?, ?)", id, nilUUID))
})

var gotRef *uuid.UUID
require.NoError(t, conn.QueryRow(ctx, "SELECT ref_id FROM test_1823 WHERE id = ?", id).Scan(&gotRef))
require.Nil(t, gotRef)
})

t.Run("std_select", func(t *testing.T) {
env, err := clickhouse_tests.GetTestEnvironment("issues")
require.NoError(t, err)
opts := clickhouse_tests.ClientOptionsFromEnv(env, clickhouse.Settings{}, false)
db, err := sql.Open("clickhouse", clickhouse_tests.OptionsToDSN(&opts))
require.NoError(t, err)
defer db.Close()

ctx := context.Background()
var nilUUID *uuid.UUID

require.NotPanics(t, func() {
var got uuid.NullUUID
require.NoError(t, db.QueryRowContext(ctx, "SELECT ?", nilUUID).Scan(&got))
require.False(t, got.Valid)
})
})

t.Run("std_insert", func(t *testing.T) {
env, err := clickhouse_tests.GetTestEnvironment("issues")
require.NoError(t, err)
opts := clickhouse_tests.ClientOptionsFromEnv(env, clickhouse.Settings{}, false)
db, err := sql.Open("clickhouse", clickhouse_tests.OptionsToDSN(&opts))
require.NoError(t, err)
defer db.Close()

ctx := context.Background()
_, err = db.ExecContext(ctx, ddl)
require.NoError(t, err)
defer db.ExecContext(ctx, "DROP TABLE IF EXISTS test_1823")

id := uuid.New()
var nilUUID *uuid.UUID

require.NotPanics(t, func() {
_, err := db.ExecContext(ctx, "INSERT INTO test_1823 (id, ref_id) VALUES (?, ?)", id, nilUUID)
require.NoError(t, err)
})

var gotRef uuid.NullUUID
require.NoError(t, db.QueryRowContext(ctx, "SELECT ref_id FROM test_1823 WHERE id = ?", id).Scan(&gotRef))
require.False(t, gotRef.Valid)
})
}