Skip to content

Commit 8ef9b6f

Browse files
committed
test(idp, JwtCredentialsProvider): derive db_user from jwt response
1 parent 44921f6 commit 8ef9b6f

File tree

2 files changed

+73
-8
lines changed

2 files changed

+73
-8
lines changed

test/integration/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def jwt_google_idp():
153153
"password": conf.get("jwt-google-idp", "password"),
154154
"credentials_provider": conf.get("jwt-google-idp", "credentials_provider"),
155155
"web_identity_token": conf.get("jwt-google-idp", "web_identity_token"),
156-
"preferred_role": conf.get("jwt-google-idp", "preferred_role"),
156+
"role_arn": conf.get("jwt-google-idp", "role_arn"),
157157
}
158158
return {**_get_default_connection_args(), **db_connect}
159159

@@ -166,7 +166,7 @@ def jwt_azure_v2_idp():
166166
"password": conf.get("jwt-azure-v2-idp", "password"),
167167
"credentials_provider": conf.get("jwt-azure-v2-idp", "credentials_provider"),
168168
"web_identity_token": conf.get("jwt-azure-v2-idp", "web_identity_token"),
169-
"preferred_role": conf.get("jwt-azure-v2-idp", "preferred_role"),
169+
"role_arn": conf.get("jwt-azure-v2-idp", "role_arn"),
170170
}
171171
return {**_get_default_connection_args(), **db_connect}
172172

test/unit/plugin/test_jwt_credentials_provider.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import json
23
import typing
34
from test.unit.helpers import make_redshift_property
45
from unittest.mock import MagicMock, patch
@@ -21,12 +22,13 @@ def make_jwtcredentialsprovider() -> JwtCredentialsProvider:
2122

2223
def test_make_jwtcredentialsprovider():
2324
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
24-
assert hasattr(jwtcp, "role_arn")
2525
assert jwtcp.role_arn is None
26-
assert hasattr(jwtcp, "duration")
26+
assert jwtcp.jwt is None
27+
assert jwtcp.role_session_name is JwtCredentialsProvider.DEFAULT_ROLE_SESSION_NAME
2728
assert jwtcp.duration is None
28-
assert hasattr(jwtcp, "db_groups_filter")
29-
assert jwtcp.db_groups_filter is None
29+
assert jwtcp.db_user is None
30+
assert jwtcp.sts_endpoint is None
31+
assert jwtcp.region is None
3032

3133

3234
def test_jwtcredentialsprovider_add_parameter():
@@ -37,17 +39,20 @@ def test_jwtcredentialsprovider_add_parameter():
3739
_duration: int = 1234
3840
_role: str = "my_role"
3941
_session: str = "my_session"
42+
_region: str = "something"
4043

4144
rp.role_arn = _role
45+
rp.web_identity_token = _wit
4246
rp.role_session_name = _session
4347
rp.duration = _duration
44-
rp.web_identity_token = _wit
48+
rp.region = _region
4549

4650
jwtcp.add_parameter(rp)
4751
assert jwtcp.jwt == _wit
4852
assert jwtcp.duration == _duration
4953
assert jwtcp.role_arn == _role
5054
assert jwtcp.role_session_name == _session
55+
assert jwtcp.region == _region
5156

5257

5358
cache_key_vals: typing.List[typing.Tuple] = [("a", "b", "c", "d"), ()]
@@ -121,7 +126,7 @@ def test_decode_jwt(_input):
121126
assert jwtcp.decode_jwt(jwt) == exp_result
122127

123128

124-
@pytest.mark.parametrize("_input", ["get_saml_assertion", "do_verify_ssl_cert", "get_form_action", "read_metadata"])
129+
@pytest.mark.parametrize("_input", ["get_saml_assertion", "do_verify_ssl_cert", "get_form_action"])
125130
def test_get_saml_assertion_not_implemented(_input):
126131
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
127132
method_to_call = jwtcp.__getattribute__(_input)
@@ -191,6 +196,10 @@ def test_refresh_passes_jwt_to_boto3(mocker):
191196
mocker.patch(
192197
"redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.decode_jwt", return_value=None
193198
)
199+
mocker.patch(
200+
"redshift_connector.plugin.jwt_credentials_provider.JwtCredentialsProvider.derive_database_user",
201+
return_value="Mouse",
202+
)
194203

