|
| 1 | +import base64 |
| 2 | +import logging |
| 3 | +import re |
| 4 | +import typing |
| 5 | +from abc import ABC, abstractmethod |
| 6 | + |
| 7 | +from redshift_connector.credentials_holder import CredentialsHolder |
| 8 | +from redshift_connector.error import InterfaceError |
| 9 | +from redshift_connector.plugin.saml_credentials_provider import SamlCredentialsProvider |
| 10 | +from redshift_connector.redshift_property import RedshiftProperty |
| 11 | + |
| 12 | +_logger: logging.Logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +class JwtCredentialsProvider(SamlCredentialsProvider, ABC): |
| 16 | + KEY_ROLE_ARN: str = "role_arn" |
| 17 | + KEY_WEB_IDENTITY_TOKEN: str = "web_identity_token" |
| 18 | + KEY_DURATION: str = "duration" |
| 19 | + KEY_ROLE_SESSION_NAME: str = "role_session_name" |
| 20 | + DEFAULT_ROLE_SESSION_NAME: str = "jwt_redshift_session" |
| 21 | + |
| 22 | + def __init__(self: "JwtCredentialsProvider"): |
| 23 | + super().__init__() |
| 24 | + # required params |
| 25 | + self.role_arn: typing.Optional[str] = None |
| 26 | + self.jwt: typing.Optional[str] = None |
| 27 | + |
| 28 | + # optional params |
| 29 | + self.role_session_name = JwtCredentialsProvider.DEFAULT_ROLE_SESSION_NAME |
| 30 | + self.duration: typing.Optional[int] = None |
| 31 | + # self.db_user: typing.Optional[str] = None |
| 32 | + # self.db_groups: typing.Optional[str] = None |
| 33 | + self.db_groups_filter: typing.Optional[str] = None |
| 34 | + # self.force_lowercase: typing.Optional[bool] = None |
| 35 | + # self.auto_create: typing.Optional[bool] = None |
| 36 | + # self.region: typing.Optional[str] = None |
| 37 | + |
| 38 | + @abstractmethod |
| 39 | + def process_jwt(self: "JwtCredentialsProvider", jwt: str) -> str: |
| 40 | + pass # pragma: no cover |
| 41 | + |
| 42 | + def add_parameter( |
| 43 | + self: "JwtCredentialsProvider", |
| 44 | + info: RedshiftProperty, |
| 45 | + ) -> None: |
| 46 | + super().add_parameter(info) |
| 47 | + self.jwt = info.web_identity_token |
| 48 | + self.role_arn = info.role_arn |
| 49 | + |
| 50 | + if info.role_session_name is not None: |
| 51 | + self.role_session_name = info.role_session_name |
| 52 | + |
| 53 | + def get_cache_key(self: "JwtCredentialsProvider") -> str: |
| 54 | + return "{}{}{}{}".format(self.role_arn, self.jwt, self.role_session_name, self.duration) |
| 55 | + |
| 56 | + def get_credentials(self: "SamlCredentialsProvider") -> CredentialsHolder: |
| 57 | + key: str = self.get_cache_key() |
| 58 | + if key not in self.cache or self.cache[key].is_expired(): |
| 59 | + try: |
| 60 | + self.refresh() |
| 61 | + except Exception as e: |
| 62 | + _logger.error("refresh failed: {}".format(str(e))) |
| 63 | + raise InterfaceError(e) |
| 64 | + |
| 65 | + if key not in self.cache or self.cache[key] is None: |
| 66 | + raise InterfaceError("Unable to load AWS credentials from IDP") |
| 67 | + |
| 68 | + return self.cache[key] |
| 69 | + |
| 70 | + def refresh(self: "JwtCredentialsProvider") -> None: |
| 71 | + import boto3 # type: ignore |
| 72 | + |
| 73 | + client = boto3.client("sts") |
| 74 | + |
| 75 | + try: |
| 76 | + _logger.debug("JWT: {}".format(self.jwt)) |
| 77 | + if self.jwt is None: |
| 78 | + raise InterfaceError("Unable to refresh, no jwt provided") |
| 79 | + |
| 80 | + jwt: str = self.process_jwt(self.jwt) |
| 81 | + decoded_jwt: typing.Optional[typing.List[typing.Union[str, bytes]]] = self.decode_jwt(self.jwt) |
| 82 | + |
| 83 | + response = client.assume_role_with_web_identity( |
| 84 | + RoleArn=self.role_arn, |
| 85 | + RoleSessionName=self.role_session_name, |
| 86 | + WebIdentityToken=jwt, |
| 87 | + DurationSeconds=self.duration if (self.duration is not None) and (self.duration > 0) else None, |
| 88 | + ) |
| 89 | + |
| 90 | + stscred: typing.Dict[str, typing.Any] = response["Credentials"] |
| 91 | + credentials: CredentialsHolder = CredentialsHolder(stscred) |
| 92 | + key: str = self.get_cache_key() |
| 93 | + self.cache[key] = credentials |
| 94 | + |
| 95 | + except client.exceptions.MalformedPolicyDocumentException as e: |
| 96 | + _logger.error("MalformedPolicyDocumentException: %s", e) |
| 97 | + raise e |
| 98 | + except client.exceptions.PackedPolicyTooLargeException as e: |
| 99 | + _logger.error("PackedPolicyTooLargeException: %s", e) |
| 100 | + raise e |
| 101 | + except client.exceptions.IDPRejectedClaimException as e: |
| 102 | + _logger.error("IDPRejectedClaimException: %s", e) |
| 103 | + raise e |
| 104 | + except client.exceptions.InvalidIdentityTokenException as e: |
| 105 | + _logger.error("InvalidIdentityTokenException: %s", e) |
| 106 | + raise e |
| 107 | + except client.exceptions.ExpiredTokenException as e: |
| 108 | + _logger.error("ExpiredTokenException: %s", e) |
| 109 | + raise e |
| 110 | + except client.exceptions.RegionDisabledException as e: |
| 111 | + _logger.error("RegionDisabledException: %s", e) |
| 112 | + raise e |
| 113 | + except Exception as e: |
| 114 | + _logger.error("other Exception: %s", e) |
| 115 | + raise e |
| 116 | + |
| 117 | + def check_required_parameters(self: "JwtCredentialsProvider") -> None: |
| 118 | + if self.role_arn is None or self.role_arn == "": |
| 119 | + raise InterfaceError("Missing required property: {}".format(JwtCredentialsProvider.KEY_ROLE_ARN)) |
| 120 | + elif self.jwt is None or self.jwt == "": |
| 121 | + raise InterfaceError("Missing required property: {}".format(JwtCredentialsProvider.KEY_WEB_IDENTITY_TOKEN)) |
| 122 | + |
| 123 | + def decode_jwt( |
| 124 | + self: "JwtCredentialsProvider", jwt: typing.Optional[str] |
| 125 | + ) -> typing.Optional[typing.List[typing.Union[str, bytes]]]: |
| 126 | + if jwt is None: |
| 127 | + return None |
| 128 | + |
| 129 | + # base64(JOSE header).base64(payload).base64(signature) |
| 130 | + header_payload_sig: typing.List[str] = jwt.split("\\.") |
| 131 | + |
| 132 | + _logger.debug("Encoded JWT Elements: {}".format(header_payload_sig)) |
| 133 | + |
| 134 | + if len(header_payload_sig) == 3: |
| 135 | + decoded_jwt: typing.List[typing.Union[bytes, str]] = [] |
| 136 | + # decode the header and payload |
| 137 | + for i in range(2): |
| 138 | + decoded_jwt.append(base64.b64decode(header_payload_sig[i])) |
| 139 | + |
| 140 | + decoded_jwt.append(header_payload_sig[2]) |
| 141 | + _logger.debug("Decoded JWT Elements: {}".format(header_payload_sig)) |
| 142 | + return decoded_jwt |
| 143 | + else: |
| 144 | + return None |
| 145 | + |
| 146 | + def get_saml_assertion(self: "SamlCredentialsProvider"): |
| 147 | + raise NotImplementedError |
| 148 | + |
| 149 | + def do_verify_ssl_cert(self: "SamlCredentialsProvider") -> bool: |
| 150 | + raise NotImplementedError |
| 151 | + |
| 152 | + def get_form_action(self: "SamlCredentialsProvider", soup) -> typing.Optional[str]: |
| 153 | + raise NotImplementedError |
| 154 | + |
| 155 | + def read_metadata(self: "SamlCredentialsProvider", doc: bytes) -> CredentialsHolder.IamMetadata: |
| 156 | + raise NotImplementedError |
| 157 | + |
| 158 | + |
| 159 | +class BasicJwtCredentialsProvider(JwtCredentialsProvider): |
| 160 | + """ |
| 161 | + A basic JWT Credential provider class that can be changed and implemented to work with any desired JWT service provider. |
| 162 | + """ |
| 163 | + |
| 164 | + def process_jwt(self: "JwtCredentialsProvider", jwt: str) -> str: |
| 165 | + self.check_required_parameters() |
| 166 | + return self.jwt # type: ignore |
0 commit comments