Skip to content

Commit f835aca

Browse files
authored
Add Token Federation Support for Databricks SQL Python Driver (#691)
* token federation for python driver * address comment * address comments * lint * lint fix * nit * change import
1 parent 54dd646 commit f835aca

File tree

6 files changed

+644
-9
lines changed

6 files changed

+644
-9
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@
88
AzureServicePrincipalCredentialProvider,
99
)
1010
from databricks.sql.auth.common import AuthType, ClientContext
11+
from databricks.sql.auth.token_federation import TokenFederationProvider
1112

1213

1314
def get_auth_provider(cfg: ClientContext, http_client):
15+
# Determine the base auth provider
16+
base_provider: Optional[AuthProvider] = None
17+
1418
if cfg.credentials_provider:
15-
return ExternalAuthProvider(cfg.credentials_provider)
19+
base_provider = ExternalAuthProvider(cfg.credentials_provider)
1620
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
17-
return ExternalAuthProvider(
21+
base_provider = ExternalAuthProvider(
1822
AzureServicePrincipalCredentialProvider(
1923
cfg.hostname,
2024
cfg.azure_client_id,
@@ -29,7 +33,7 @@ def get_auth_provider(cfg: ClientContext, http_client):
2933
assert cfg.oauth_client_id is not None
3034
assert cfg.oauth_scopes is not None
3135

32-
return DatabricksOAuthProvider(
36+
base_provider = DatabricksOAuthProvider(
3337
cfg.hostname,
3438
cfg.oauth_persistence,
3539
cfg.oauth_redirect_port_range,
@@ -39,17 +43,17 @@ def get_auth_provider(cfg: ClientContext, http_client):
3943
cfg.auth_type,
4044
)
4145
elif cfg.access_token is not None:
42-
return AccessTokenAuthProvider(cfg.access_token)
46+
base_provider = AccessTokenAuthProvider(cfg.access_token)
4347
elif cfg.use_cert_as_auth and cfg.tls_client_cert_file:
4448
# no op authenticator. authentication is performed using ssl certificate outside of headers
45-
return AuthProvider()
49+
base_provider = AuthProvider()
4650
else:
4751
if (
4852
cfg.oauth_redirect_port_range is not None
4953
and cfg.oauth_client_id is not None
5054
and cfg.oauth_scopes is not None
5155
):
52-
return DatabricksOAuthProvider(
56+
base_provider = DatabricksOAuthProvider(
5357
cfg.hostname,
5458
cfg.oauth_persistence,
5559
cfg.oauth_redirect_port_range,
@@ -61,6 +65,17 @@ def get_auth_provider(cfg: ClientContext, http_client):
6165
else:
6266
raise RuntimeError("No valid authentication settings!")
6367

68+
# Always wrap with token federation (falls back gracefully if not needed)
69+
if base_provider:
70+
return TokenFederationProvider(
71+
hostname=cfg.hostname,
72+
external_provider=base_provider,
73+
http_client=http_client,
74+
identity_federation_client_id=cfg.identity_federation_client_id,
75+
)
76+
77+
return base_provider
78+
6479

6580
PYSQL_OAUTH_SCOPES = ["sql", "offline_access"]
6681
PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python"
@@ -114,5 +129,6 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs)
114129
else redirect_port_range,
115130
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
116131
credentials_provider=kwargs.get("credentials_provider"),
132+
identity_federation_client_id=kwargs.get("identity_federation_client_id"),
117133
)
118134
return get_auth_provider(cfg, http_client)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import logging
2+
import jwt
3+
from datetime import datetime, timedelta
4+
from typing import Optional, Dict, Tuple
5+
from urllib.parse import urlparse
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def parse_hostname(hostname: str) -> str:
11+
"""
12+
Normalize the hostname to include scheme and trailing slash.
13+
14+
Args:
15+
hostname: The hostname to normalize
16+
17+
Returns:
18+
Normalized hostname with scheme and trailing slash
19+
"""
20+
if not hostname.startswith("http://") and not hostname.startswith("https://"):
21+
hostname = f"https://{hostname}"
22+
if not hostname.endswith("/"):
23+
hostname = f"{hostname}/"
24+
return hostname
25+
26+
27+
def decode_token(access_token: str) -> Optional[Dict]:
28+
"""
29+
Decode a JWT token without verification to extract claims.
30+
31+
Args:
32+
access_token: The JWT access token to decode
33+
34+
Returns:
35+
Decoded token claims or None if decoding fails
36+
"""
37+
try:
38+
return jwt.decode(access_token, options={"verify_signature": False})
39+
except Exception as e:
40+
logger.debug("Failed to decode JWT token: %s", e)
41+
return None
42+
43+
44+
def is_same_host(url1: str, url2: str) -> bool:
45+
"""
46+
Check if two URLs have the same host.
47+
48+
Args:
49+
url1: First URL
50+
url2: Second URL
51+
52+
Returns:
53+
True if hosts are the same, False otherwise
54+
"""
55+
try:
56+
host1 = urlparse(url1).netloc
57+
host2 = urlparse(url2).netloc
58+
# Handle port differences (e.g., example.com vs example.com:443)
59+
host1_without_port = host1.split(":")[0]
60+
host2_without_port = host2.split(":")[0]
61+
return host1_without_port == host2_without_port
62+
except Exception as e:
63+
logger.debug("Failed to parse URLs: %s", e)
64+
return False

src/databricks/sql/auth/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
tls_client_cert_file: Optional[str] = None,
3838
oauth_persistence=None,
3939
credentials_provider=None,
40+
identity_federation_client_id: Optional[str] = None,
4041
# HTTP client configuration parameters
4142
ssl_options=None, # SSLOptions type
4243
socket_timeout: Optional[float] = None,
@@ -65,6 +66,7 @@ def __init__(
6566
self.tls_client_cert_file = tls_client_cert_file
6667
self.oauth_persistence = oauth_persistence
6768
self.credentials_provider = credentials_provider
69+
self.identity_federation_client_id = identity_federation_client_id
6870

6971
# HTTP client configuration
7072
self.ssl_options = ssl_options
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import logging
2+
import json
3+
from datetime import datetime, timedelta
4+
from typing import Optional, Dict, Tuple
5+
from urllib.parse import urlencode
6+
7+
from databricks.sql.auth.authenticators import AuthProvider
8+
from databricks.sql.auth.auth_utils import (
9+
parse_hostname,
10+
decode_token,
11+
is_same_host,
12+
)
13+
from databricks.sql.common.http import HttpMethod
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class Token:
19+
"""
20+
Represents an OAuth token with expiration management.
21+
"""
22+
23+
def __init__(self, access_token: str, token_type: str = "Bearer"):
24+
"""
25+
Initialize a token.
26+
27+
Args:
28+
access_token: The access token string
29+
token_type: The token type (default: Bearer)
30+
"""
31+
self.access_token = access_token
32+
self.token_type = token_type
33+
self.expiry_time = self._calculate_expiry()
34+
35+
def _calculate_expiry(self) -> datetime:
36+
"""
37+
Calculate the token expiry time from JWT claims.
38+
39+
Returns:
40+
The token expiry datetime
41+
"""
42+
decoded = decode_token(self.access_token)
43+
if decoded and "exp" in decoded:
44+
# Use JWT exp claim with 1 minute buffer
45+
return datetime.fromtimestamp(decoded["exp"]) - timedelta(minutes=1)
46+
# Default to 1 hour if no expiry info
47+
return datetime.now() + timedelta(hours=1)
48+
49+
def is_expired(self) -> bool:
50+
"""
51+
Check if the token is expired.
52+
53+
Returns:
54+
True if token is expired, False otherwise
55+
"""
56+
return datetime.now() >= self.expiry_time
57+
58+
def to_dict(self) -> Dict[str, str]:
59+
"""
60+
Convert token to dictionary format.
61+
62+
Returns:
63+
Dictionary with access_token and token_type
64+
"""
65+
return {
66+
"access_token": self.access_token,
67+
"token_type": self.token_type,
68+
}
69+
70+
71+
class TokenFederationProvider(AuthProvider):
72+
"""
73+
Implementation of Token Federation for Databricks SQL Python driver.
74+
75+
This provider exchanges third-party access tokens for Databricks in-house tokens
76+
when the token issuer is different from the Databricks host.
77+
"""
78+
79+
TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token"
80+
TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"
81+
TOKEN_EXCHANGE_SUBJECT_TYPE = "urn:ietf:params:oauth:token-type:jwt"
82+
83+
def __init__(
84+
self,
85+
hostname: str,
86+
external_provider: AuthProvider,
87+
http_client,
88+
identity_federation_client_id: Optional[str] = None,
89+
):
90+
"""
91+
Initialize the Token Federation Provider.
92+
93+
Args:
94+
hostname: The Databricks workspace hostname
95+
external_provider: The external authentication provider
96+
http_client: HTTP client for making requests (required)
97+
identity_federation_client_id: Optional client ID for token federation
98+
"""
99+
if not http_client:
100+
raise ValueError("http_client is required for TokenFederationProvider")
101+
102+
self.hostname = parse_hostname(hostname)
103+
self.external_provider = external_provider
104+
self.http_client = http_client
105+
self.identity_federation_client_id = identity_federation_client_id
106+
107+
self._cached_token: Optional[Token] = None
108+
self._external_headers: Dict[str, str] = {}
109+
110+
def add_headers(self, request_headers: Dict[str, str]):
111+
"""Add authentication headers to the request."""
112+
113+
if self._cached_token and not self._cached_token.is_expired():
114+
request_headers[
115+
"Authorization"
116+
] = f"{self._cached_token.token_type} {self._cached_token.access_token}"
117+
return
118+
119+
# Get the external headers first to check if we need token federation
120+
self._external_headers = {}
121+
self.external_provider.add_headers(self._external_headers)
122+
123+
# If no Authorization header from external provider, pass through all headers
124+
if "Authorization" not in self._external_headers:
125+
request_headers.update(self._external_headers)
126+
return
127+
128+
token = self._get_token()
129+
request_headers["Authorization"] = f"{token.token_type} {token.access_token}"
130+
131+
def _get_token(self) -> Token:
132+
"""Get or refresh the authentication token."""
133+
# Check if cached token is still valid
134+
if self._cached_token and not self._cached_token.is_expired():
135+
return self._cached_token
136+
137+
# Extract token from already-fetched headers
138+
auth_header = self._external_headers.get("Authorization", "")
139+
token_type, access_token = self._extract_token_from_header(auth_header)
140+
141+
# Check if token exchange is needed
142+
if self._should_exchange_token(access_token):
143+
try:
144+
token = self._exchange_token(access_token)
145+
self._cached_token = token
146+
return token
147+
except Exception as e:
148+
logger.warning("Token exchange failed, using external token: %s", e)
149+
150+
# Use external token directly
151+
token = Token(access_token, token_type)
152+
self._cached_token = token
153+
return token
154+
155+
def _should_exchange_token(self, access_token: str) -> bool:
156+
"""Check if the token should be exchanged based on issuer."""
157+
decoded = decode_token(access_token)
158+
if not decoded:
159+
return False
160+
161+
issuer = decoded.get("iss", "")
162+
# Check if issuer host is different from Databricks host
163+
return not is_same_host(issuer, self.hostname)
164+
165+
def _exchange_token(self, access_token: str) -> Token:
166+
"""Exchange the external token for a Databricks token."""
167+
token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}"
168+
169+
data = {
170+
"grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE,
171+
"subject_token": access_token,
172+
"subject_token_type": self.TOKEN_EXCHANGE_SUBJECT_TYPE,
173+
"scope": "sql",
174+
"return_original_token_if_authenticated": "true",
175+
}
176+
177+
if self.identity_federation_client_id:
178+
data["client_id"] = self.identity_federation_client_id
179+
180+
headers = {
181+
"Content-Type": "application/x-www-form-urlencoded",
182+
"Accept": "*/*",
183+
}
184+
185+
body = urlencode(data)
186+
187+
response = self.http_client.request(
188+
HttpMethod.POST, url=token_url, body=body, headers=headers
189+
)
190+
191+
token_response = json.loads(response.data.decode())
192+
193+
return Token(
194+
token_response["access_token"], token_response.get("token_type", "Bearer")
195+
)
196+
197+
def _extract_token_from_header(self, auth_header: str) -> Tuple[str, str]:
198+
"""Extract token type and access token from Authorization header."""
199+
if not auth_header:
200+
raise ValueError("Authorization header is missing")
201+
202+
parts = auth_header.split(" ", 1)
203+
if len(parts) != 2:
204+
raise ValueError("Invalid Authorization header format")
205+
206+
return parts[0], parts[1]

tests/unit/test_auth.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
164164
kwargs = {"credentials_provider": MyProvider()}
165165
mock_http_client = MagicMock()
166166
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
167-
self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider")
167+
168+
self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider")
169+
self.assertEqual(type(auth_provider.external_provider).__name__, "ExternalAuthProvider")
168170

169171
headers = {}
170172
auth_provider.add_headers(headers)
@@ -199,8 +201,11 @@ def test_get_python_sql_connector_default_auth(self, mock__initial_get_token):
199201
hostname = "foo.cloud.databricks.com"
200202
mock_http_client = MagicMock()
201203
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client)
202-
self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider")
203-
self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
204+
205+
self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider")
206+
self.assertEqual(type(auth_provider.external_provider).__name__, "DatabricksOAuthProvider")
207+
208+
self.assertEqual(auth_provider.external_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
204209

205210

206211
class TestClientCredentialsTokenSource:

0 commit comments

Comments
 (0)