11import base64
2+ import json
23import logging
34import re
45import typing
@@ -28,12 +29,13 @@ def __init__(self: "JwtCredentialsProvider"):
2829 # optional params
2930 self .role_session_name = JwtCredentialsProvider .DEFAULT_ROLE_SESSION_NAME
3031 self .duration : typing .Optional [int ] = None
31- # self.db_user: typing.Optional[str] = None
32+ self .db_user : typing .Optional [str ] = None
3233 # self.db_groups: typing.Optional[str] = None
33- self .db_groups_filter : typing .Optional [str ] = None
34+ # self.db_groups_filter: typing.Optional[str] = None
3435 # self.force_lowercase: typing.Optional[bool] = None
3536 # self.auto_create: typing.Optional[bool] = None
36- # self.region: typing.Optional[str] = None
37+ self .sts_endpoint : typing .Optional [str ] = None
38+ self .region : typing .Optional [str ] = None
3739
3840 @abstractmethod
3941 def process_jwt (self : "JwtCredentialsProvider" , jwt : str ) -> str :
@@ -43,9 +45,11 @@ def add_parameter(
4345 self : "JwtCredentialsProvider" ,
4446 info : RedshiftProperty ,
4547 ) -> None :
46- super ().add_parameter (info )
47- self .jwt = info .web_identity_token
4848 self .role_arn = info .role_arn
49+ self .jwt = info .web_identity_token
50+ self .duration = info .duration
51+ # Do not read dbUser from connection, as it derives from token.
52+ self .region = info .region
4953
5054 if info .role_session_name is not None :
5155 self .role_session_name = info .role_session_name
@@ -80,6 +84,8 @@ def refresh(self: "JwtCredentialsProvider") -> None:
8084 jwt : str = self .process_jwt (self .jwt )
8185 decoded_jwt : typing .Optional [typing .List [typing .Union [str , bytes ]]] = self .decode_jwt (self .jwt )
8286
87+ self .db_user = self .derive_database_user (decoded_jwt )
88+
8389 response = client .assume_role_with_web_identity (
8490 RoleArn = self .role_arn ,
8591 RoleSessionName = self .role_session_name ,
@@ -89,6 +95,7 @@ def refresh(self: "JwtCredentialsProvider") -> None:
8995
9096 stscred : typing .Dict [str , typing .Any ] = response ["Credentials" ]
9197 credentials : CredentialsHolder = CredentialsHolder (stscred )
98+ credentials .set_metadata (self .read_metadata ())
9299 key : str = self .get_cache_key ()
93100 self .cache [key ] = credentials
94101
@@ -127,22 +134,52 @@ def decode_jwt(
127134 return None
128135
129136 # base64(JOSE header).base64(payload).base64(signature)
130- header_payload_sig : typing .List [str ] = jwt .split ("\\ ." )
137+ header_payload_sig : typing .List [str ] = jwt .split ("." )
131138
132139 _logger .debug ("Encoded JWT Elements: {}" .format (header_payload_sig ))
133140
134141 if len (header_payload_sig ) == 3 :
135142 decoded_jwt : typing .List [typing .Union [bytes , str ]] = []
136143 # decode the header and payload
137144 for i in range (2 ):
138- decoded_jwt .append (base64 .b64decode (header_payload_sig [i ]))
145+ decoded_jwt .append (base64 .b64decode (header_payload_sig [i ] + "===" ))
139146
140147 decoded_jwt .append (header_payload_sig [2 ])
141148 _logger .debug ("Decoded JWT Elements: {}" .format (header_payload_sig ))
142149 return decoded_jwt
143150 else :
144151 return None
145152
153+ def derive_database_user (
154+ self : "JwtCredentialsProvider" , decoded_jwt : typing .Optional [typing .List [typing .Union [str , bytes ]]]
155+ ) -> str :
156+ database_user : typing .Optional [str ] = None
157+
158+ if decoded_jwt is not None and len (decoded_jwt ) == 3 :
159+ payload : str = typing .cast (str , decoded_jwt [1 ])
160+ claims : typing .Tuple [str , ...] = ("DbUser" , "upn" , "preferred_username" , "email" )
161+
162+ entity_json : typing .Dict = json .loads (payload )
163+ user_token_field : typing .Dict = {}
164+
165+ for claim in claims :
166+ user_token_field = entity_json .get (claim , None )
167+
168+ if user_token_field is not None :
169+ database_user = typing .cast (str , user_token_field )
170+
171+ if database_user is not None and database_user != "" :
172+ _logger .debug (
173+ "JWT claim: {claim} as database user {user}" .format (claim = claim , user = database_user )
174+ )
175+ break
176+
177+ if database_user is None or database_user == "" :
178+ raise InterfaceError ("No database user claim found in JWT" )
179+ return database_user
180+ else :
181+ raise InterfaceError ("JWT decoding error" )
182+
146183 def get_saml_assertion (self : "SamlCredentialsProvider" ):
147184 raise NotImplementedError
148185
@@ -152,8 +189,11 @@ def do_verify_ssl_cert(self: "SamlCredentialsProvider") -> bool:
152189 def get_form_action (self : "SamlCredentialsProvider" , soup ) -> typing .Optional [str ]:
153190 raise NotImplementedError
154191
155- def read_metadata (self : "SamlCredentialsProvider" , doc : bytes ) -> CredentialsHolder .IamMetadata :
156- raise NotImplementedError
192+ def read_metadata (self : "SamlCredentialsProvider" , doc : bytes = b"" ) -> CredentialsHolder .IamMetadata :
193+ metadata : CredentialsHolder .IamMetadata = CredentialsHolder .IamMetadata ()
194+ metadata .set_db_user (typing .cast (str , self .db_user ))
195+ metadata .set_auto_create ("true" )
196+ return metadata
157197
158198
159199class BasicJwtCredentialsProvider (JwtCredentialsProvider ):
0 commit comments