Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,21 @@ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]:
# runtime and no config variables are set.
from databricks.sdk.runtime import (init_runtime_legacy_auth,
init_runtime_native_auth,
init_runtime_native_unified,
init_runtime_repl_auth)

# Try the unified provider first (returns host, account_id, workspace_id, inner).
if init_runtime_native_unified is not None:
host, account_id, workspace_id, inner = init_runtime_native_unified()
if host is not None:
cfg.host = host
cfg.account_id = account_id
cfg.workspace_id = workspace_id
logger.debug("[init_runtime_native_unified] runtime native auth configured")
return inner
logger.debug("[init_runtime_native_unified] no host detected")

# Fall back to legacy providers (return host, inner).
for init in [
init_runtime_native_auth,
init_runtime_repl_auth,
Expand Down
8 changes: 8 additions & 0 deletions databricks/sdk/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
]

# DO NOT MOVE THE TRY-CATCH BLOCK BELOW AND DO NOT ADD THINGS BEFORE IT! WILL MAKE TEST FAIL.
try:
from dbruntime.sdk_credential_provider import init_runtime_native_unified

logger.debug("runtime SDK credential provider (unified) available")
except ImportError:
init_runtime_native_unified = None

try:
# We don't want to expose additional entity to user namespace, so
# a workaround here for exposing required information in notebook environment
Expand All @@ -34,6 +41,7 @@
init_runtime_native_auth = None

globals()["init_runtime_native_auth"] = init_runtime_native_auth
globals()["init_runtime_native_unified"] = init_runtime_native_unified


def init_runtime_repl_auth():
Expand Down
87 changes: 87 additions & 0 deletions tests/test_notebook_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def fake_init_runtime_repl_auth():
pass

fake_runtime.init_runtime_native_auth = fake_init_runtime_native_auth
fake_runtime.init_runtime_native_unified = None
fake_runtime.init_runtime_legacy_auth = fake_init_runtime_legacy_auth
fake_runtime.init_runtime_repl_auth = fake_init_runtime_repl_auth

Expand Down Expand Up @@ -187,3 +188,89 @@ def test_workspace_client_integration(
assert w.config.scopes == expected_scopes
headers = w.config.authenticate()
assert headers["Authorization"] == "Bearer exchanged-oauth-token"


@pytest.fixture
def mock_runtime_native_unified():
"""Mock the runtime module with init_runtime_native_unified returning 4-tuple."""
fake_runtime = types.ModuleType("databricks.sdk.runtime")

def fake_init_runtime_native_unified():
def inner():
return {"Authorization": "Bearer unified-token"}

return "https://unified.cloud.databricks.com", "acc-123", "ws-456", inner

fake_runtime.init_runtime_native_unified = fake_init_runtime_native_unified
fake_runtime.init_runtime_native_auth = None
fake_runtime.init_runtime_legacy_auth = None
fake_runtime.init_runtime_repl_auth = None

sys.modules["databricks.sdk.runtime"] = fake_runtime
yield


@pytest.fixture
def mock_runtime_native_unified_returns_none():
"""Mock the runtime module with init_runtime_native_unified returning None host."""
fake_runtime = types.ModuleType("databricks.sdk.runtime")

def fake_init_runtime_native_unified():
return None, None, None, None

def fake_init_runtime_native_auth():
def inner():
return {"Authorization": "Bearer fallback-token"}

return "https://fallback.cloud.databricks.com", inner

fake_runtime.init_runtime_native_unified = fake_init_runtime_native_unified
fake_runtime.init_runtime_native_auth = fake_init_runtime_native_auth
fake_runtime.init_runtime_legacy_auth = None
fake_runtime.init_runtime_repl_auth = None

sys.modules["databricks.sdk.runtime"] = fake_runtime
yield


def test_runtime_unified_auth_sets_host_and_ids(mock_runtime_env, mock_runtime_native_unified):
"""Test that init_runtime_native_unified sets host, account_id, and workspace_id on Config."""
cfg = Config(host="https://unified.cloud.databricks.com")

headers = cfg.authenticate()
assert headers["Authorization"] == "Bearer unified-token"
assert cfg.host == "https://unified.cloud.databricks.com"
assert cfg.account_id == "acc-123"
assert cfg.workspace_id == "ws-456"


def test_runtime_unified_auth_fallback_when_none(mock_runtime_env, mock_runtime_native_unified_returns_none):
"""Test fallback to init_runtime_native_auth when unified returns None."""
cfg = Config(host="https://fallback.cloud.databricks.com")

headers = cfg.authenticate()
assert headers["Authorization"] == "Bearer fallback-token"
assert cfg.host == "https://fallback.cloud.databricks.com"


def test_runtime_unified_auth_fallback_when_not_available(mock_runtime_env, mock_runtime_native_auth):
"""Test fallback to init_runtime_native_auth when unified is None (import failed)."""
cfg = Config(host="https://test.cloud.databricks.com")

headers = cfg.authenticate()
assert headers["Authorization"] == "Bearer test-notebook-pat-token"
assert cfg.host == "https://test.cloud.databricks.com"


def test_runtime_unified_auth_priority_over_native(mock_runtime_env, mock_runtime_native_unified):
"""Test that unified provider is used over native auth in DefaultCredentials chain."""
cfg = Config(host="https://unified.cloud.databricks.com")

default_creds = DefaultCredentials()
creds_provider = default_creds(cfg)

headers = creds_provider()
assert headers["Authorization"] == "Bearer unified-token"
assert default_creds.auth_type() == "runtime"
assert cfg.account_id == "acc-123"
assert cfg.workspace_id == "ws-456"
Loading