Skip to content

Commit 94c5ebb

Browse files
committed
fix(auth): include DbGroups when getting temp credentials from boto
1 parent d2e5e89 commit 94c5ebb

File tree

5 files changed

+12
-13
lines changed

5 files changed

+12
-13
lines changed

redshift_connector/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def connect(
107107
listen_port: int = 7890,
108108
login_url: typing.Optional[str] = None,
109109
auto_create: bool = False,
110-
db_groups: typing.Optional[typing.List[str]] = None,
110+
db_groups: typing.List[str] = list(),
111111
force_lowercase: bool = False,
112112
allow_db_user_override: bool = False,
113113
log_level: int = 0,

redshift_connector/credentials_holder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self: "CredentialsHolder.IamMetadata") -> None:
4242
self.db_user: typing.Optional[str] = None
4343
self.saml_db_user: typing.Optional[str] = None
4444
self.profile_db_user: typing.Optional[str] = None
45-
self.db_groups: typing.Optional[str] = None
45+
self.db_groups: typing.List[str] = list()
4646
self.allow_db_user_override: bool = False
4747
self.force_lowercase: bool = False
4848

@@ -73,10 +73,10 @@ def get_profile_db_user(self: "CredentialsHolder.IamMetadata") -> typing.Optiona
7373
def set_profile_db_user(self: "CredentialsHolder.IamMetadata", profile_db_user: str) -> None:
7474
self.profile_db_user = profile_db_user
7575

76-
def get_db_groups(self: "CredentialsHolder.IamMetadata") -> typing.Optional[str]:
76+
def get_db_groups(self: "CredentialsHolder.IamMetadata") -> typing.List[str]:
7777
return self.db_groups
7878

79-
def set_db_groups(self: "CredentialsHolder.IamMetadata", db_groups: str) -> None:
79+
def set_db_groups(self: "CredentialsHolder.IamMetadata", db_groups: typing.List[str]) -> None:
8080
self.db_groups = db_groups
8181

8282
def get_allow_db_user_override(self: "CredentialsHolder.IamMetadata") -> bool:

redshift_connector/iam_helper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def set_iam_properties(
6262
listen_port: int,
6363
login_url: typing.Optional[str],
6464
auto_create: bool,
65-
db_groups: typing.Optional[typing.List[str]],
65+
db_groups: typing.List[str],
6666
force_lowercase: bool,
6767
allow_db_user_override: bool,
6868
) -> None:
@@ -190,7 +190,7 @@ def set_iam_credentials(info: RedshiftProperty) -> None:
190190
db_user: typing.Optional[str] = metadata.get_db_user()
191191
saml_db_user: typing.Optional[str] = metadata.get_saml_db_user()
192192
profile_db_user: typing.Optional[str] = metadata.get_profile_db_user()
193-
db_groups: typing.Optional[str] = metadata.get_db_groups()
193+
db_groups: typing.List[str] = metadata.get_db_groups()
194194
force_lowercase: bool = metadata.get_force_lowercase()
195195
allow_db_user_override: bool = metadata.get_allow_db_user_override()
196196
if auto_create is True:
@@ -214,9 +214,8 @@ def set_iam_credentials(info: RedshiftProperty) -> None:
214214
if saml_db_user is not None:
215215
info.db_user = saml_db_user
216216

217-
if (info.db_groups is None) and (db_groups is not None):
218-
tmp: typing.List[str] = db_groups.split(",")
219-
info.db_groups = [group.lower() for group in tmp]
217+
if (len(info.db_groups) == 0) and (len(db_groups) > 0):
218+
info.db_groups = db_groups
220219

221220
set_cluster_credentials(provider, info)
222221

@@ -240,6 +239,7 @@ def set_cluster_credentials(cred_provider: SamlCredentialsProvider, info: Redshi
240239
cred: dict = client.get_cluster_credentials(
241240
DbUser=info.db_user,
242241
DbName=info.db_name,
242+
DbGroups=info.db_groups,
243243
ClusterIdentifier=info.cluster_identifier,
244244
AutoCreate=info.auto_create,
245245
)

redshift_connector/plugin/saml_credentials_provider.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self: "SamlCredentialsProvider") -> None:
2626
self.preferred_role: typing.Optional[str] = None
2727
self.sslInsecure: typing.Optional[bool] = None
2828
self.db_user: typing.Optional[str] = None
29-
self.db_groups: typing.Optional[typing.List[str]] = None
29+
self.db_groups: typing.List[str] = list()
3030
self.force_lowercase: typing.Optional[bool] = None
3131
self.auto_create: typing.Optional[bool] = None
3232
self.region: typing.Optional[str] = None
@@ -212,8 +212,7 @@ def read_metadata(self: "SamlCredentialsProvider", doc: bytes) -> CredentialsHol
212212
elif name == "https://redshift.amazon.com/SAML/Attributes/AutoCreate":
213213
metadata.set_auto_create(value)
214214
elif name == "https://redshift.amazon.com/SAML/Attributes/DbGroups":
215-
groups = ",".join([value.contents[0] for value in values])
216-
metadata.set_db_groups(groups)
215+
metadata.set_db_groups([value.contents[0].lower() for value in values])
217216
elif name == "https://redshift.amazon.com/SAML/Attributes/ForceLowercase":
218217
metadata.set_force_lowercase(value)
219218

redshift_connector/redshift_property.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class RedshiftProperty:
2727
credentials_provider: typing.Optional[str] = None
2828
# A comma-separated list of existing database group names that the DbUser joins for the current session.
2929
# If not specified, defaults to PUBLIC.
30-
db_groups: typing.Optional[typing.List[str]] = None
30+
db_groups: typing.List[str] = list()
3131
# Forces the database group names to be lower case.
3232
force_lowercase: bool = False
3333
# This option specifies whether the driver uses the DbUser value from the SAML assertion

0 commit comments

Comments
 (0)