Skip to content
Merged
1 change: 1 addition & 0 deletions dbt/adapters/databricks/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __pre_deserialize__(cls, data: dict[Any, Any]) -> dict[Any, Any]:
data.setdefault("connection_parameters", {})
data["connection_parameters"].setdefault("_retry_stop_after_attempts_count", 30)
data["connection_parameters"].setdefault("_retry_delay_max", 60)
data["connection_parameters"].setdefault("_retry_server_directed_only", False)
Comment thread
sd-db marked this conversation as resolved.
Outdated
return data

def __post_init__(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
dependencies = [
"click>=8.2.0, <9.0.0",
"databricks-sdk>=0.68.0, <0.78.0",
"databricks-sql-connector[pyarrow]>=4.1.1, <4.1.4",
"databricks-sql-connector[pyarrow]>=4.1.5, <4.2.0",
Comment thread
sd-db marked this conversation as resolved.
Outdated
"dbt-adapters>=1.22.0, <1.23.0",
"dbt-common>=1.37.0, <1.38.0",
"dbt-core>=1.11.2, <1.11.7",
Expand Down
87 changes: 87 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.credentials import (
CATALOG_KEY_IN_SESSION_PROPERTIES,
DatabricksCredentials,
)
from dbt.adapters.databricks.impl import (
DatabricksRelationInfo,
Expand Down Expand Up @@ -241,6 +242,7 @@ def _connect_func(
expected_http_headers=None,
expected_no_token=None,
expected_client_creds=None,
expected_retry_params=None,
):
def connect(
server_hostname,
Expand Down Expand Up @@ -279,6 +281,9 @@ def connect(
assert http_headers is None
else:
assert http_headers == expected_http_headers
if expected_retry_params is not None:
for key, value in expected_retry_params.items():
assert kwargs.get(key) == value
return Mock()

return connect
Expand Down Expand Up @@ -354,6 +359,40 @@ def _test_databricks_sql_connector_http_header_connection(self, http_headers, co
assert connection.credentials.token == "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
assert connection.credentials.schema == "analytics"

def test_databricks_sql_connector_default_retry_params(self):
"""Verify default retry parameters are passed through to dbsql.connect().

Ensures _retry_stop_after_attempts_count, _retry_delay_max, and
_retry_server_directed_only defaults from __pre_deserialize__ reach
the connector.
"""
self._test_databricks_sql_connector_connection(
self._connect_func(
expected_retry_params={
"_retry_stop_after_attempts_count": 30,
"_retry_delay_max": 60,
"_retry_server_directed_only": False,
}
)
)

def test_databricks_sql_connector_retry_server_directed_only_opt_in(self):
"""Verify user can opt in to _retry_server_directed_only via connection_parameters.

When a user sets _retry_server_directed_only: true in their profile,
the connector should only retry requests when the server includes a
Retry-After header, preventing duplicate writes from blind retries.
"""
config = self._get_config(connection_parameters={"_retry_server_directed_only": True})
adapter = DatabricksAdapter(config, get_context("spawn"))

connect = self._connect_func(expected_retry_params={"_retry_server_directed_only": True})
with patch("dbt.adapters.databricks.handle.dbsql.connect", new=connect):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load

assert connection.state == "open"

@patch("dbt.adapters.databricks.api_client.DatabricksApiClient")
def test_list_relations_without_caching__no_relations(self, _):
with patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked:
Expand Down Expand Up @@ -1351,3 +1390,51 @@ def test_is_uniform_with_invalid_materialization_error(
DbtConfigError, match="When table_format is 'iceberg', materialized must be"
):
adapter.is_uniform(mock_config)


class TestDatabricksCredentialsPreDeserialize:
"""Tests for DatabricksCredentials.__pre_deserialize__ retry defaults."""

def test_pre_deserialize__default_retry_params(self):
"""Verify all retry defaults are set when connection_parameters is empty."""
data = {"connection_parameters": {}}
result = DatabricksCredentials.__pre_deserialize__(data)
assert result["connection_parameters"]["_retry_stop_after_attempts_count"] == 30
assert result["connection_parameters"]["_retry_delay_max"] == 60
assert result["connection_parameters"]["_retry_server_directed_only"] is False

def test_pre_deserialize__missing_connection_parameters(self):
"""Verify retry defaults are set even when connection_parameters key is absent."""
data = {}
result = DatabricksCredentials.__pre_deserialize__(data)
assert result["connection_parameters"]["_retry_stop_after_attempts_count"] == 30
assert result["connection_parameters"]["_retry_delay_max"] == 60
assert result["connection_parameters"]["_retry_server_directed_only"] is False

def test_pre_deserialize__user_override_retry_server_directed_only(self):
"""Verify user-provided _retry_server_directed_only is not overridden by default."""
data = {"connection_parameters": {"_retry_server_directed_only": True}}
result = DatabricksCredentials.__pre_deserialize__(data)
assert result["connection_parameters"]["_retry_server_directed_only"] is True

def test_pre_deserialize__user_override_preserves_other_defaults(self):
"""Verify overriding one retry param does not affect the others."""
data = {
"connection_parameters": {
"_retry_stop_after_attempts_count": 5,
"_retry_server_directed_only": True,
}
}
result = DatabricksCredentials.__pre_deserialize__(data)
assert result["connection_parameters"]["_retry_stop_after_attempts_count"] == 5
assert result["connection_parameters"]["_retry_delay_max"] == 60
assert result["connection_parameters"]["_retry_server_directed_only"] is True

def test_pre_deserialize__custom_params_preserve_retry_defaults(self):
"""Verify unrelated connection_parameters don't interfere with retry defaults."""
data = {"connection_parameters": {"custom_param": "value"}}
result = DatabricksCredentials.__pre_deserialize__(data)
assert result["connection_parameters"]["custom_param"] == "value"
assert result["connection_parameters"]["_retry_stop_after_attempts_count"] == 30
assert result["connection_parameters"]["_retry_delay_max"] == 60
assert result["connection_parameters"]["_retry_server_directed_only"] is False
Loading