Skip to content

Commit f4e1820

Browse files
committed
feat(connection, auth-profile): support Redshift authentication profile
1 parent e023fa7 commit f4e1820

File tree

3 files changed

+95
-3
lines changed

3 files changed

+95
-3
lines changed

redshift_connector/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def connect(
172172
role_session_name: typing.Optional[str] = None,
173173
role_arn: typing.Optional[str] = None,
174174
iam_disable_cache: typing.Optional[bool] = None,
175+
auth_profile: typing.Optional[str] = None,
176+
endpoint_url: typing.Optional[str] = None,
175177
) -> Connection:
176178
"""
177179
Establishes a :class:`Connection` to an Amazon Redshift cluster. This function validates user input, optionally authenticates using an identity provider plugin, then constructs a :class:`Connection` object.
@@ -257,6 +259,10 @@ def connect(
257259
The role ARN used for authentication with JWT. This parameter is required when using a JWTCredentialsProvider.
258260
iam_disable_cache: Optional[bool]
259261
This option specifies whether the IAM credentials are cached. By default caching is enabled.
262+
auth_profile: Optional[str]
263+
The name of an Amazon Redshift Authentication profile having connection properties as JSON. See :class:RedshiftProperty to learn how connection properties should be named.
264+
endpoint_url: Optional[str]
265+
The Amazon Redshift endpoint url. This option is only used by AWS internal teams.
260266
Returns
261267
-------
262268
A Connection object associated with the specified Amazon Redshift cluster: :class:`Connection`
@@ -267,6 +273,7 @@ def connect(
267273
info.put("app_id", app_id)
268274
info.put("app_name", app_name)
269275
info.put("application_name", application_name)
276+
info.put("auth_profile", auth_profile)
270277
info.put("auto_create", auto_create)
271278
info.put("client_id", client_id)
272279
info.put("client_protocol_version", client_protocol_version)
@@ -277,6 +284,7 @@ def connect(
277284
info.put("db_groups", db_groups)
278285
info.put("db_name", database)
279286
info.put("db_user", db_user)
287+
info.put("endpoint_url", endpoint_url)
280288
info.put("force_lowercase", force_lowercase)
281289
info.put("host", host)
282290
info.put("iam", iam)

redshift_connector/iam_helper.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,36 @@ def set_iam_properties(info: RedshiftProperty) -> RedshiftProperty:
5353
"""
5454
if info is None:
5555
raise InterfaceError("Invalid connection property setting. info must be specified")
56+
57+
# Check for IAM keys and AuthProfile first
58+
if info.auth_profile is not None:
59+
import pkg_resources
60+
from packaging.version import Version
61+
62+
if Version(pkg_resources.get_distribution("boto3").version) < Version("1.17.111"):
63+
raise pkg_resources.VersionConflict(
64+
"boto3 >= 1.17.111 required for authentication via Amazon Redshift authentication profile. "
65+
"Please upgrade the installed version of boto3 to use this functionality."
66+
)
67+
68+
if not all((info.access_key_id, info.secret_access_key, info.region)):
69+
raise InterfaceError(
70+
"Invalid connection property setting. access_key_id, secret_access_key, and region are required "
71+
"for authentication via Redshift auth_profile"
72+
)
73+
else:
74+
# info.put("region", info.region)
75+
# info.put("endpoint_url", info.endpoint_url)
76+
77+
resp = IamHelper.read_auth_profile(
78+
auth_profile=typing.cast(str, info.auth_profile),
79+
iam_access_key_id=typing.cast(str, info.access_key_id),
80+
iam_secret_key=typing.cast(str, info.secret_access_key),
81+
iam_session_token=info.session_token,
82+
info=info,
83+
)
84+
info.put_all(resp)
85+
5686
# IAM requires an SSL connection to work.
5787
# Make sure that is set to SSL level VERIFY_CA or higher.
5888
if info.ssl is True:
@@ -87,14 +117,15 @@ def set_iam_properties(info: RedshiftProperty) -> RedshiftProperty:
87117
info.secret_access_key,
88118
info.session_token,
89119
info.profile,
120+
info.auth_profile,
90121
)
91122
):
92123
raise InterfaceError(
93124
"Invalid connection property setting. Credentials provider, AWS credentials, Redshift auth profile "
94125
"or AWS profile must be provided when IAM is enabled"
95126
)
96127

