-
Notifications
You must be signed in to change notification settings - Fork 196
Expand file tree
/
Copy pathtest_notebook_oauth.py
More file actions
276 lines (204 loc) · 9.71 KB
/
test_notebook_oauth.py
File metadata and controls
276 lines (204 loc) · 9.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
"""Tests for runtime OAuth authentication in notebook environments."""
import os
import sys
import types
from datetime import datetime, timedelta
from typing import Dict
import pytest
from databricks.sdk import oauth
from databricks.sdk.config import Config
from databricks.sdk.credentials_provider import (CredentialsProvider,
CredentialsStrategy,
DefaultCredentials,
runtime_oauth)
@pytest.fixture
def mock_runtime_env(monkeypatch):
"""Set up mock Databricks runtime environment."""
monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3")
yield
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION")
@pytest.fixture
def mock_runtime_native_auth():
"""Mock the runtime_native_auth to return a valid credentials provider."""
fake_runtime = types.ModuleType("databricks.sdk.runtime")
def fake_init_runtime_native_auth():
def inner():
return {"Authorization": "Bearer test-notebook-pat-token"}
return "https://test.cloud.databricks.com", inner
def fake_init_runtime_legacy_auth():
pass
def fake_init_runtime_repl_auth():
pass
fake_runtime.init_runtime_native_auth = fake_init_runtime_native_auth
fake_runtime.init_runtime_native_unified = None
fake_runtime.init_runtime_legacy_auth = fake_init_runtime_legacy_auth
fake_runtime.init_runtime_repl_auth = fake_init_runtime_repl_auth
sys.modules["databricks.sdk.runtime"] = fake_runtime
yield
@pytest.fixture
def mock_pat_exchange(mocker):
"""Mock the PATOAuthTokenExchange to avoid actual HTTP calls."""
mock_token = oauth.Token(
access_token="exchanged-oauth-token", token_type="Bearer", expiry=datetime.now() + timedelta(hours=1)
)
mock_exchange = mocker.Mock(spec=oauth.PATOAuthTokenExchange)
mock_exchange.token.return_value = mock_token
mocker.patch("databricks.sdk.oauth.PATOAuthTokenExchange", return_value=mock_exchange)
return mock_exchange
class MockCredentialsStrategy(CredentialsStrategy):
def auth_type(self) -> str:
return "mock_credentials_strategy"
def __call__(self, cfg) -> CredentialsProvider:
def credentials_provider() -> Dict[str, str]:
return {"Authorization": "Bearer: no_token"}
return credentials_provider
@pytest.mark.parametrize(
"scopes,auth_details",
[
("sql offline_access", None),
("sql offline_access", '{"type": "databricks_resource"}'),
("sql", None),
("sql offline_access all-apis", None),
],
)
def test_runtime_oauth_success_scenarios(
mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange, scopes, auth_details
):
"""Test runtime-oauth works correctly in various valid configurations."""
cfg = Config(
host="https://test.cloud.databricks.com",
scopes=scopes,
authorization_details=auth_details,
credentials_strategy=MockCredentialsStrategy(),
)
creds_provider = runtime_oauth(cfg)
assert creds_provider is not None
headers = creds_provider()
assert headers["Authorization"] == "Bearer exchanged-oauth-token"
@pytest.mark.parametrize(
"scopes",
[
(None),
(""),
],
)
def test_runtime_oauth_missing_scopes(mock_runtime_env, mock_runtime_native_auth, scopes):
"""Test that runtime-oauth returns None when scopes are not provided."""
cfg = Config(host="https://test.cloud.databricks.com", scopes=scopes)
creds_provider = runtime_oauth(cfg)
assert creds_provider is None
def test_runtime_oauth_priority_over_native_auth(mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange):
"""Test that runtime-oauth is prioritized over runtime-native-auth."""
cfg = Config(host="https://test.cloud.databricks.com", scopes="sql offline_access")
default_creds = DefaultCredentials()
creds_provider = default_creds(cfg)
headers = creds_provider()
assert headers["Authorization"] == "Bearer exchanged-oauth-token"
assert default_creds.auth_type() == "runtime-oauth"
def test_fallback_to_native_auth_without_scopes(mock_runtime_env, mock_runtime_native_auth):
"""Test that runtime-native-auth is used when scopes are not provided."""
cfg = Config(host="https://test.cloud.databricks.com")
default_creds = DefaultCredentials()
creds_provider = default_creds(cfg)
headers = creds_provider()
assert headers["Authorization"] == "Bearer test-notebook-pat-token"
assert default_creds.auth_type() == "runtime"
def test_explicit_runtime_oauth_auth_type(mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange):
"""Test that runtime-oauth is used when explicitly specified as auth_type."""
cfg = Config(host="https://test.cloud.databricks.com", scopes="sql offline_access", auth_type="runtime-oauth")
default_creds = DefaultCredentials()
creds_provider = default_creds(cfg)
headers = creds_provider()
assert headers["Authorization"] == "Bearer exchanged-oauth-token"
assert default_creds.auth_type() == "runtime-oauth"
@pytest.mark.parametrize(
"has_scopes,expected_token",
[
(True, "exchanged-oauth-token"),
(False, "test-notebook-pat-token"),
],
)
def test_config_authenticate_integration(
mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange, has_scopes, expected_token
):
"""Test Config.authenticate() integration with runtime-oauth and fallback."""
cfg_kwargs = {"host": "https://test.cloud.databricks.com"}
if has_scopes:
cfg_kwargs["scopes"] = "sql offline_access"
cfg = Config(**cfg_kwargs)
headers = cfg.authenticate()
assert headers["Authorization"] == f"Bearer {expected_token}"
@pytest.mark.parametrize(
"scopes_input,expected_scopes",
[(["sql", "offline_access"], ["offline_access", "sql"])],
)
def test_workspace_client_integration(
mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange, scopes_input, expected_scopes
):
"""Test that WorkspaceClient correctly uses runtime-oauth with different scope inputs."""
from databricks.sdk import WorkspaceClient
w = WorkspaceClient(host="https://test.cloud.databricks.com", scopes=scopes_input)
assert w.config.scopes == expected_scopes
headers = w.config.authenticate()
assert headers["Authorization"] == "Bearer exchanged-oauth-token"
@pytest.fixture
def mock_runtime_native_unified():
"""Mock the runtime module with init_runtime_native_unified returning 4-tuple."""
fake_runtime = types.ModuleType("databricks.sdk.runtime")
def fake_init_runtime_native_unified():
def inner():
return {"Authorization": "Bearer unified-token"}
return "https://unified.cloud.databricks.com", "acc-123", "ws-456", inner
fake_runtime.init_runtime_native_unified = fake_init_runtime_native_unified
fake_runtime.init_runtime_native_auth = None
fake_runtime.init_runtime_legacy_auth = None
fake_runtime.init_runtime_repl_auth = None
sys.modules["databricks.sdk.runtime"] = fake_runtime
yield
@pytest.fixture
def mock_runtime_native_unified_returns_none():
"""Mock the runtime module with init_runtime_native_unified returning None host."""
fake_runtime = types.ModuleType("databricks.sdk.runtime")
def fake_init_runtime_native_unified():
return None, None, None, None
def fake_init_runtime_native_auth():
def inner():
return {"Authorization": "Bearer fallback-token"}
return "https://fallback.cloud.databricks.com", inner
fake_runtime.init_runtime_native_unified = fake_init_runtime_native_unified
fake_runtime.init_runtime_native_auth = fake_init_runtime_native_auth
fake_runtime.init_runtime_legacy_auth = None
fake_runtime.init_runtime_repl_auth = None
sys.modules["databricks.sdk.runtime"] = fake_runtime
yield
def test_runtime_unified_auth_sets_host_and_ids(mock_runtime_env, mock_runtime_native_unified):
"""Test that init_runtime_native_unified sets host, account_id, and workspace_id on Config."""
cfg = Config(host="https://unified.cloud.databricks.com")
headers = cfg.authenticate()
assert headers["Authorization"] == "Bearer unified-token"
assert cfg.host == "https://unified.cloud.databricks.com"
assert cfg.account_id == "acc-123"
assert cfg.workspace_id == "ws-456"
def test_runtime_unified_auth_fallback_when_none(mock_runtime_env, mock_runtime_native_unified_returns_none):
"""Test fallback to init_runtime_native_auth when unified returns None."""
cfg = Config(host="https://fallback.cloud.databricks.com")
headers = cfg.authenticate()
assert headers["Authorization"] == "Bearer fallback-token"
assert cfg.host == "https://fallback.cloud.databricks.com"
def test_runtime_unified_auth_fallback_when_not_available(mock_runtime_env, mock_runtime_native_auth):
"""Test fallback to init_runtime_native_auth when unified is None (import failed)."""
cfg = Config(host="https://test.cloud.databricks.com")
headers = cfg.authenticate()
assert headers["Authorization"] == "Bearer test-notebook-pat-token"
assert cfg.host == "https://test.cloud.databricks.com"
def test_runtime_unified_auth_priority_over_native(mock_runtime_env, mock_runtime_native_unified):
"""Test that unified provider is used over native auth in DefaultCredentials chain."""
cfg = Config(host="https://unified.cloud.databricks.com")
default_creds = DefaultCredentials()
creds_provider = default_creds(cfg)
headers = creds_provider()
assert headers["Authorization"] == "Bearer unified-token"
assert default_creds.auth_type() == "runtime"
assert cfg.account_id == "acc-123"
assert cfg.workspace_id == "ws-456"