diff --git a/clickhouse_std.go b/clickhouse_std.go index a223ea7a06..ddc74d5220 100644 --- a/clickhouse_std.go +++ b/clickhouse_std.go @@ -66,7 +66,7 @@ func (o *stdConnOpener) Connect(ctx context.Context) (_ driver.Conn, err error) } } - if o.opt.Addr == nil || len(o.opt.Addr) == 0 { + if len(o.opt.Addr) == 0 { return nil, ErrAcquireConnNoAddress } @@ -342,19 +342,31 @@ func (std *stdDriver) PrepareContext(ctx context.Context, query string) (driver. return nil, driver.ErrBadConn } - batch, err := std.conn.prepareBatch(ctx, func(nativeTransport, error) {}, func(context.Context) (nativeTransport, error) { return nil, nil }, query, chdriver.PrepareBatchOptions{}) - if err != nil { - if isConnBrokenError(err) { - std.debugf("PrepareContext got a fatal error, resetting connection: %v\n", err) - return nil, driver.ErrBadConn + // Detect INSERT to decide between batch mode and read/exec prepared stmt + // We keep the heuristic simple and case-insensitive: leading non-space must start with INSERT + trimmed := strings.TrimLeft(query, " \t\n\r") + if len(trimmed) >= 6 && strings.EqualFold(trimmed[:6], "insert") { + batch, err := std.conn.prepareBatch(ctx, func(nativeTransport, error) {}, func(context.Context) (nativeTransport, error) { return nil, nil }, query, chdriver.PrepareBatchOptions{}) + if err != nil { + if isConnBrokenError(err) { + std.debugf("PrepareContext got a fatal error, resetting connection: %v\n", err) + return nil, driver.ErrBadConn + } + std.debugf("PrepareContext error: %v\n", err) + return nil, err } - std.debugf("PrepareContext error: %v\n", err) - return nil, err + std.commit = batch.Send + return &stdBatch{ + batch: batch, + debugf: std.debugf, + }, nil } - std.commit = batch.Send - return &stdBatch{ - batch: batch, + + // For non-INSERT, return stdStmt that supports QueryContext and ExecContext + return &stdStmt{ + query: query, debugf: std.debugf, + conn: std.conn, }, nil } @@ -405,6 +417,92 @@ func (s *stdBatch) Query(args []driver.Value) (driver.Rows, error) { func (s *stdBatch) Close() error { return nil } +// stdStmt supports prepared statements for non-INSERT queries using the same connection. +type stdStmt struct { + query string + debugf func(format string, v ...any) + conn stdConnect +} + +func (s *stdStmt) NumInput() int { return -1 } + +// Exec executes non-INSERT statements prepared via stdStmt. +func (s *stdStmt) Exec(args []driver.Value) (driver.Result, error) { + values := make([]any, 0, len(args)) + for _, v := range args { + values = append(values, v) + } + if s.conn.isBad() { + s.debugf("[stmt][exec] connection is bad") + return nil, driver.ErrBadConn + } + if err := s.conn.exec(context.Background(), s.query, values...); err != nil { + if isConnBrokenError(err) { + s.debugf("[stmt][exec] fatal error, resetting connection: %v", err) + return nil, driver.ErrBadConn + } + s.debugf("[stmt][exec] error: %v", err) + return nil, err + } + return driver.RowsAffected(0), nil +} + +func (s *stdStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if s.conn.isBad() { + s.debugf("[stmt][execctx] connection is bad") + return nil, driver.ErrBadConn + } + if err := s.conn.exec(ctx, s.query, rebind(args)...); err != nil { + if isConnBrokenError(err) { + s.debugf("[stmt][execctx] fatal error, resetting connection: %v", err) + return nil, driver.ErrBadConn + } + s.debugf("[stmt][execctx] error: %v", err) + return nil, err + } + return driver.RowsAffected(0), nil +} + +func (s *stdStmt) Query(args []driver.Value) (driver.Rows, error) { + values := make([]any, 0, len(args)) + for _, v := range args { + values = append(values, v) + } + if s.conn.isBad() { + s.debugf("[stmt][query] connection is bad") + return nil, driver.ErrBadConn + } + r, err := s.conn.query(context.Background(), func(nativeTransport, error) {}, s.query, values...) + if isConnBrokenError(err) { + s.debugf("[stmt][query] fatal error, resetting connection: %v", err) + return nil, driver.ErrBadConn + } + if err != nil { + s.debugf("[stmt][query] error: %v", err) + return nil, err + } + return &stdRows{rows: r, debugf: s.debugf}, nil +} + +func (s *stdStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + if s.conn.isBad() { + s.debugf("[stmt][queryctx] connection is bad") + return nil, driver.ErrBadConn + } + r, err := s.conn.query(ctx, func(nativeTransport, error) {}, s.query, rebind(args)...) + if isConnBrokenError(err) { + s.debugf("[stmt][queryctx] fatal error, resetting connection: %v", err) + return nil, driver.ErrBadConn + } + if err != nil { + s.debugf("[stmt][queryctx] error: %v", err) + return nil, err + } + return &stdRows{rows: r, debugf: s.debugf}, nil +} + +func (s *stdStmt) Close() error { return nil } + type stdRows struct { rows *rows debugf func(format string, v ...any) diff --git a/examples/std/main_test.go b/examples/std/main_test.go index 085296cc26..11f7247267 100644 --- a/examples/std/main_test.go +++ b/examples/std/main_test.go @@ -152,3 +152,7 @@ func TestJSONStringExample(t *testing.T) { clickhouse_tests.SkipOnCloud(t, "cannot modify JSON settings on cloud") require.NoError(t, JSONStringExample()) } + +func TestPreparedSelectExample(t *testing.T) { + require.NoError(t, PreparedSelect()) +} diff --git a/examples/std/prepared.go b/examples/std/prepared.go new file mode 100644 index 0000000000..3857204abe --- /dev/null +++ b/examples/std/prepared.go @@ -0,0 +1,64 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package std + +import ( + "context" + "database/sql" + "fmt" + + "github.com/ClickHouse/clickhouse-go/v2" +) + +// PreparedSelect demonstrates using database/sql prepared statements for read queries. +func PreparedSelect() error { + conn, err := GetStdOpenDBConnection(clickhouse.Native, nil, nil, nil) + if err != nil { + return err + } + defer func(db *sql.DB) { _ = db.Close() }(conn) + + ctx := context.Background() + if err := conn.PingContext(ctx); err != nil { + return err + } + + stmt, err := conn.PrepareContext(ctx, "SELECT ? + ?") + if err != nil { + return err + } + defer func() { _ = stmt.Close() }() + + rows, err := stmt.QueryContext(ctx, 2, 3) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return fmt.Errorf("no rows returned from prepared SELECT") + } + var sum int64 + if err := rows.Scan(&sum); err != nil { + return err + } + if sum != 5 { + return fmt.Errorf("unexpected result from prepared SELECT: got %d, want 5", sum) + } + return rows.Err() +} diff --git a/tests/std/prepared_stmt_test.go b/tests/std/prepared_stmt_test.go new file mode 100644 index 0000000000..09bad4f151 --- /dev/null +++ b/tests/std/prepared_stmt_test.go @@ -0,0 +1,97 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package std + +import ( + "context" + "testing" + + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/stretchr/testify/require" +) + +func TestStdPreparedSelect(t *testing.T) { + db, err := GetStdOpenDBConnection(clickhouse.Native, nil, nil, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + ctx := context.Background() + require.NoError(t, db.PingContext(ctx)) + + stmt, err := db.PrepareContext(ctx, "SELECT ? + ?") + require.NoError(t, err) + t.Cleanup(func() { _ = stmt.Close() }) + + rows, err := stmt.QueryContext(ctx, 10, 5) + require.NoError(t, err) + t.Cleanup(func() { _ = rows.Close() }) + + require.True(t, rows.Next()) + var sum int64 + require.NoError(t, rows.Scan(&sum)) + require.EqualValues(t, 15, sum) + require.NoError(t, rows.Err()) +} + +// Test for prepared selects using both positional and named params. +func TestStdPreparedFunds(t *testing.T) { + db, err := GetStdOpenDBConnection(clickhouse.Native, nil, nil, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + ctx := context.Background() + require.NoError(t, db.PingContext(ctx)) + + _, _ = db.ExecContext(ctx, "DROP TABLE IF EXISTS std_prepared_funds") + _, err = db.ExecContext(ctx, ` + CREATE TABLE std_prepared_funds ( + symbol String, + name String + ) Engine = Memory`) + require.NoError(t, err) + t.Cleanup(func() { _, _ = db.ExecContext(ctx, "DROP TABLE IF EXISTS std_prepared_funds") }) + + _, err = db.ExecContext(ctx, `INSERT INTO std_prepared_funds (symbol, name) VALUES ('abc', 'ABC Fund')`) + require.NoError(t, err) + + // q1: positional placeholder + stmt1, err := db.PrepareContext(ctx, `SELECT name FROM std_prepared_funds WHERE symbol=? LIMIT 1`) + require.NoError(t, err) + t.Cleanup(func() { _ = stmt1.Close() }) + rows1, err := stmt1.QueryContext(ctx, "abc") + require.NoError(t, err) + t.Cleanup(func() { _ = rows1.Close() }) + require.True(t, rows1.Next()) + var name1 string + require.NoError(t, rows1.Scan(&name1)) + require.Equal(t, "ABC Fund", name1) + require.NoError(t, rows1.Err()) + + // q2: named query parameter + stmt2, err := db.PrepareContext(ctx, `SELECT name FROM std_prepared_funds WHERE symbol={symbol: String} LIMIT 1`) + require.NoError(t, err) + t.Cleanup(func() { _ = stmt2.Close() }) + rows2, err := stmt2.QueryContext(ctx, clickhouse.Named("symbol", "abc")) + require.NoError(t, err) + t.Cleanup(func() { _ = rows2.Close() }) + require.True(t, rows2.Next()) + var name2 string + require.NoError(t, rows2.Scan(&name2)) + require.Equal(t, "ABC Fund", name2) + require.NoError(t, rows2.Err()) +} diff --git a/tests/std/prepared_stmt_use_db_test.go b/tests/std/prepared_stmt_use_db_test.go new file mode 100644 index 0000000000..d1885aed18 --- /dev/null +++ b/tests/std/prepared_stmt_use_db_test.go @@ -0,0 +1,54 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package std + +import ( + "context" + "testing" + + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/stretchr/testify/require" +) + +// Ensures we can execute a USE ; followed by a prepared SELECT. +func TestStdPreparedSelectWithUseDatabase(t *testing.T) { + db, err := GetStdOpenDBConnection(clickhouse.Native, nil, nil, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + ctx := context.Background() + require.NoError(t, db.PingContext(ctx)) + + // Explicit USE should work as Exec on connection + _, err = db.ExecContext(ctx, "USE default") + require.NoError(t, err) + + stmt, err := db.PrepareContext(ctx, "SELECT ? + ?") + require.NoError(t, err) + t.Cleanup(func() { _ = stmt.Close() }) + + rows, err := stmt.QueryContext(ctx, 7, 8) + require.NoError(t, err) + t.Cleanup(func() { _ = rows.Close() }) + + require.True(t, rows.Next()) + var sum int64 + require.NoError(t, rows.Scan(&sum)) + require.EqualValues(t, 15, sum) + require.NoError(t, rows.Err()) +}