195204
jwtcp: JwtCredentialsProvider = make_jwtcredentialsprovider()
196205
mocked_orig_jwt: str = "initial value"
@@ -221,6 +230,62 @@ def test_refresh_passes_jwt_to_boto3(mocker):
221230
assert isinstance(jwtcp.cache[jwtcp.get_cache_key()], CredentialsHolder)
222231

223232

233+
test_jwt_resp_data: typing.List[str] = [
234+
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2MTgyNTgzNjQsImV4cCI6MTY0OTc5NDM2NCwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoianJvY2tldEBleGFtcGxlLmNvbSIsIkdpdmVuTmFtZSI6IkpvaG5ueSIsIlN1cm5hbWUiOiJSb2NrZXQiLCJFbWFpbCI6Impyb2NrZXRAZXhhbXBsZS5jb20iLCJSb2xlIjpbIk1hbmFnZXIiLCJQcm9qZWN0IEFkbWluaXN0cmF0b3IiXX0.4lCP0ZgrKo3f6lQ9AtMdFEeDD5fBnszN3Deo18VyJ-o"
235+
]
236+
237+
238+
@pytest.mark.parametrize("encoded_jwt", test_jwt_resp_data)
239+
def test_decode_jwt_resp(encoded_jwt):
240+
bjwtcp: BasicJwtCredentialsProvider = BasicJwtCredentialsProvider()
241+
242+
decoded = bjwtcp.decode_jwt(encoded_jwt)
243+
assert len(decoded) == 3
244+
for idx, entry in enumerate(decoded[0:2]):
245+
str_entry = entry.decode("utf-8")
246+
data = json.loads(str_entry)
247+
assert len(data) > 1
248+
249+
if idx == 0:
250+
for exp_key in ("alg", "typ"):
251+
assert exp_key in data
252+
elif idx == 1:
253+
for exp_key in ("iss", "iat", "exp", "aud", "sub", "GivenName", "Surname", "Email", "Role"):
254+
assert exp_key in data
255+
256+
257+
@pytest.mark.parametrize("db_param", ["DbUser", "upn", "preferred_username", "email"])
258+
def test_derive_database_user(db_param):
259+
data: typing.Dict[str, typing.Union[str, int]] = {
260+
"iss": "Online JWT Builder",
261+
"iat": 1618258364,
262+
"exp": 1649794364,
263+
"aud": "www.example.com",
264+
"sub": "jrocket@example.com",
265+
"GivenName": "Johnny",
266+
"Surname": "Rocket",
267+
"Role": ["Manager", "Project Administrator"],
268+
}
269+
DB_USER: str = "mr.bear@forest.com"
270+
data[db_param] = DB_USER
271+
272+
mock_jwt_resp: typing.List[typing.Union[str, bytes]] = ["", json.dumps(data), "mocked resp"]
273+
274+
bjwtcp: BasicJwtCredentialsProvider = BasicJwtCredentialsProvider()
275+
assert bjwtcp.derive_database_user(mock_jwt_resp) == DB_USER
276+
277+
278+
@pytest.mark.parametrize(
279+
"decoded_data",
280+
[[""], ["" * 4], ["", json.dumps({"dbuser": "invalid"}), ""], ["", json.dumps({"Email": "invalid"}), ""]],
281+
)
282+
def test_derive_database_user_not_found(decoded_data):
283+
bjwtcp: BasicJwtCredentialsProvider = BasicJwtCredentialsProvider()
284+
285+
with pytest.raises(InterfaceError):
286+
bjwtcp.derive_database_user(decoded_data)
287+
288+
224289
def test_basic_jwt_credential_provider(mocker):
225290
bjwtcp: BasicJwtCredentialsProvider = BasicJwtCredentialsProvider()
226291
bjwtcp.jwt = "hi"

0 commit comments

Comments
 (0)