Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ public static ImmutableMap<String, String> buildPropertiesMap(
if (!isNullOrEmpty(connectionParamString)) {
String[] urlParts = connectionParamString.split(DatabricksJdbcConstants.URL_DELIMITER);
for (String urlPart : urlParts) {
String[] pair = urlPart.split(DatabricksJdbcConstants.PAIR_DELIMITER);
if (pair.length == 1) {
pair = new String[] {pair[0], ""};
}
if (pair[0].startsWith(DatabricksJdbcUrlParams.HTTP_HEADERS.getParamName())) {
parametersBuilder.put(pair[0], pair[1]);
// Split on first '=' only — values (like httpPath) may contain '=' (e.g. ?o=123)
int delimIdx = urlPart.indexOf(DatabricksJdbcConstants.PAIR_DELIMITER);
String key = delimIdx >= 0 ? urlPart.substring(0, delimIdx) : urlPart;
String value = delimIdx >= 0 ? urlPart.substring(delimIdx + 1) : "";
if (key.startsWith(DatabricksJdbcUrlParams.HTTP_HEADERS.getParamName())) {
parametersBuilder.put(key, value);
} else {
parametersBuilder.put(pair[0].toLowerCase(), pair[1]);
parametersBuilder.put(key.toLowerCase(), value);
}
}
}
Expand Down Expand Up @@ -1167,14 +1167,39 @@ private String getParameter(DatabricksJdbcUrlParams key, String defaultValue) {
return this.parameters.getOrDefault(key.getParamName().toLowerCase(), defaultValue);
}

private static final String ORG_ID_HEADER = "x-databricks-org-id";

private Map<String, String> parseCustomHeaders(ImmutableMap<String, String> parameters) {
String filterPrefix = DatabricksJdbcUrlParams.HTTP_HEADERS.getParamName();

return parameters.entrySet().stream()
.filter(entry -> entry.getKey().startsWith(filterPrefix))
.collect(
Collectors.toMap(
entry -> entry.getKey().substring(filterPrefix.length()), Map.Entry::getValue));
Map<String, String> headers =
new HashMap<>(
parameters.entrySet().stream()
.filter(entry -> entry.getKey().startsWith(filterPrefix))
.collect(
Collectors.toMap(
entry -> entry.getKey().substring(filterPrefix.length()),
Map.Entry::getValue)));

// Extract org ID from ?o= in httpPath for SPOG routing
if (!headers.containsKey(ORG_ID_HEADER)) {
String httpPath =
parameters.getOrDefault(
DatabricksJdbcUrlParams.HTTP_PATH.getParamName().toLowerCase(), "");
int queryStart = httpPath.indexOf('?');
if (queryStart >= 0) {
String queryString = httpPath.substring(queryStart + 1);
for (String param : queryString.split("&")) {
String[] kv = param.split("=", 2);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use some Url parser utility for this

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use URI builder. Thanks

if (kv.length == 2 && "o".equals(kv[0]) && !kv[1].isEmpty()) {
headers.put(ORG_ID_HEADER, kv[1]);
break;
}
}
}
}

return headers;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ CreateUploadUrlResponse getCreateUploadUrlResponse(String objectPath)
CreateUploadUrlRequest request = new CreateUploadUrlRequest(objectPath);
try {
Request req = new Request(Request.POST, CREATE_UPLOAD_URL_PATH, apiClient.serialize(request));
req.withHeaders(JSON_HTTP_HEADERS);
req.withHeaders(JSON_HTTP_HEADERS)
.withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of());
return apiClient.execute(req, CreateUploadUrlResponse.class);
} catch (IOException | DatabricksException e) {
String errorMessage =
Expand All @@ -514,7 +515,8 @@ CreateDownloadUrlResponse getCreateDownloadUrlResponse(String objectPath)
try {
Request req =
new Request(Request.POST, CREATE_DOWNLOAD_URL_PATH, apiClient.serialize(request));
req.withHeaders(JSON_HTTP_HEADERS);
req.withHeaders(JSON_HTTP_HEADERS)
.withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of());
return apiClient.execute(req, CreateDownloadUrlResponse.class);
} catch (IOException | DatabricksException e) {
String errorMessage =
Expand All @@ -534,7 +536,8 @@ CreateDeleteUrlResponse getCreateDeleteUrlResponse(String objectPath)

try {
Request req = new Request(Request.POST, CREATE_DELETE_URL_PATH, apiClient.serialize(request));
req.withHeaders(JSON_HTTP_HEADERS);
req.withHeaders(JSON_HTTP_HEADERS)
.withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of());
return apiClient.execute(req, CreateDeleteUrlResponse.class);
} catch (IOException | DatabricksException e) {
String errorMessage =
Expand All @@ -551,7 +554,8 @@ ListResponse getListResponse(String listPath) throws DatabricksVolumeOperationEx
ListRequest request = new ListRequest(listPath);
try {
Request req = new Request(Request.GET, LIST_PATH);
req.withHeaders(JSON_HTTP_HEADERS);
req.withHeaders(JSON_HTTP_HEADERS)
.withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of());
ApiClient.setQuery(req, request);
return apiClient.execute(req, ListResponse.class);
} catch (IOException | DatabricksException e) {
Expand Down Expand Up @@ -888,6 +892,9 @@ private CompletableFuture<CreateUploadUrlResponse> requestPresignedUrlWithRetry(
Map<String, String> authHeaders = workspaceClient.config().authenticate();
authHeaders.forEach(requestBuilder::addHeader);
JSON_HTTP_HEADERS.forEach(requestBuilder::addHeader);
if (connectionContext != null) {
connectionContext.getCustomHeaders().forEach(requestBuilder::addHeader);
}

requestBuilder.setEntity(
AsyncEntityProducers.create(requestBody.getBytes(), ContentType.APPLICATION_JSON));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ public final class DatabricksJdbcConstants {
"(?:/([^;]*))?"
+ // Optional Schema (captured without /)
"(?:;(.*))?"); // Optional Property=Value pairs (captured without leading ;)
public static final Pattern HTTP_WAREHOUSE_PATH_PATTERN = Pattern.compile(".*/warehouses/(.+)");
public static final Pattern HTTP_ENDPOINT_PATH_PATTERN = Pattern.compile(".*/endpoints/(.+)");
public static final Pattern HTTP_WAREHOUSE_PATH_PATTERN =
Pattern.compile(".*/warehouses/([^?&]+).*");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also allow some other http paths, I guess /endpoints/ or something

public static final Pattern HTTP_ENDPOINT_PATH_PATTERN =
Pattern.compile(".*/endpoints/([^?&]+).*");
public static final Pattern HTTP_CLI_PATTERN = Pattern.compile(".*cliservice(.+)");
public static final Pattern HTTP_PATH_CLI_PATTERN = Pattern.compile("cliservice");
public static final Pattern TEST_PATH_PATTERN = Pattern.compile("jdbc:databricks://test");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ private void refreshAllFeatureFlags() {
.getDatabricksConfig()
.authenticate()
.forEach(request::addHeader);
connectionContext.getCustomHeaders().forEach(request::addHeader);
fetchAndSetFlagsFromServer(httpClient, request);
} catch (Exception e) {
LOGGER.trace(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public void pushEvent(TelemetryRequest request) throws Exception {
Map<String, String> authHeaders =
isAuthenticated ? databricksConfig.authenticate() : Collections.emptyMap();
authHeaders.forEach(post::addHeader);
connectionContext.getCustomHeaders().forEach(post::addHeader);
try (CloseableHttpResponse response = httpClient.execute(post)) {
// TODO: check response and add retry for partial failures
if (!HttpUtil.isSuccessfulHttpResponse(response)) {
Expand Down
13 changes: 13 additions & 0 deletions src/test/java/com/databricks/jdbc/TestConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -321,4 +321,17 @@ public class TestConstants {
public static final List<TSparkArrowBatch> ARROW_BATCH_LIST =
Collections.singletonList(
new TSparkArrowBatch().setRowCount(0).setBatch(new byte[] {65, 66, 67}));

// SPOG URLs with ?o= query parameter in httpPath
public static final String VALID_SPOG_URL_WAREHOUSE =
"jdbc:databricks://spog.cloud.databricks.com/default;ssl=1;AuthMech=3;"
+ "httpPath=/sql/1.0/warehouses/abc123?o=6051921418418893;UseThriftClient=1";

public static final String VALID_SPOG_URL_ENDPOINT =
"jdbc:databricks://spog.cloud.databricks.com/default;ssl=1;AuthMech=3;"
+ "httpPath=/sql/1.0/endpoints/abc123?o=6051921418418893;UseThriftClient=0";

public static final String VALID_SPOG_URL_WAREHOUSE_NO_EXTRA_PARAMS =
"jdbc:databricks://spog.cloud.databricks.com/default;ssl=1;AuthMech=3;"
+ "httpPath=/sql/1.0/warehouses/abc123?o=6051921418418893";
}
Original file line number Diff line number Diff line change
Expand Up @@ -1357,4 +1357,83 @@ public void testOAuthWebServerTimeoutCustom() throws DatabricksSQLException {
TestConstants.VALID_URL_1 + ";OAuthWebServerTimeout=300", properties);
assertEquals(300, connectionContext.getOAuthWebServerTimeout());
}

// ==================== SPOG ?o= Tests ====================

@Test
void testBuildPropertiesMap_preservesQueryParamInHttpPath() {
String params = "ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/abc123?o=999;UseThriftClient=1";
ImmutableMap<String, String> result = buildPropertiesMap(params, new Properties());

assertEquals("/sql/1.0/warehouses/abc123?o=999", result.get("httppath"));
assertEquals("1", result.get("usethriftclient"));
}

@Test
void testBuildPropertiesMap_handlesValueWithMultipleEquals() {
String params = "httpPath=/sql/1.0/warehouses/abc?o=999&other=foo";
ImmutableMap<String, String> result = buildPropertiesMap(params, new Properties());

assertEquals("/sql/1.0/warehouses/abc?o=999&other=foo", result.get("httppath"));
}

@Test
void testBuildPropertiesMap_handlesValueWithNoEquals() {
String params = "keyonly";
ImmutableMap<String, String> result = buildPropertiesMap(params, new Properties());

assertEquals("", result.get("keyonly"));
}

@Test
void testSpogContext_extractsOrgIdFromHttpPath() throws DatabricksSQLException {
Properties props = new Properties();
props.put("user", "token");
props.put("password", "test-token");
IDatabricksConnectionContext ctx =
DatabricksConnectionContext.parse(TestConstants.VALID_SPOG_URL_WAREHOUSE, props);

Map<String, String> headers = ctx.getCustomHeaders();
assertEquals("6051921418418893", headers.get("x-databricks-org-id"));
}

@Test
void testSpogContext_extractsCleanWarehouseId() throws DatabricksSQLException {
Properties props = new Properties();
props.put("user", "token");
props.put("password", "test-token");
IDatabricksConnectionContext ctx =
DatabricksConnectionContext.parse(TestConstants.VALID_SPOG_URL_WAREHOUSE, props);

// Warehouse ID should be "abc123" not "abc123?o=6051921418418893"
assertTrue(ctx.getComputeResource() instanceof Warehouse);
assertEquals("abc123", ((Warehouse) ctx.getComputeResource()).getWarehouseId());
}

@Test
void testSpogContext_noOrgIdWithoutQueryParam() throws DatabricksSQLException {
Properties props = new Properties();
props.put("user", "token");
props.put("password", "test-token");
IDatabricksConnectionContext ctx =
DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, props);

Map<String, String> headers = ctx.getCustomHeaders();
assertFalse(headers.containsKey("x-databricks-org-id"));
}

@Test
void testSpogContext_explicitHeaderTakesPrecedence() throws DatabricksSQLException {
String url =
"jdbc:databricks://host/default;ssl=1;AuthMech=3;"
+ "httpPath=/sql/1.0/warehouses/abc123?o=frompath;"
+ "http.header.x-databricks-org-id=fromheader";
Properties props = new Properties();
props.put("user", "token");
props.put("password", "test-token");
IDatabricksConnectionContext ctx = DatabricksConnectionContext.parse(url, props);

Map<String, String> headers = ctx.getCustomHeaders();
assertEquals("fromheader", headers.get("x-databricks-org-id"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ private static Stream<Arguments> jdbcUrlValidityTestCases() {
"Valid URL with invalid compression type",
true),
Arguments.of(INVALID_URL_1, "Invalid non-Databricks JDBC URL", false),
Arguments.of(INVALID_URL_2, "Invalid malformed JDBC scheme", false));
Arguments.of(INVALID_URL_2, "Invalid malformed JDBC scheme", false),
Arguments.of(
VALID_SPOG_URL_WAREHOUSE, "Valid SPOG URL with ?o= in warehouse httpPath", true),
Arguments.of(VALID_SPOG_URL_ENDPOINT, "Valid SPOG URL with ?o= in endpoint httpPath", true),
Arguments.of(
VALID_SPOG_URL_WAREHOUSE_NO_EXTRA_PARAMS,
"Valid SPOG URL with ?o= at end of URL",
true));
}
}
Loading