11import base64
2+ import json
23import typing
34from test .unit .helpers import make_redshift_property
45from unittest .mock import MagicMock , patch
@@ -21,12 +22,13 @@ def make_jwtcredentialsprovider() -> JwtCredentialsProvider:
2122
2223def test_make_jwtcredentialsprovider ():
2324 jwtcp : JwtCredentialsProvider = make_jwtcredentialsprovider ()
24- assert hasattr (jwtcp , "role_arn" )
2525 assert jwtcp .role_arn is None
26- assert hasattr (jwtcp , "duration" )
26+ assert jwtcp .jwt is None
27+ assert jwtcp .role_session_name is JwtCredentialsProvider .DEFAULT_ROLE_SESSION_NAME
2728 assert jwtcp .duration is None
28- assert hasattr (jwtcp , "db_groups_filter" )
29- assert jwtcp .db_groups_filter is None
29+ assert jwtcp .db_user is None
30+ assert jwtcp .sts_endpoint is None
31+ assert jwtcp .region is None
3032
3133
3234def test_jwtcredentialsprovider_add_parameter ():
@@ -37,17 +39,20 @@ def test_jwtcredentialsprovider_add_parameter():
3739 _duration : int = 1234
3840 _role : str = "my_role"
3941 _session : str = "my_session"
42+ _region : str = "something"
4043
4144 rp .role_arn = _role
45+ rp .web_identity_token = _wit
4246 rp .role_session_name = _session
4347 rp .duration = _duration
44- rp .web_identity_token = _wit
48+ rp .region = _region
4549
4650 jwtcp .add_parameter (rp )
4751 assert jwtcp .jwt == _wit
4852 assert jwtcp .duration == _duration
4953 assert jwtcp .role_arn == _role
5054 assert jwtcp .role_session_name == _session
55+ assert jwtcp .region == _region
5156
5257
5358cache_key_vals : typing .List [typing .Tuple ] = [("a" , "b" , "c" , "d" ), ()]
@@ -121,7 +126,7 @@ def test_decode_jwt(_input):
121126 assert jwtcp .decode_jwt (jwt ) == exp_result
122127
123128
124- @pytest .mark .parametrize ("_input" , ["get_saml_assertion" , "do_verify_ssl_cert" , "get_form_action" , "read_metadata" ])
129+ @pytest .mark .parametrize ("_input" , ["get_saml_assertion" , "do_verify_ssl_cert" , "get_form_action" ])
125130def test_get_saml_assertion_not_implemented (_input ):
126131 jwtcp : JwtCredentialsProvider = make_jwtcredentialsprovider ()
127132 method_to_call = jwtcp .__getattribute__ (_input )
@@ -191,6 +196,10 @@ def test_refresh_passes_jwt_to_boto3(mocker):
191196 mocker .patch (
192197 "redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.decode_jwt" , return_value = None
193198 )
199+ mocker .patch (
200+ "redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.derive_database_user" ,
201+ return_value = "Mouse" ,
202+ )
194203
195204 jwtcp : JwtCredentialsProvider = make_jwtcredentialsprovider ()
196205 mocked_orig_jwt : str = "initial value"
@@ -221,6 +230,62 @@ def test_refresh_passes_jwt_to_boto3(mocker):
221230 assert isinstance (jwtcp .cache [jwtcp .get_cache_key ()], CredentialsHolder )
222231
223232
233+ test_jwt_resp_data : typing .List [str ] = [
234+ "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2MTgyNTgzNjQsImV4cCI6MTY0OTc5NDM2NCwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoianJvY2tldEBleGFtcGxlLmNvbSIsIkdpdmVuTmFtZSI6IkpvaG5ueSIsIlN1cm5hbWUiOiJSb2NrZXQiLCJFbWFpbCI6Impyb2NrZXRAZXhhbXBsZS5jb20iLCJSb2xlIjpbIk1hbmFnZXIiLCJQcm9qZWN0IEFkbWluaXN0cmF0b3IiXX0.4lCP0ZgrKo3f6lQ9AtMdFEeDD5fBnszN3Deo18VyJ-o"
235+ ]
236+
237+
238+ @pytest .mark .parametrize ("encoded_jwt" , test_jwt_resp_data )
239+ def test_decode_jwt_resp (encoded_jwt ):
240+ bjwtcp : BasicJwtCredentialsProvider = BasicJwtCredentialsProvider ()
241+
242+ decoded = bjwtcp .decode_jwt (encoded_jwt )
243+ assert len (decoded ) == 3
244+ for idx , entry in enumerate (decoded [0 :2 ]):
245+ str_entry = entry .decode ("utf-8" )
246+ data = json .loads (str_entry )
247+ assert len (data ) > 1
248+
249+ if idx == 0 :
250+ for exp_key in ("alg" , "typ" ):
251+ assert exp_key in data
252+ elif idx == 1 :
253+ for exp_key in ("iss" , "iat" , "exp" , "aud" , "sub" , "GivenName" , "Surname" , "Email" , "Role" ):
254+ assert exp_key in data
255+
256+
257+ @pytest .mark .parametrize ("db_param" , ["DbUser" , "upn" , "preferred_username" , "email" ])
258+ def test_derive_database_user (db_param ):
259+ data : typing .Dict [str , typing .Union [str , int ]] = {
260+ "iss" : "Online JWT Builder" ,
261+ "iat" : 1618258364 ,
262+ "exp" : 1649794364 ,
263+ "aud" : "www.example.com" ,
264+ "sub" : "jrocket@example.com" ,
265+ "GivenName" : "Johnny" ,
266+ "Surname" : "Rocket" ,
267+ "Role" : ["Manager" , "Project Administrator" ],
268+ }
269+ DB_USER : str = "mr.bear@forest.com"
270+ data [db_param ] = DB_USER
271+
272+ mock_jwt_resp : typing .List [typing .Union [str , bytes ]] = ["" , json .dumps (data ), "mocked resp" ]
273+
274+ bjwtcp : BasicJwtCredentialsProvider = BasicJwtCredentialsProvider ()
275+ assert bjwtcp .derive_database_user (mock_jwt_resp ) == DB_USER
276+
277+
278+ @pytest .mark .parametrize (
279+ "decoded_data" ,
280+ [["" ], ["" * 4 ], ["" , json .dumps ({"dbuser" : "invalid" }), "" ], ["" , json .dumps ({"Email" : "invalid" }), "" ]],
281+ )
282+ def test_derive_database_user_not_found (decoded_data ):
283+ bjwtcp : BasicJwtCredentialsProvider = BasicJwtCredentialsProvider ()
284+
285+ with pytest .raises (InterfaceError ):
286+ bjwtcp .derive_database_user (decoded_data )
287+
288+
224289def test_basic_jwt_credential_provider (mocker ):
225290 bjwtcp : BasicJwtCredentialsProvider = BasicJwtCredentialsProvider ()
226291 bjwtcp .jwt = "hi"
0 commit comments