diff --git a/src/crate/client/cursor.py b/src/crate/client/cursor.py index 7ac7f051..5f6cce9c 100644 --- a/src/crate/client/cursor.py +++ b/src/crate/client/cursor.py @@ -22,6 +22,7 @@ import typing as t import warnings from datetime import datetime, timedelta, timezone +from itertools import count from .converter import Converter, DataType from .exceptions import ProgrammingError @@ -29,13 +30,19 @@ _NAMED_PARAM_RE = re.compile(r"%\(([^)]+)\)s") +def _rewrite_pyformat_sql(sql: str) -> str: + """Replace %(name)s placeholders with $N positional markers (1-indexed).""" + counter = count(1) + return _NAMED_PARAM_RE.sub(lambda _: f"${next(counter)}", sql) + + def _convert_named_to_positional( sql: str, params: t.Dict[str, t.Any] ) -> t.Tuple[str, t.List[t.Any]]: - """Convert pyformat-style named parameters to positional qmark parameters. + """Convert pyformat-style named parameters to positional parameters. - Converts ``%(name)s`` placeholders to ``?`` and returns an ordered list - of corresponding values extracted from ``params``. + Converts ``%(name)s`` placeholders to ``$N`` (1-indexed) and returns an + ordered list of corresponding values extracted from ``params``. The same name may appear multiple times; each occurrence appends the value to the positional list independently. @@ -47,7 +54,7 @@ def _convert_named_to_positional( sql = "SELECT * FROM t WHERE a = %(a)s AND b = %(b)s" params = {"a": 1, "b": 2} - # returns: ("SELECT * FROM t WHERE a = ? AND b = ?", [1, 2]) + # returns: ("SELECT * FROM t WHERE a = $1 AND b = $2", [1, 2]) """ positions = {} idx = 1 @@ -136,6 +143,8 @@ def execute(self, sql, parameters=None, bulk_parameters=None): if isinstance(parameters, dict): sql, parameters = _convert_named_to_positional(sql, parameters) + elif bulk_parameters is not None and _NAMED_PARAM_RE.search(sql): + sql = _rewrite_pyformat_sql(sql) self._result = self.connection.client.sql( sql, parameters, bulk_parameters diff --git a/tests/client/test_cursor.py b/tests/client/test_cursor.py index ace23c4a..9252e053 100644 --- a/tests/client/test_cursor.py +++ b/tests/client/test_cursor.py @@ -329,6 +329,34 @@ def test_execute_with_bulk_args(mocked_connection): mocked_connection.client.sql.assert_called_once_with(statement, None, [[1]]) +def test_execute_with_pyformat_sql_and_bulk_parameters(mocked_connection): + """ + cursor.execute() converts %(name)s SQL to $N when bulk_parameters is + provided. Rows are already positional; only the SQL needs conversion. + """ + cursor = mocked_connection.cursor() + sql = "INSERT INTO t (id, val) VALUES (%(id)s, %(val)s)" + bulk = [[1, "hello"], [2, "world"]] + cursor.execute(sql, bulk_parameters=bulk) + mocked_connection.client.sql.assert_called_once_with( + "INSERT INTO t (id, val) VALUES ($1, $2)", None, bulk + ) + + +def test_execute_with_pyformat_sql_and_bulk_parameters_no_placeholders( + mocked_connection, +): + """ + SQL without %(name)s placeholders is passed through unchanged + even when bulk_parameters is provided. + """ + cursor = mocked_connection.cursor() + sql = "INSERT INTO t (id, val) VALUES (?, ?)" + bulk = [[1, "hello"], [2, "world"]] + cursor.execute(sql, bulk_parameters=bulk) + mocked_connection.client.sql.assert_called_once_with(sql, None, bulk) + + def test_execute_custom_converter(mocked_connection): """ Verify that a custom converter is correctly applied when passed to a cursor.