diff --git a/README.md b/README.md index 8cef9f3bb7..8a3e46d348 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Both support TCP and HTTP transport. When in doubt, use the native interface. * Supports both native ClickHouse TCP and HTTP client-server protocols * Compatibility with [`database/sql`](#std-databasesql-interface) ([slower](#benchmark) than [native interface](#native-interface)!) * [`database/sql`](#std-databasesql-interface) supports both native TCP and HTTP protocols for transport. -* Marshal rows into structs ([ScanStruct](examples/clickhouse_api/scan_struct.go), [Select](examples/clickhouse_api/select_struct.go)) +* Marshal rows into structs ([ScanStruct](examples/clickhouse_api/scan_struct.go), [Select](examples/clickhouse_api/select_struct.go), [StructIter](examples/clickhouse_api/iterators.go) for native `driver.Rows`) * Unmarshal struct to row ([AppendStruct](benchmark/v2/write-native-struct/main.go)) * Connection pool (for both TCP-Native and HTTP) * Failover and load balancing diff --git a/examples/clickhouse_api/iterators.go b/examples/clickhouse_api/iterators.go new file mode 100644 index 0000000000..6636a11cee --- /dev/null +++ b/examples/clickhouse_api/iterators.go @@ -0,0 +1,65 @@ +package clickhouse_api + +import ( + "context" + "fmt" + + chdriver "github.com/ClickHouse/clickhouse-go/v2/lib/driver" +) + +func Iterators() (err error) { + conn, err := GetNativeConnection(nil, nil, nil) + if err != nil { + return err + } + + ctx := context.Background() + defer func() { + if dropErr := conn.Exec(ctx, "DROP TABLE example_iterators"); dropErr != nil && err == nil { + err = fmt.Errorf("drop example_iterators: %w", dropErr) + } + }() + + if err := conn.Exec(ctx, `DROP TABLE IF EXISTS example_iterators`); err != nil { + return err + } + if err := conn.Exec(ctx, ` + CREATE TABLE example_iterators ( + Col1 UInt8, + Col2 String + ) ENGINE = Memory + `); err != nil { + return err + } + + batch, err := conn.PrepareBatch(ctx, "INSERT INTO example_iterators") + if err != nil { + return err + } + for i := 0; i < 3; i++ { + if err := batch.Append(uint8(i), fmt.Sprintf("value_%d", i)); err != nil { + return err + } + } + if err := batch.Send(); err != nil { + return err + } + + type result struct { + Col1 uint8 + Col2 string + } + + rows, err := conn.Query(ctx, "SELECT Col1, Col2 FROM example_iterators ORDER BY Col1") + if err != nil { + return err + } + for value, err := range chdriver.StructIter[result](rows) { + if err != nil { + return err + } + fmt.Printf("struct row: col1=%d, col2=%s\n", value.Col1, value.Col2) + } + + return nil +} diff --git a/examples/clickhouse_api/main_test.go b/examples/clickhouse_api/main_test.go index f94cb38f0b..51654b08d4 100644 --- a/examples/clickhouse_api/main_test.go +++ b/examples/clickhouse_api/main_test.go @@ -201,6 +201,10 @@ func TestQueryRows(t *testing.T) { require.NoError(t, QueryRows()) } +func TestIterators(t *testing.T) { + require.NoError(t, Iterators()) +} + func TestSSL(t *testing.T) { require.NoError(t, SSLVersion()) } diff --git a/lib/driver/iter.go b/lib/driver/iter.go new file mode 100644 index 0000000000..5acb0e1ef7 --- /dev/null +++ b/lib/driver/iter.go @@ -0,0 +1,45 @@ +package driver + +import ( + "errors" + "iter" +) + +// StructIter returns an iterator that scans each row into T with ScanStruct. +// It works with native Rows, not database/sql.Rows. +func StructIter[T any](rows Rows) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + for rows.Next() { + var value T + if err := rows.ScanStruct(&value); err != nil { + var zero T + if closeErr := rows.Close(); closeErr != nil { + err = errors.Join(err, closeErr) + } + _ = yield(zero, err) + return + } + if !yield(value, nil) { + // The caller stopped iteration, so the protocol forbids yielding a close error. + if closeErr := rows.Close(); closeErr != nil { + return + } + return + } + } + + if err := rows.Err(); err != nil { + var zero T + if closeErr := rows.Close(); closeErr != nil { + err = errors.Join(err, closeErr) + } + _ = yield(zero, err) + return + } + + if err := rows.Close(); err != nil { + var zero T + _ = yield(zero, err) + } + } +} diff --git a/lib/driver/iter_test.go b/lib/driver/iter_test.go new file mode 100644 index 0000000000..0ff036b562 --- /dev/null +++ b/lib/driver/iter_test.go @@ -0,0 +1,183 @@ +package driver + +import ( + "errors" + "io" + "reflect" + "testing" +) + +type testRows struct { + values []int + index int + closeCalls int + err error + closeErr error + scanStructErrAt int +} + +func (r *testRows) Next() bool { + if r.index >= len(r.values) { + return false + } + r.index++ + return true +} + +func (r *testRows) Scan(dest ...any) error { return nil } + +func (r *testRows) ScanStruct(dest any) error { + if r.scanStructErrAt > 0 && r.index == r.scanStructErrAt { + return io.ErrUnexpectedEOF + } + value := reflect.ValueOf(dest) + if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Struct { + return errors.New("expected pointer to struct") + } + field := value.Elem().FieldByName("Value") + if !field.IsValid() || !field.CanSet() || field.Kind() != reflect.Int { + return errors.New("expected struct with settable int Value field") + } + field.SetInt(int64(r.values[r.index-1])) + return nil +} + +func (r *testRows) ColumnTypes() []ColumnType { return nil } + +func (r *testRows) Totals(dest ...any) error { return nil } + +func (r *testRows) Columns() []string { return nil } + +func (r *testRows) Close() error { + r.closeCalls++ + return r.closeErr +} + +func (r *testRows) Err() error { return r.err } + +func (r *testRows) HasData() bool { return r.index < len(r.values) } + +func TestStructIter(t *testing.T) { + type item struct { + Value int + } + + rows := &testRows{values: []int{4, 5, 6}} + + var got []int + for value, err := range StructIter[item](rows) { + if err != nil { + t.Fatalf("unexpected iter error: %v", err) + } + got = append(got, value.Value) + } + + if !reflect.DeepEqual(got, []int{4, 5, 6}) { + t.Fatalf("unexpected values: %#v", got) + } + if rows.closeCalls == 0 { + t.Fatal("expected rows to be closed") + } +} + +func TestStructIterScanError(t *testing.T) { + type item struct { + Value int + } + + rows := &testRows{values: []int{7, 8, 9}, scanStructErrAt: 2} + + var got []int + var gotErr error + for value, err := range StructIter[item](rows) { + if err != nil { + gotErr = err + break + } + got = append(got, value.Value) + } + + if !errors.Is(gotErr, io.ErrUnexpectedEOF) { + t.Fatalf("unexpected error: %v", gotErr) + } + if !reflect.DeepEqual(got, []int{7}) { + t.Fatalf("unexpected values before error: %#v", got) + } +} + +func TestStructIterTerminalRowsError(t *testing.T) { + type item struct { + Value int + } + + rows := &testRows{values: []int{1}, err: io.EOF} + + var got []item + var gotErr error + for value, err := range StructIter[item](rows) { + if err != nil { + gotErr = err + break + } + got = append(got, value) + } + + if !errors.Is(gotErr, io.EOF) { + t.Fatalf("unexpected terminal error: %v", gotErr) + } + if !reflect.DeepEqual(got, []item{{Value: 1}}) { + t.Fatalf("unexpected values before terminal error: %#v", got) + } +} + +func TestStructIterCloseError(t *testing.T) { + type item struct { + Value int + } + + rows := &testRows{values: []int{1}, closeErr: io.ErrClosedPipe} + + var got []item + var gotErr error + for value, err := range StructIter[item](rows) { + if err != nil { + gotErr = err + break + } + got = append(got, value) + } + + if !errors.Is(gotErr, io.ErrClosedPipe) { + t.Fatalf("unexpected close error: %v", gotErr) + } + if !reflect.DeepEqual(got, []item{{Value: 1}}) { + t.Fatalf("unexpected values before close error: %#v", got) + } + if rows.closeCalls != 1 { + t.Fatalf("unexpected close calls: %d", rows.closeCalls) + } +} + +func TestStructIterStopsAfterCallerBreak(t *testing.T) { + type item struct { + Value int + } + + rows := &testRows{values: []int{1, 2, 3}} + + var got []item + for value, err := range StructIter[item](rows) { + if err != nil { + t.Fatalf("unexpected iter error: %v", err) + } + got = append(got, value) + break + } + + if !reflect.DeepEqual(got, []item{{Value: 1}}) { + t.Fatalf("unexpected values before caller break: %#v", got) + } + if rows.closeCalls != 1 { + t.Fatalf("unexpected close calls: %d", rows.closeCalls) + } +} diff --git a/tests/iterator_test.go b/tests/iterator_test.go new file mode 100644 index 0000000000..ffecc34c52 --- /dev/null +++ b/tests/iterator_test.go @@ -0,0 +1,55 @@ +package tests + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ClickHouse/clickhouse-go/v2" + chdriver "github.com/ClickHouse/clickhouse-go/v2/lib/driver" +) + +func TestStructIterProtocols(t *testing.T) { + TestProtocols(t, func(t *testing.T, protocol clickhouse.Protocol) { + conn, err := GetNativeConnection(t, protocol, nil, nil, nil) + require.NoError(t, err) + + ctx := context.Background() + const table = "test_struct_iter" + require.NoError(t, conn.Exec(ctx, "DROP TABLE IF EXISTS "+table)) + t.Cleanup(func() { + require.NoError(t, conn.Exec(ctx, "DROP TABLE IF EXISTS "+table)) + }) + require.NoError(t, conn.Exec(ctx, ` + CREATE TABLE test_struct_iter ( + Col1 UInt8, + Col2 String + ) ENGINE = Memory + `)) + + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_struct_iter") + require.NoError(t, err) + require.NoError(t, batch.Append(uint8(1), "one")) + require.NoError(t, batch.Append(uint8(2), "two")) + require.NoError(t, batch.Send()) + + rows, err := conn.Query(ctx, "SELECT Col1, Col2 FROM test_struct_iter ORDER BY Col1") + require.NoError(t, err) + + type result struct { + Col1 uint8 + Col2 string + } + var got []result + for value, err := range chdriver.StructIter[result](rows) { + require.NoError(t, err) + got = append(got, value) + } + + require.Equal(t, []result{ + {Col1: 1, Col2: "one"}, + {Col1: 2, Col2: "two"}, + }, got) + }) +}