Skip to content

Commit 44921f6

Browse files
committed
fix(idp, JwtCredentialsProvider): derive db_user from jwt token
1 parent 7a4fb28 commit 44921f6

File tree

1 file changed

+49
-9
lines changed

1 file changed

+49
-9
lines changed

redshift_connector/plugin/jwt_credentials_provider.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import json
23
import logging
34
import re
45
import typing
@@ -28,12 +29,13 @@ def __init__(self: "JwtCredentialsProvider"):
2829
# optional params
2930
self.role_session_name = JwtCredentialsProvider.DEFAULT_ROLE_SESSION_NAME
3031
self.duration: typing.Optional[int] = None
31-
# self.db_user: typing.Optional[str] = None
32+
self.db_user: typing.Optional[str] = None
3233
# self.db_groups: typing.Optional[str] = None
33-
self.db_groups_filter: typing.Optional[str] = None
34+
# self.db_groups_filter: typing.Optional[str] = None
3435
# self.force_lowercase: typing.Optional[bool] = None
3536
# self.auto_create: typing.Optional[bool] = None
36-
# self.region: typing.Optional[str] = None
37+
self.sts_endpoint: typing.Optional[str] = None
38+
self.region: typing.Optional[str] = None
3739

3840
@abstractmethod
3941
def process_jwt(self: "JwtCredentialsProvider", jwt: str) -> str:
@@ -43,9 +45,11 @@ def add_parameter(
4345
self: "JwtCredentialsProvider",
4446
info: RedshiftProperty,
4547
) -> None:
46-
super().add_parameter(info)
47-
self.jwt = info.web_identity_token
4848
self.role_arn = info.role_arn
49+
self.jwt = info.web_identity_token
50+
self.duration = info.duration
51+
# Do not read dbUser from connection, as it derives from token.
52+
self.region = info.region
4953

5054
if info.role_session_name is not None:
5155
self.role_session_name = info.role_session_name
@@ -80,6 +84,8 @@ def refresh(self: "JwtCredentialsProvider") -> None:
8084
jwt: str = self.process_jwt(self.jwt)
8185
decoded_jwt: typing.Optional[typing.List[typing.Union[str, bytes]]] = self.decode_jwt(self.jwt)
8286

87+
self.db_user = self.derive_database_user(decoded_jwt)
88+
8389
response = client.assume_role_with_web_identity(
8490
RoleArn=self.role_arn,
8591
RoleSessionName=self.role_session_name,
@@ -89,6 +95,7 @@ def refresh(self: "JwtCredentialsProvider") -> None:
8995

9096
stscred: typing.Dict[str, typing.Any] = response["Credentials"]
9197
credentials: CredentialsHolder = CredentialsHolder(stscred)
98+
credentials.set_metadata(self.read_metadata())
9299
key: str = self.get_cache_key()
93100
self.cache[key] = credentials
94101

@@ -127,22 +134,52 @@ def decode_jwt(
127134
return None
128135

129136
# base64(JOSE header).base64(payload).base64(signature)
130-
header_payload_sig: typing.List[str] = jwt.split("\\.")
137+
header_payload_sig: typing.List[str] = jwt.split(".")
131138

132139
_logger.debug("Encoded JWT Elements: {}".format(header_payload_sig))
133140

134141
if len(header_payload_sig) == 3:
135142
decoded_jwt: typing.List[typing.Union[bytes, str]] = []
136143
# decode the header and payload
137144
for i in range(2):
138-
decoded_jwt.append(base64.b64decode(header_payload_sig[i]))
145+
decoded_jwt.append(base64.b64decode(header_payload_sig[i] + "==="))
139146

140147
decoded_jwt.append(header_payload_sig[2])
141148
_logger.debug("Decoded JWT Elements: {}".format(header_payload_sig))
142149
return decoded_jwt
143150
else:
144151
return None
145152

153+
def derive_database_user(
154+
self: "JwtCredentialsProvider", decoded_jwt: typing.Optional[typing.List[typing.Union[str, bytes]]]
155+
) -> str:
156+
database_user: typing.Optional[str] = None
157+
158+
if decoded_jwt is not None and len(decoded_jwt) == 3:
159+
payload: str = typing.cast(str, decoded_jwt[1])
160+
claims: typing.Tuple[str, ...] = ("DbUser", "upn", "preferred_username", "email")
161+
162+
entity_json: typing.Dict = json.loads(payload)
163+
user_token_field: typing.Dict = {}
164+
165+
for claim in claims:
166+
user_token_field = entity_json.get(claim, None)
167+
168+
if user_token_field is not None:
169+
database_user = typing.cast(str, user_token_field)
170+
171+
if database_user is not None and database_user != "":
172+
_logger.debug(
173+
"JWT claim: {claim} as database user {user}".format(claim=claim, user=database_user)
174+
)
175+
break
176+
177+
if database_user is None or database_user == "":
178+
raise InterfaceError("No database user claim found in JWT")
179+
return database_user
180+
else:
181+
raise InterfaceError("JWT decoding error")
182+
146183
def get_saml_assertion(self: "SamlCredentialsProvider"):
147184
raise NotImplementedError
148185

@@ -152,8 +189,11 @@ def do_verify_ssl_cert(self: "SamlCredentialsProvider") -> bool:
152189
def get_form_action(self: "SamlCredentialsProvider", soup) -> typing.Optional[str]:
153190
raise NotImplementedError
154191

155-
def read_metadata(self: "SamlCredentialsProvider", doc: bytes) -> CredentialsHolder.IamMetadata:
156-
raise NotImplementedError
192+
def read_metadata(self: "SamlCredentialsProvider", doc: bytes = b"") -> CredentialsHolder.IamMetadata:
193+
metadata: CredentialsHolder.IamMetadata = CredentialsHolder.IamMetadata()
194+
metadata.set_db_user(typing.cast(str, self.db_user))
195+
metadata.set_auto_create("true")
196+
return metadata
157197

158198

159199
class BasicJwtCredentialsProvider(JwtCredentialsProvider):

0 commit comments

Comments
 (0)