97-
if info.cluster_identifier is None and info.cluster_identifier is None:
128+
if info.cluster_identifier is None:
98129
raise InterfaceError(
99130
"Invalid connection property setting. cluster_identifier must be provided when IAM is enabled"
100131
)
@@ -159,6 +190,56 @@ def set_iam_properties(info: RedshiftProperty) -> RedshiftProperty:
159190
IamHelper.set_iam_credentials(info)
160191
return info
161192

193+
@staticmethod
194+
def read_auth_profile(
195+
auth_profile: str,
196+
iam_access_key_id: str,
197+
iam_secret_key: str,
198+
iam_session_token: typing.Optional[str],
199+
info: RedshiftProperty,
200+
) -> RedshiftProperty:
201+
import json
202+
203+
import boto3
204+
from botocore.exceptions import ClientError
205+
206+
# 1st phase - authenticate with boto3 client for Amazon Redshift via IAM
207+
# credentials provided by end user
208+
creds: typing.Dict[str, str] = {
209+
"aws_access_key_id": iam_access_key_id,
210+
"aws_secret_access_key": iam_secret_key,
211+
"region_name": typing.cast(str, info.region),
212+
}
213+
214+
for opt_key, opt_val in (
215+
("aws_session_token", iam_session_token),
216+
("endpoint_url", info.endpoint_url),
217+
):
218+
if opt_val is not None and opt_val != "":
219+
creds[opt_key] = opt_val
220+
221+
try:
222+
_logger.debug("Initial authentication with boto3...")
223+
client = boto3.client(service_name="redshift", **creds)
224+
_logger.debug("Requesting authentication profiles")
225+
# 2nd phase - request Amazon Redshift authentication profiles and record contents for retrieving
226+
# temporary credentials for the Amazon Redshift cluster specified by end user
227+
response = client.describe_authentication_profiles(AuthenticationProfileName=auth_profile)
228+
except ClientError:
229+
raise InterfaceError("Unable to retrieve contents of Redshift authentication profile from server")
230+
231+
_logger.debug("Received {} authentication profiles".format(len(response["AuthenticationProfiles"])))
232+
# the first matching authentication profile will be used
233+
profile_content: typing.Union[str] = response["AuthenticationProfiles"][0]["AuthenticationProfileContent"]
234+
235+
try:
236+
profile_content_dict: typing.Dict = json.loads(profile_content)
237+
return RedshiftProperty(**profile_content_dict)
238+
except ValueError:
239+
raise ProgrammingError(
240+
"Unable to decode the JSON content of the Redshift authentication profile: {}".format(auth_profile)
241+
)
242+
162243
@staticmethod
163244
def set_iam_credentials(info: RedshiftProperty) -> None:
164245
"""
@@ -258,8 +339,9 @@ def set_cluster_credentials(
258339
] = cred_provider.get_credentials()
259340
session_credentials: typing.Dict[str, str] = credentials_holder.get_session_credentials()
260341

261-
if info.region is not None:
262-
session_credentials["region_name"] = info.region
342+
for opt_key, opt_val in (("region_name", info.region), ("endpoint_url", info.endpoint_url)):
343+
if opt_val is not None:
344+
session_credentials[opt_key] = opt_val
263345

264346
# if AWS credentials were used to create a boto3.Session object, use it
265347
if credentials_holder.has_associated_session:

redshift_connector/redshift_property.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self: "RedshiftProperty", **kwargs):
1919
# The name of the Okta application that you use to authenticate the connection to Redshift.
2020
self.app_name: str = "amazon_aws_redshift"
2121
self.application_name: typing.Optional[str] = None
22+
self.auth_profile: typing.Optional[str] = None
2223
# Indicates whether the user should be created if it does not already exist.
2324
self.auto_create: bool = False
2425
# The client ID associated with the user name in the Azure AD portal. Only used for Azure AD.
@@ -44,6 +45,7 @@ def __init__(self: "RedshiftProperty", **kwargs):
4445
self.db_user: typing.Optional[str] = None
4546
# The length of time, in seconds
4647
self.duration: int = 900
48+
self.endpoint_url: typing.Optional[str] = None
4749
# Forces the database group names to be lower case.
4850
self.force_lowercase: bool = False
4951
# The host to connect to.

0 commit comments

Comments
 (0)