Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/crate/client/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,27 @@
import typing as t
import warnings
from datetime import datetime, timedelta, timezone
from itertools import count

from .converter import Converter, DataType
from .exceptions import ProgrammingError

_NAMED_PARAM_RE = re.compile(r"%\(([^)]+)\)s")


def _rewrite_pyformat_sql(sql: str) -> str:
"""Replace %(name)s placeholders with $N positional markers (1-indexed)."""
counter = count(1)
return _NAMED_PARAM_RE.sub(lambda _: f"${next(counter)}", sql)


def _convert_named_to_positional(
sql: str, params: t.Dict[str, t.Any]
) -> t.Tuple[str, t.List[t.Any]]:
"""Convert pyformat-style named parameters to positional qmark parameters.
"""Convert pyformat-style named parameters to positional parameters.

Converts ``%(name)s`` placeholders to ``?`` and returns an ordered list
of corresponding values extracted from ``params``.
Converts ``%(name)s`` placeholders to ``$N`` (1-indexed) and returns an
ordered list of corresponding values extracted from ``params``.

The same name may appear multiple times; each occurrence appends the
value to the positional list independently.
Expand All @@ -47,7 +54,7 @@ def _convert_named_to_positional(

sql = "SELECT * FROM t WHERE a = %(a)s AND b = %(b)s"
params = {"a": 1, "b": 2}
# returns: ("SELECT * FROM t WHERE a = ? AND b = ?", [1, 2])
# returns: ("SELECT * FROM t WHERE a = $1 AND b = $2", [1, 2])
"""
positions = {}
idx = 1
Expand Down Expand Up @@ -136,6 +143,8 @@ def execute(self, sql, parameters=None, bulk_parameters=None):

if isinstance(parameters, dict):
sql, parameters = _convert_named_to_positional(sql, parameters)
elif bulk_parameters is not None and _NAMED_PARAM_RE.search(sql):
sql = _rewrite_pyformat_sql(sql)
Comment on lines +146 to +147

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't the bulk_parameters also need some sort of conversion and be in dict format for this to work reliably?

Or does pyformat imply that the bulk parameters order matches the order of named placeholders? If so - is that documented somewhere?

Do other clients allow this?


self._result = self.connection.client.sql(
sql, parameters, bulk_parameters
Expand Down
28 changes: 28 additions & 0 deletions tests/client/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,34 @@ def test_execute_with_bulk_args(mocked_connection):
mocked_connection.client.sql.assert_called_once_with(statement, None, [[1]])


def test_execute_with_pyformat_sql_and_bulk_parameters(mocked_connection):
"""
cursor.execute() converts %(name)s SQL to $N when bulk_parameters is
provided. Rows are already positional; only the SQL needs conversion.
"""
cursor = mocked_connection.cursor()
sql = "INSERT INTO t (id, val) VALUES (%(id)s, %(val)s)"
bulk = [[1, "hello"], [2, "world"]]
cursor.execute(sql, bulk_parameters=bulk)
mocked_connection.client.sql.assert_called_once_with(
"INSERT INTO t (id, val) VALUES ($1, $2)", None, bulk
)


def test_execute_with_pyformat_sql_and_bulk_parameters_no_placeholders(
mocked_connection,
):
"""
SQL without %(name)s placeholders is passed through unchanged
even when bulk_parameters is provided.
"""
cursor = mocked_connection.cursor()
sql = "INSERT INTO t (id, val) VALUES (?, ?)"
bulk = [[1, "hello"], [2, "world"]]
cursor.execute(sql, bulk_parameters=bulk)
mocked_connection.client.sql.assert_called_once_with(sql, None, bulk)


def test_execute_custom_converter(mocked_connection):
"""
Verify that a custom converter is correctly applied when passed to a cursor.
Expand Down
Loading