diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 7b7310b02..c97d9118a 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -18,6 +18,7 @@ ### Internal Changes * Replace the async-disabling mechanism on token refresh failure with a 1-minute retry backoff. Previously, a single failed async refresh would disable proactive token renewal until the token expired. Now, the SDK waits a short cooldown period and retries, improving resilience to transient errors. * Extract `_resolve_profile` to simplify config file loading and improve `__settings__` error messages. +* Add `host_type` to `HostMetadata` and `HostType.from_api_value()` for normalizing host type strings from the discovery endpoint. ### API Changes * Add `create_catalog()`, `create_synced_table()`, `delete_catalog()`, `delete_synced_table()`, `get_catalog()` and `get_synced_table()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service. diff --git a/databricks/sdk/client_types.py b/databricks/sdk/client_types.py index 3937cdde0..729a454af 100644 --- a/databricks/sdk/client_types.py +++ b/databricks/sdk/client_types.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional class HostType(Enum): @@ -8,6 +9,24 @@ class HostType(Enum): WORKSPACE = "workspace" UNIFIED = "unified" + @staticmethod + def from_api_value(value: str) -> Optional["HostType"]: + """Normalize a host_type string from the API to a HostType enum value. + + Maps "workspace" -> WORKSPACE, "account" -> ACCOUNTS, "unified" -> UNIFIED. + Returns None for unrecognized or empty values. + """ + if not value: + return None + normalized = value.lower() + if normalized == "workspace": + return HostType.WORKSPACE + if normalized == "account": + return HostType.ACCOUNTS + if normalized == "unified": + return HostType.UNIFIED + return None + class ClientType(Enum): """Enum representing the type of client configuration.""" diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 599dc0efe..cc34d2dbe 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -290,6 +290,7 @@ def __init__( self._header_factory = None self._inner = {} self._user_agent_other_info = [] + self._resolved_host_type = None self._custom_headers = custom_headers or {} if credentials_strategy and credentials_provider: raise ValueError("When providing `credentials_strategy` field, `credential_provider` cannot be specified.") @@ -655,6 +656,11 @@ def _resolve_host_metadata(self) -> None: if not self.cloud and meta.cloud: logger.debug(f"Resolved cloud from host metadata: {meta.cloud.value}") self.cloud = meta.cloud + if self._resolved_host_type is None and meta.host_type: + resolved = HostType.from_api_value(meta.host_type) + if resolved is not None: + logger.debug(f"Resolved host_type from host metadata: {meta.host_type}") + self._resolved_host_type = resolved # Account hosts use account_id as the OIDC token audience instead of the token endpoint. # This is a special case: when the metadata has no workspace_id, the host is acting as an # account-level endpoint and the audience must be scoped to the account. diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 9f3656fda..85ac05b2a 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -448,6 +448,7 @@ class HostMetadata: account_id: Optional[str] = None workspace_id: Optional[str] = None cloud: Optional[Cloud] = None + host_type: Optional[str] = None @staticmethod def from_dict(d: dict) -> "HostMetadata": @@ -456,6 +457,7 @@ def from_dict(d: dict) -> "HostMetadata": account_id=d.get("account_id"), workspace_id=d.get("workspace_id"), cloud=Cloud.parse(d.get("cloud", "")), + host_type=d.get("host_type"), ) def as_dict(self) -> dict: @@ -464,6 +466,7 @@ def as_dict(self) -> dict: "account_id": self.account_id, "workspace_id": self.workspace_id, "cloud": self.cloud.value if self.cloud else None, + "host_type": self.host_type, } diff --git a/tests/test_config.py b/tests/test_config.py index 0130fab88..d6e0cd4e4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -973,3 +973,119 @@ def test_resolve_host_metadata_does_not_overwrite_token_audience(mocker): token_audience="custom-audience", ) assert config.token_audience == "custom-audience" + + +# --------------------------------------------------------------------------- +# HostType.from_api_value tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "api_value,expected", + [ + ("workspace", HostType.WORKSPACE), + ("Workspace", HostType.WORKSPACE), + ("WORKSPACE", HostType.WORKSPACE), + ("account", HostType.ACCOUNTS), + ("Account", HostType.ACCOUNTS), + ("ACCOUNT", HostType.ACCOUNTS), + ("unified", HostType.UNIFIED), + ("Unified", HostType.UNIFIED), + ("UNIFIED", HostType.UNIFIED), + ("unknown", None), + ("", None), + (None, None), + ], +) +def test_host_type_from_api_value(api_value, expected): + assert HostType.from_api_value(api_value) == expected + + +# --------------------------------------------------------------------------- +# HostMetadata.from_dict with host_type field +# --------------------------------------------------------------------------- + + +def test_host_metadata_from_dict_with_host_type(): + """HostMetadata.from_dict parses the host_type field.""" + d = { + "oidc_endpoint": "https://host/oidc", + "account_id": "acc-1", + "host_type": "workspace", + } + meta = HostMetadata.from_dict(d) + assert meta.host_type == "workspace" + + +def test_host_metadata_from_dict_without_host_type(): + """HostMetadata.from_dict returns None for missing host_type.""" + d = {"oidc_endpoint": "https://host/oidc"} + meta = HostMetadata.from_dict(d) + assert meta.host_type is None + + +# --------------------------------------------------------------------------- +# _resolve_host_metadata populates _resolved_host_type +# --------------------------------------------------------------------------- + + +def test_resolve_host_metadata_populates_resolved_host_type(mocker): + """_resolved_host_type is populated from metadata host_type.""" + mocker.patch( + "databricks.sdk.config.get_host_metadata", + return_value=HostMetadata.from_dict( + { + "oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc", + "host_type": "unified", + } + ), + ) + config = Config(host=_DUMMY_WS_HOST, token="t") + assert config._resolved_host_type == HostType.UNIFIED + + +def test_resolve_host_metadata_does_not_overwrite_existing_resolved_host_type(mocker): + """An already-set _resolved_host_type is not overwritten by metadata.""" + mocker.patch( + "databricks.sdk.config.get_host_metadata", + return_value=HostMetadata.from_dict( + { + "oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc", + "host_type": "account", + } + ), + ) + config = Config(host=_DUMMY_WS_HOST, token="t") + # Manually set resolved host type then re-resolve + config._resolved_host_type = HostType.UNIFIED + config._resolve_host_metadata() + assert config._resolved_host_type == HostType.UNIFIED + + +def test_resolve_host_metadata_resolved_host_type_none_when_missing(mocker): + """_resolved_host_type stays None when metadata has no host_type.""" + mocker.patch( + "databricks.sdk.config.get_host_metadata", + return_value=HostMetadata.from_dict( + { + "oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc", + } + ), + ) + config = Config(host=_DUMMY_WS_HOST, token="t") + assert config._resolved_host_type is None + + +def test_resolve_host_metadata_resolved_host_type_unknown_string(mocker): + """_resolved_host_type stays None for unrecognized host_type strings.""" + mocker.patch( + "databricks.sdk.config.get_host_metadata", + return_value=HostMetadata.from_dict( + { + "oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc", + "host_type": "some_future_type", + } + ), + ) + config = Config(host=_DUMMY_WS_HOST, token="t") + assert config._resolved_host_type is None