Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableMap;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.*;
import java.util.Collections;
import java.util.regex.Matcher;
import java.util.stream.Collectors;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URIBuilder;

public class DatabricksConnectionContext implements IDatabricksConnectionContext {
Expand Down Expand Up @@ -108,14 +110,14 @@ public static ImmutableMap<String, String> buildPropertiesMap(
if (!isNullOrEmpty(connectionParamString)) {
String[] urlParts = connectionParamString.split(DatabricksJdbcConstants.URL_DELIMITER);
for (String urlPart : urlParts) {
String[] pair = urlPart.split(DatabricksJdbcConstants.PAIR_DELIMITER);
if (pair.length == 1) {
pair = new String[] {pair[0], ""};
}
if (pair[0].startsWith(DatabricksJdbcUrlParams.HTTP_HEADERS.getParamName())) {
parametersBuilder.put(pair[0], pair[1]);
// Split on first '=' only — values (like httpPath) may contain '=' (e.g. ?o=123)
int delimIdx = urlPart.indexOf(DatabricksJdbcConstants.PAIR_DELIMITER);
String key = delimIdx >= 0 ? urlPart.substring(0, delimIdx) : urlPart;
String value = delimIdx >= 0 ? urlPart.substring(delimIdx + 1) : "";
if (key.startsWith(DatabricksJdbcUrlParams.HTTP_HEADERS.getParamName())) {
parametersBuilder.put(key, value);
} else {
parametersBuilder.put(pair[0].toLowerCase(), pair[1]);
parametersBuilder.put(key.toLowerCase(), value);
}
}
}
Expand Down Expand Up @@ -1167,14 +1169,41 @@ private String getParameter(DatabricksJdbcUrlParams key, String defaultValue) {
return this.parameters.getOrDefault(key.getParamName().toLowerCase(), defaultValue);
}

private static final String ORG_ID_HEADER = "x-databricks-org-id";

private Map<String, String> parseCustomHeaders(ImmutableMap<String, String> parameters) {
String filterPrefix = DatabricksJdbcUrlParams.HTTP_HEADERS.getParamName();

return parameters.entrySet().stream()
.filter(entry -> entry.getKey().startsWith(filterPrefix))
.collect(
Collectors.toMap(
entry -> entry.getKey().substring(filterPrefix.length()), Map.Entry::getValue));
Map<String, String> headers =
new HashMap<>(
parameters.entrySet().stream()
.filter(entry -> entry.getKey().startsWith(filterPrefix))
.collect(
Collectors.toMap(
entry -> entry.getKey().substring(filterPrefix.length()),
Map.Entry::getValue)));

// Extract org ID from ?o= in httpPath for SPOG routing
if (!headers.containsKey(ORG_ID_HEADER)) {
String httpPath =
parameters.getOrDefault(
DatabricksJdbcUrlParams.HTTP_PATH.getParamName().toLowerCase(), "");
try {
for (NameValuePair param :
new URIBuilder("http://placeholder" + httpPath).getQueryParams()) {
if ("o".equals(param.getName())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

declare o as constant

&& param.getValue() != null
&& !param.getValue().isEmpty()) {
headers.put(ORG_ID_HEADER, param.getValue());
break;
}
}
} catch (URISyntaxException e) {
// Malformed httpPath — skip SPOG header extraction
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add logging

}
}

return headers;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ CreateUploadUrlResponse getCreateUploadUrlResponse(String objectPath)
CreateUploadUrlRequest request = new CreateUploadUrlRequest(objectPath);
try {
Request req = new Request(Request.POST, CREATE_UPLOAD_URL_PATH, apiClient.serialize(request));
req.withHeaders(JSON_HTTP_HEADERS);
req.withHeaders(JSON_HTTP_HEADERS)
.withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of());
return apiClient.execute(req, CreateUploadUrlResponse.class);
} catch (IOException | DatabricksException e) {
String errorMessage =
Expand All @@ -514,7 +515,8 @@ CreateDownloadUrlResponse getCreateDownloadUrlResponse(String objectPath)
try {
Request req =
new Request(Request.POST, CREATE_DOWNLOAD_URL_PATH, apiClient.serialize(request));
req.withHeaders(JSON_HTTP_HEADERS);
req.withHeaders(JSON_HTTP_HEADERS)
.withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of());
return apiClient.execute(req, CreateDownloadUrlResponse.class);
} catch (IOException | DatabricksException e) {
String errorMessage =
Expand All @@ -534,7 +536,8 @@ CreateDeleteUrlResponse getCreateDeleteUrlResponse(String objectPath)

try {
Request req = new Request(Request.POST, CREATE_DELETE_URL_PATH, apiClient.serialize(request));
req.withHeaders(JSON_HTTP_HEADERS);
req.withHeaders(JSON_HTTP_HEADERS)
.withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of());
return apiClient.execute(req, CreateDeleteUrlResponse.class);
} catch (IOException | DatabricksException e) {
String errorMessage =
Expand All @@ -551,7 +554,8 @@ ListResponse getListResponse(String listPath) throws DatabricksVolumeOperationEx
ListRequest request = new ListRequest(listPath);
try {
Request req = new Request(Request.GET, LIST_PATH);
req.withHeaders(JSON_HTTP_HEADERS);
req.withHeaders(JSON_HTTP_HEADERS)
.withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of());
ApiClient.setQuery(req, request);
return apiClient.execute(req, ListResponse.class);
} catch (IOException | DatabricksException e) {
Expand Down Expand Up @@ -888,6 +892,9 @@ private CompletableFuture<CreateUploadUrlResponse> requestPresignedUrlWithRetry(
Map<String, String> authHeaders = workspaceClient.config().authenticate();
authHeaders.forEach(requestBuilder::addHeader);
JSON_HTTP_HEADERS.forEach(requestBuilder::addHeader);
if (connectionContext != null) {
connectionContext.getCustomHeaders().forEach(requestBuilder::addHeader);
}

requestBuilder.setEntity(
AsyncEntityProducers.create(requestBody.getBytes(), ContentType.APPLICATION_JSON));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ public final class DatabricksJdbcConstants {
"(?:/([^;]*))?"
+ // Optional Schema (captured without /)
"(?:;(.*))?"); // Optional Property=Value pairs (captured without leading ;)
public static final Pattern HTTP_WAREHOUSE_PATH_PATTERN = Pattern.compile(".*/warehouses/(.+)");
public static final Pattern HTTP_ENDPOINT_PATH_PATTERN = Pattern.compile(".*/endpoints/(.+)");
public static final Pattern HTTP_WAREHOUSE_PATH_PATTERN =
Pattern.compile(".*/warehouses/([^?&]+).*");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also allow some other http paths, I guess /endpoints/ or something

public static final Pattern HTTP_ENDPOINT_PATH_PATTERN =
Pattern.compile(".*/endpoints/([^?&]+).*");
public static final Pattern HTTP_CLI_PATTERN = Pattern.compile(".*cliservice(.+)");
public static final Pattern HTTP_PATH_CLI_PATTERN = Pattern.compile("cliservice");
public static final Pattern TEST_PATH_PATTERN = Pattern.compile("jdbc:databricks://test");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ private void refreshAllFeatureFlags() {
.getDatabricksConfig()
.authenticate()
.forEach(request::addHeader);
connectionContext.getCustomHeaders().forEach(request::addHeader);
fetchAndSetFlagsFromServer(httpClient, request);
} catch (Exception e) {
LOGGER.trace(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public void pushEvent(TelemetryRequest request) throws Exception {
Map<String, String> authHeaders =
isAuthenticated ? databricksConfig.authenticate() : Collections.emptyMap();
authHeaders.forEach(post::addHeader);
connectionContext.getCustomHeaders().forEach(post::addHeader);
try (CloseableHttpResponse response = httpClient.execute(post)) {
// TODO: check response and add retry for partial failures
if (!HttpUtil.isSuccessfulHttpResponse(response)) {
Expand Down
13 changes: 13 additions & 0 deletions src/test/java/com/databricks/jdbc/TestConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -321,4 +321,17 @@ public class TestConstants {
public static final List<TSparkArrowBatch> ARROW_BATCH_LIST =
Collections.singletonList(
new TSparkArrowBatch().setRowCount(0).setBatch(new byte[] {65, 66, 67}));

// SPOG URLs with ?o= query parameter in httpPath
public static final String VALID_SPOG_URL_WAREHOUSE =
"jdbc:databricks://spog.cloud.databricks.com/default;ssl=1;AuthMech=3;"
+ "httpPath=/sql/1.0/warehouses/abc123?o=6051921418418893;UseThriftClient=1";

public static final String VALID_SPOG_URL_ENDPOINT =
"jdbc:databricks://spog.cloud.databricks.com/default;ssl=1;AuthMech=3;"
+ "httpPath=/sql/1.0/endpoints/abc123?o=6051921418418893;UseThriftClient=0";

public static final String VALID_SPOG_URL_WAREHOUSE_NO_EXTRA_PARAMS =
"jdbc:databricks://spog.cloud.databricks.com/default;ssl=1;AuthMech=3;"
+ "httpPath=/sql/1.0/warehouses/abc123?o=6051921418418893";
}
Original file line number Diff line number Diff line change
Expand Up @@ -1357,4 +1357,83 @@ public void testOAuthWebServerTimeoutCustom() throws DatabricksSQLException {
TestConstants.VALID_URL_1 + ";OAuthWebServerTimeout=300", properties);
assertEquals(300, connectionContext.getOAuthWebServerTimeout());
}

// ==================== SPOG ?o= Tests ====================

@Test
void testBuildPropertiesMap_preservesQueryParamInHttpPath() {
String params = "ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/abc123?o=999;UseThriftClient=1";
ImmutableMap<String, String> result = buildPropertiesMap(params, new Properties());

assertEquals("/sql/1.0/warehouses/abc123?o=999", result.get("httppath"));
assertEquals("1", result.get("usethriftclient"));
}

@Test
void testBuildPropertiesMap_handlesValueWithMultipleEquals() {
String params = "httpPath=/sql/1.0/warehouses/abc?o=999&other=foo";
ImmutableMap<String, String> result = buildPropertiesMap(params, new Properties());

assertEquals("/sql/1.0/warehouses/abc?o=999&other=foo", result.get("httppath"));
}

@Test
void testBuildPropertiesMap_handlesValueWithNoEquals() {
String params = "keyonly";
ImmutableMap<String, String> result = buildPropertiesMap(params, new Properties());

assertEquals("", result.get("keyonly"));
}

@Test
void testSpogContext_extractsOrgIdFromHttpPath() throws DatabricksSQLException {
Properties props = new Properties();
props.put("user", "token");
props.put("password", "test-token");
IDatabricksConnectionContext ctx =
DatabricksConnectionContext.parse(TestConstants.VALID_SPOG_URL_WAREHOUSE, props);

Map<String, String> headers = ctx.getCustomHeaders();
assertEquals("6051921418418893", headers.get("x-databricks-org-id"));
}

@Test
void testSpogContext_extractsCleanWarehouseId() throws DatabricksSQLException {
Properties props = new Properties();
props.put("user", "token");
props.put("password", "test-token");
IDatabricksConnectionContext ctx =
DatabricksConnectionContext.parse(TestConstants.VALID_SPOG_URL_WAREHOUSE, props);

// Warehouse ID should be "abc123" not "abc123?o=6051921418418893"
assertTrue(ctx.getComputeResource() instanceof Warehouse);
assertEquals("abc123", ((Warehouse) ctx.getComputeResource()).getWarehouseId());
}

@Test
void testSpogContext_noOrgIdWithoutQueryParam() throws DatabricksSQLException {
Properties props = new Properties();
props.put("user", "token");
props.put("password", "test-token");
IDatabricksConnectionContext ctx =
DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, props);

Map<String, String> headers = ctx.getCustomHeaders();
assertFalse(headers.containsKey("x-databricks-org-id"));
}

@Test
void testSpogContext_explicitHeaderTakesPrecedence() throws DatabricksSQLException {
String url =
"jdbc:databricks://host/default;ssl=1;AuthMech=3;"
+ "httpPath=/sql/1.0/warehouses/abc123?o=frompath;"
+ "http.header.x-databricks-org-id=fromheader";
Properties props = new Properties();
props.put("user", "token");
props.put("password", "test-token");
IDatabricksConnectionContext ctx = DatabricksConnectionContext.parse(url, props);

Map<String, String> headers = ctx.getCustomHeaders();
assertEquals("fromheader", headers.get("x-databricks-org-id"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ private static Stream<Arguments> jdbcUrlValidityTestCases() {
"Valid URL with invalid compression type",
true),
Arguments.of(INVALID_URL_1, "Invalid non-Databricks JDBC URL", false),
Arguments.of(INVALID_URL_2, "Invalid malformed JDBC scheme", false));
Arguments.of(INVALID_URL_2, "Invalid malformed JDBC scheme", false),
Arguments.of(
VALID_SPOG_URL_WAREHOUSE, "Valid SPOG URL with ?o= in warehouse httpPath", true),
Arguments.of(VALID_SPOG_URL_ENDPOINT, "Valid SPOG URL with ?o= in endpoint httpPath", true),
Arguments.of(
VALID_SPOG_URL_WAREHOUSE_NO_EXTRA_PARAMS,
"Valid SPOG URL with ?o= at end of URL",
true));
}
}
Loading