diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index e49f0a7fc..a992ea67a 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -160,8 +160,21 @@ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]: # runtime and no config variables are set. from databricks.sdk.runtime import (init_runtime_legacy_auth, init_runtime_native_auth, + init_runtime_native_unified, init_runtime_repl_auth) + # Try the unified provider first (returns host, account_id, workspace_id, inner). + if init_runtime_native_unified is not None: + host, account_id, workspace_id, inner = init_runtime_native_unified() + if host is not None: + cfg.host = host + cfg.account_id = account_id + cfg.workspace_id = workspace_id + logger.debug("[init_runtime_native_unified] runtime native auth configured") + return inner + logger.debug("[init_runtime_native_unified] no host detected") + + # Fall back to legacy providers (return host, inner). for init in [ init_runtime_native_auth, init_runtime_repl_auth, diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index adf26c707..ebcab489e 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -23,6 +23,13 @@ ] # DO NOT MOVE THE TRY-CATCH BLOCK BELOW AND DO NOT ADD THINGS BEFORE IT! WILL MAKE TEST FAIL. +try: + from dbruntime.sdk_credential_provider import init_runtime_native_unified + + logger.debug("runtime SDK credential provider (unified) available") +except ImportError: + init_runtime_native_unified = None + try: # We don't want to expose additional entity to user namespace, so # a workaround here for exposing required information in notebook environment @@ -34,6 +41,7 @@ init_runtime_native_auth = None globals()["init_runtime_native_auth"] = init_runtime_native_auth +globals()["init_runtime_native_unified"] = init_runtime_native_unified def init_runtime_repl_auth(): diff --git a/tests/test_notebook_oauth.py b/tests/test_notebook_oauth.py index bde8f2165..eaca8c5de 100644 --- a/tests/test_notebook_oauth.py +++ b/tests/test_notebook_oauth.py @@ -43,6 +43,7 @@ def fake_init_runtime_repl_auth(): pass fake_runtime.init_runtime_native_auth = fake_init_runtime_native_auth + fake_runtime.init_runtime_native_unified = None fake_runtime.init_runtime_legacy_auth = fake_init_runtime_legacy_auth fake_runtime.init_runtime_repl_auth = fake_init_runtime_repl_auth @@ -187,3 +188,89 @@ def test_workspace_client_integration( assert w.config.scopes == expected_scopes headers = w.config.authenticate() assert headers["Authorization"] == "Bearer exchanged-oauth-token" + + +@pytest.fixture +def mock_runtime_native_unified(): + """Mock the runtime module with init_runtime_native_unified returning 4-tuple.""" + fake_runtime = types.ModuleType("databricks.sdk.runtime") + + def fake_init_runtime_native_unified(): + def inner(): + return {"Authorization": "Bearer unified-token"} + + return "https://unified.cloud.databricks.com", "acc-123", "ws-456", inner + + fake_runtime.init_runtime_native_unified = fake_init_runtime_native_unified + fake_runtime.init_runtime_native_auth = None + fake_runtime.init_runtime_legacy_auth = None + fake_runtime.init_runtime_repl_auth = None + + sys.modules["databricks.sdk.runtime"] = fake_runtime + yield + + +@pytest.fixture +def mock_runtime_native_unified_returns_none(): + """Mock the runtime module with init_runtime_native_unified returning None host.""" + fake_runtime = types.ModuleType("databricks.sdk.runtime") + + def fake_init_runtime_native_unified(): + return None, None, None, None + + def fake_init_runtime_native_auth(): + def inner(): + return {"Authorization": "Bearer fallback-token"} + + return "https://fallback.cloud.databricks.com", inner + + fake_runtime.init_runtime_native_unified = fake_init_runtime_native_unified + fake_runtime.init_runtime_native_auth = fake_init_runtime_native_auth + fake_runtime.init_runtime_legacy_auth = None + fake_runtime.init_runtime_repl_auth = None + + sys.modules["databricks.sdk.runtime"] = fake_runtime + yield + + +def test_runtime_unified_auth_sets_host_and_ids(mock_runtime_env, mock_runtime_native_unified): + """Test that init_runtime_native_unified sets host, account_id, and workspace_id on Config.""" + cfg = Config(host="https://unified.cloud.databricks.com") + + headers = cfg.authenticate() + assert headers["Authorization"] == "Bearer unified-token" + assert cfg.host == "https://unified.cloud.databricks.com" + assert cfg.account_id == "acc-123" + assert cfg.workspace_id == "ws-456" + + +def test_runtime_unified_auth_fallback_when_none(mock_runtime_env, mock_runtime_native_unified_returns_none): + """Test fallback to init_runtime_native_auth when unified returns None.""" + cfg = Config(host="https://fallback.cloud.databricks.com") + + headers = cfg.authenticate() + assert headers["Authorization"] == "Bearer fallback-token" + assert cfg.host == "https://fallback.cloud.databricks.com" + + +def test_runtime_unified_auth_fallback_when_not_available(mock_runtime_env, mock_runtime_native_auth): + """Test fallback to init_runtime_native_auth when unified is None (import failed).""" + cfg = Config(host="https://test.cloud.databricks.com") + + headers = cfg.authenticate() + assert headers["Authorization"] == "Bearer test-notebook-pat-token" + assert cfg.host == "https://test.cloud.databricks.com" + + +def test_runtime_unified_auth_priority_over_native(mock_runtime_env, mock_runtime_native_unified): + """Test that unified provider is used over native auth in DefaultCredentials chain.""" + cfg = Config(host="https://unified.cloud.databricks.com") + + default_creds = DefaultCredentials() + creds_provider = default_creds(cfg) + + headers = creds_provider() + assert headers["Authorization"] == "Bearer unified-token" + assert default_creds.auth_type() == "runtime" + assert cfg.account_id == "acc-123" + assert cfg.workspace_id == "ws-456"