Skip to content

Commit 2cb3ee3

Browse files
committed
feat(IdP, JwtCredentialsProvider): support JWT SSO IdP
1 parent 26b2888 commit 2cb3ee3

File tree

5 files changed

+188
-1
lines changed

5 files changed

+188
-1
lines changed

redshift_connector/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def connect(
117117
client_protocol_version: int = DEFAULT_PROTOCOL_VERSION,
118118
database_metadata_current_db_only: bool = True,
119119
ssl_insecure: typing.Optional[bool] = None,
120+
web_identity_token: typing.Optional[str] = None,
121+
role_session_name: typing.Optional[str] = None,
122+
role_arn: typing.Optional[str] = None,
120123
) -> Connection:
121124

122125
info: RedshiftProperty = RedshiftProperty()
@@ -164,6 +167,9 @@ def connect(
164167
client_protocol_version=client_protocol_version,
165168
database_metadata_current_db_only=database_metadata_current_db_only,
166169
ssl_insecure=ssl_insecure,
170+
web_identity_token=web_identity_token,
171+
role_session_name=role_session_name,
172+
role_arn=role_arn,
167173
)
168174

169175
return Connection(

redshift_connector/iam_helper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def set_iam_properties(
9090
client_protocol_version: int,
9191
database_metadata_current_db_only: bool,
9292
ssl_insecure: typing.Optional[bool],
93+
web_identity_token: typing.Optional[str],
94+
role_session_name: typing.Optional[str],
95+
role_arn: typing.Optional[str],
9396
) -> None:
9497
"""
9598
Helper function to handle IAM connection properties and ensure required parameters are specified.
@@ -241,6 +244,11 @@ def set_iam_properties(
241244
info.login_url = login_url
242245
info.partner_sp_id = partner_sp_id
243246

247+
# Jwt idp parameters
248+
info.web_identity_token = web_identity_token
249+
info.role_session_name = role_session_name
250+
info.role_arn = role_arn
251+
244252
if info.iam is True:
245253
IamHelper.set_iam_credentials(info)
246254
else:
@@ -253,7 +261,7 @@ def set_iam_credentials(info: RedshiftProperty) -> None:
253261
"""
254262
klass: typing.Optional[SamlCredentialsProvider] = None
255263
provider: typing.Union[SamlCredentialsProvider, AWSCredentialsProvider]
256-
# case insensitive comparison
264+
257265
if info.credentials_provider is not None:
258266
try:
259267
klass = dynamic_plugin_import(info.credentials_provider)

redshift_connector/plugin/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from .azure_credentials_provider import AzureCredentialsProvider
33
from .browser_azure_credentials_provider import BrowserAzureCredentialsProvider
44
from .browser_saml_credentials_provider import BrowserSamlCredentialsProvider
5+
from .jwt_credentials_provider import (
6+
BasicJwtCredentialsProvider,
7+
JwtCredentialsProvider,
8+
)
59
from .okta_credentials_provider import OktaCredentialsProvider
610
from .ping_credentials_provider import PingCredentialsProvider
711
from .saml_credentials_provider import SamlCredentialsProvider
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import base64
2+
import logging
3+
import re
4+
import typing
5+
from abc import ABC, abstractmethod
6+
7+
from redshift_connector.credentials_holder import CredentialsHolder
8+
from redshift_connector.error import InterfaceError
9+
from redshift_connector.plugin.saml_credentials_provider import SamlCredentialsProvider
10+
from redshift_connector.redshift_property import RedshiftProperty
11+
12+
_logger: logging.Logger = logging.getLogger(__name__)
13+
14+
15+
class JwtCredentialsProvider(SamlCredentialsProvider, ABC):
16+
KEY_ROLE_ARN: str = "role_arn"
17+
KEY_WEB_IDENTITY_TOKEN: str = "web_identity_token"
18+
KEY_DURATION: str = "duration"
19+
KEY_ROLE_SESSION_NAME: str = "role_session_name"
20+
DEFAULT_ROLE_SESSION_NAME: str = "jwt_redshift_session"
21+
22+
def __init__(self: "JwtCredentialsProvider"):
23+
super().__init__()
24+
# required params
25+
self.role_arn: typing.Optional[str] = None
26+
self.jwt: typing.Optional[str] = None
27+
28+
# optional params
29+
self.role_session_name = JwtCredentialsProvider.DEFAULT_ROLE_SESSION_NAME
30+
self.duration: typing.Optional[int] = None
31+
# self.db_user: typing.Optional[str] = None
32+
# self.db_groups: typing.Optional[str] = None
33+
self.db_groups_filter: typing.Optional[str] = None
34+
# self.force_lowercase: typing.Optional[bool] = None
35+
# self.auto_create: typing.Optional[bool] = None
36+
# self.region: typing.Optional[str] = None
37+
38+
@abstractmethod
39+
def process_jwt(self: "JwtCredentialsProvider", jwt: str) -> str:
40+
pass # pragma: no cover
41+
42+
def add_parameter(
43+
self: "JwtCredentialsProvider",
44+
info: RedshiftProperty,
45+
) -> None:
46+
super().add_parameter(info)
47+
self.jwt = info.web_identity_token
48+
self.role_arn = info.role_arn
49+
50+
if info.role_session_name is not None:
51+
self.role_session_name = info.role_session_name
52+
53+
def get_cache_key(self: "JwtCredentialsProvider") -> str:
54+
return "{}{}{}{}".format(self.role_arn, self.jwt, self.role_session_name, self.duration)
55+
56+
def get_credentials(self: "SamlCredentialsProvider") -> CredentialsHolder:
57+
key: str = self.get_cache_key()
58+
if key not in self.cache or self.cache[key].is_expired():
59+
try:
60+
self.refresh()
61+
except Exception as e:
62+
_logger.error("refresh failed: {}".format(str(e)))
63+
raise InterfaceError(e)
64+
65+
if key not in self.cache or self.cache[key] is None:
66+
raise InterfaceError("Unable to load AWS credentials from IDP")
67+
68+
return self.cache[key]
69+
70+
def refresh(self: "JwtCredentialsProvider") -> None:
71+
import boto3 # type: ignore
72+
73+
client = boto3.client("sts")
74+
75+
try:
76+
_logger.debug("JWT: {}".format(self.jwt))
77+
if self.jwt is None:
78+
raise InterfaceError("Unable to refresh, no jwt provided")
79+
80+
jwt: str = self.process_jwt(self.jwt)
81+
decoded_jwt: typing.Optional[typing.List[typing.Union[str, bytes]]] = self.decode_jwt(self.jwt)
82+
83+
response = client.assume_role_with_web_identity(
84+
RoleArn=self.role_arn,
85+
RoleSessionName=self.role_session_name,
86+
WebIdentityToken=jwt,
87+
DurationSeconds=self.duration if (self.duration is not None) and (self.duration > 0) else None,
88+
)
89+
90+
stscred: typing.Dict[str, typing.Any] = response["Credentials"]
91+
credentials: CredentialsHolder = CredentialsHolder(stscred)
92+
key: str = self.get_cache_key()
93+
self.cache[key] = credentials
94+
95+
except client.exceptions.MalformedPolicyDocumentException as e:
96+
_logger.error("MalformedPolicyDocumentException: %s", e)
97+
raise e
98+
except client.exceptions.PackedPolicyTooLargeException as e:
99+
_logger.error("PackedPolicyTooLargeException: %s", e)
100+
raise e
101+
except client.exceptions.IDPRejectedClaimException as e:
102+
_logger.error("IDPRejectedClaimException: %s", e)
103+
raise e
104+
except client.exceptions.InvalidIdentityTokenException as e:
105+
_logger.error("InvalidIdentityTokenException: %s", e)
106+
raise e
107+
except client.exceptions.ExpiredTokenException as e:
108+
_logger.error("ExpiredTokenException: %s", e)
109+
raise e
110+
except client.exceptions.RegionDisabledException as e:
111+
_logger.error("RegionDisabledException: %s", e)
112+
raise e
113+
except Exception as e:
114+
_logger.error("other Exception: %s", e)
115+
raise e
116+
117+
def check_required_parameters(self: "JwtCredentialsProvider") -> None:
118+
if self.role_arn is None or self.role_arn == "":
119+
raise InterfaceError("Missing required property: {}".format(JwtCredentialsProvider.KEY_ROLE_ARN))
120+
elif self.jwt is None or self.jwt == "":
121+
raise InterfaceError("Missing required property: {}".format(JwtCredentialsProvider.KEY_WEB_IDENTITY_TOKEN))
122+
123+
def decode_jwt(
124+
self: "JwtCredentialsProvider", jwt: typing.Optional[str]
125+
) -> typing.Optional[typing.List[typing.Union[str, bytes]]]:
126+
if jwt is None:
127+
return None
128+
129+
# base64(JOSE header).base64(payload).base64(signature)
130+
header_payload_sig: typing.List[str] = jwt.split("\\.")
131+
132+
_logger.debug("Encoded JWT Elements: {}".format(header_payload_sig))
133+
134+
if len(header_payload_sig) == 3:
135+
decoded_jwt: typing.List[typing.Union[bytes, str]] = []
136+
# decode the header and payload
137+
for i in range(2):
138+
decoded_jwt.append(base64.b64decode(header_payload_sig[i]))
139+
140+
decoded_jwt.append(header_payload_sig[2])
141+
_logger.debug("Decoded JWT Elements: {}".format(header_payload_sig))
142+
return decoded_jwt
143+
else:
144+
return None
145+
146+
def get_saml_assertion(self: "SamlCredentialsProvider"):
147+
raise NotImplementedError
148+
149+
def do_verify_ssl_cert(self: "SamlCredentialsProvider") -> bool:
150+
raise NotImplementedError
151+
152+
def get_form_action(self: "SamlCredentialsProvider", soup) -> typing.Optional[str]:
153+
raise NotImplementedError
154+
155+
def read_metadata(self: "SamlCredentialsProvider", doc: bytes) -> CredentialsHolder.IamMetadata:
156+
raise NotImplementedError
157+
158+
159+
class BasicJwtCredentialsProvider(JwtCredentialsProvider):
160+
"""
161+
A basic JWT Credential provider class that can be changed and implemented to work with any desired JWT service provider.
162+
"""
163+
164+
def process_jwt(self: "JwtCredentialsProvider", jwt: str) -> str:
165+
self.check_required_parameters()
166+
return self.jwt # type: ignore

redshift_connector/redshift_property.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,6 @@ class RedshiftProperty:
9494
idp_response_timeout: int = 120
9595
listen_port: int = 7890
9696
login_url: typing.Optional[str] = None
97+
web_identity_token: typing.Optional[str] = None
98+
role_session_name: typing.Optional[str] = None
99+
role_arn: typing.Optional[str] = None

0 commit comments

Comments
 (0)