Skip to content

Commit 6af4225

Browse files
committed
feat(auth): support IAM credential authentication
1 parent 722197e commit 6af4225

File tree

6 files changed

+380
-78
lines changed

6 files changed

+380
-78
lines changed

redshift_connector/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def connect(
9494
app_name: str = "amazon_aws_redshift",
9595
preferred_role: typing.Optional[str] = None,
9696
principal_arn: typing.Optional[str] = None,
97+
access_key_id: typing.Optional[str] = None,
98+
secret_access_key: typing.Optional[str] = None,
99+
session_token: typing.Optional[str] = None,
100+
profile: typing.Optional[str] = None,
97101
credentials_provider: typing.Optional[str] = None,
98102
region: typing.Optional[str] = None,
99103
cluster_identifier: typing.Optional[str] = None,
@@ -137,6 +141,10 @@ def connect(
137141
app_name=app_name,
138142
preferred_role=preferred_role,
139143
principal_arn=principal_arn,
144+
access_key_id=access_key_id,
145+
secret_access_key=secret_access_key,
146+
session_token=session_token,
147+
profile=profile,
140148
credentials_provider=credentials_provider,
141149
region=region,
142150
cluster_identifier=cluster_identifier,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .aws_credentials_provider import AWSCredentialsProvider
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import logging
2+
import typing
3+
4+
from redshift_connector.credentials_holder import (
5+
ABCCredentialsHolder,
6+
AWSDirectCredentialsHolder,
7+
AWSProfileCredentialsHolder,
8+
)
9+
from redshift_connector.error import InterfaceError
10+
11+
_logger: logging.Logger = logging.getLogger(__name__)
12+
13+
if typing.TYPE_CHECKING:
14+
import boto3 # type: ignore
15+
16+
from redshift_connector.redshift_property import RedshiftProperty
17+
18+
19+
class AWSCredentialsProvider:
20+
"""
21+
A credential provider class for AWS credentials specified via :func:`~redshift_connector.connect` using `profile` or AWS access keys.
22+
"""
23+
24+
def __init__(self: "AWSCredentialsProvider") -> None:
25+
self.cache: typing.Dict[int, typing.Union[AWSDirectCredentialsHolder, AWSProfileCredentialsHolder]] = {}
26+
27+
self.access_key_id: typing.Optional[str] = None
28+
self.secret_access_key: typing.Optional[str] = None
29+
self.session_token: typing.Optional[str] = None
30+
self.profile: typing.Optional["boto3.Session"] = None
31+
32+
def get_cache_key(self: "AWSCredentialsProvider") -> int:
33+
"""
34+
Creates a cache key using the hash of either the end-user provided AWS credential information.
35+
36+
Returns
37+
-------
38+
An `int` hash representation of the non-secret portion of credential information: `int`
39+
"""
40+
if self.profile:
41+
return hash(self.profile)
42+
else:
43+
return hash(self.access_key_id)
44+
45+
def get_credentials(
46+
self: "AWSCredentialsProvider",
47+
) -> typing.Union[AWSDirectCredentialsHolder, AWSProfileCredentialsHolder]:
48+
"""
49+
Retrieves a :class`ABCCredentialsHolder` from cache or builds one.
50+
51+
Returns
52+
-------
53+
An `AWSCredentialsHolder` object containing end-user specified AWS credential information: :class`ABCAWSCredentialsHolder`
54+
"""
55+
key: int = self.get_cache_key()
56+
if key not in self.cache:
57+
try:
58+
self.refresh()
59+
except Exception as e:
60+
_logger.error("refresh failed: {}".format(str(e)))
61+
raise InterfaceError(e)
62+
63+
credentials: typing.Union[AWSDirectCredentialsHolder, AWSProfileCredentialsHolder] = self.cache[key]
64+
65+
if credentials is None:
66+
raise InterfaceError("Unable to load AWS credentials")
67+
68+
return credentials
69+
70+
def add_parameter(self: "AWSCredentialsProvider", info: "RedshiftProperty") -> None:
71+
"""
72+
Defines instance variables used for creating a :class`ABCCredentialsHolder` object and associated :class:`boto3.Session`
73+
74+
Parameters
75+
----------
76+
info : :class:`RedshiftProperty`
77+
The :class:`RedshiftProperty` object created using end-user specified values passed to :func:`~redshift_connector.connect`
78+
"""
79+
self.access_key_id = info.access_key_id
80+
self.secret_access_key = info.secret_access_key
81+
self.session_token = info.session_token
82+
self.profile = info.profile
83+
84+
def refresh(self: "AWSCredentialsProvider") -> None:
85+
"""
86+
Establishes a :class:`boto3.Session` using end-user specified AWS credential information
87+
"""
88+
import boto3 # type: ignore
89+
90+
args: typing.Dict[str, str] = {}
91+
92+
if self.profile is not None:
93+
args["profile_name"] = self.profile
94+
elif self.access_key_id is not None:
95+
args["aws_access_key_id"] = self.access_key_id
96+
args["aws_secret_access_key"] = typing.cast(str, self.secret_access_key)
97+
if self.session_token is not None:
98+
args["aws_session_token"] = self.session_token
99+
100+
session: boto3.Session = boto3.Session(**args)
101+
credentials: typing.Optional[typing.Union[AWSProfileCredentialsHolder, AWSDirectCredentialsHolder]] = None
102+
103+
if self.profile is not None:
104+
credentials = AWSProfileCredentialsHolder(profile=self.profile, session=session)
105+
else:
106+
credentials = AWSDirectCredentialsHolder(
107+
access_key_id=typing.cast(str, self.access_key_id),
108+
secret_access_key=typing.cast(str, self.secret_access_key),
109+
session_token=self.session_token,
110+
session=session,
111+
)
112+
113+
key = self.get_cache_key()
114+
self.cache[key] = credentials

redshift_connector/credentials_holder.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,105 @@
11
import datetime
22
import typing
3+
from abc import ABC, abstractmethod
34

5+
if typing.TYPE_CHECKING:
6+
import boto3 # type: ignore
7+
8+
9+
class ABCCredentialsHolder(ABC):
10+
"""
11+
Abstract base class used to store credentials for establishing a connection to an Amazon Redshift cluster.
12+
"""
13+
14+
@abstractmethod
15+
def get_session_credentials(self: "ABCCredentialsHolder"):
16+
"""
17+
A dictionary mapping end-user specified AWS credential value to :func:`boto3.client` parameters.
18+
19+
Returns
20+
_______
21+
A dictionary mapping parameter names to end-user specified values: `typing.Dict[str,str]`
22+
"""
23+
pass
24+
25+
@property
26+
def has_associated_session(self: "ABCCredentialsHolder") -> bool:
27+
"""
28+
A boolean value indicating if the current class stores AWS credentials in a :class:`boto3.Session`.
29+
30+
Returns
31+
-------
32+
`True` if the current class provides a :class:`boto3.Session` object, otherwise `False` : `bool`
33+
"""
34+
return False
35+
36+
37+
class ABCAWSCredentialsHolder(ABC):
38+
"""
39+
Abstract base class used to store AWS credentials provided by user.
40+
"""
41+
42+
def __init__(self: "ABCAWSCredentialsHolder", session: "boto3.Session"):
43+
self.boto_session = session
44+
45+
@property
46+
def has_associated_session(self: "ABCAWSCredentialsHolder") -> bool:
47+
return True
48+
49+
def get_boto_session(self: "ABCAWSCredentialsHolder") -> "boto3.Session":
50+
"""
51+
The :class:`boto3.Session` created using the end-user's AWS Credentials.
52+
Returns
53+
-------
54+
A boto3 session created with the end-user's AWS Credentials: :class:`boto3.Session`
55+
"""
56+
return self.boto_session
57+
58+
59+
class AWSDirectCredentialsHolder(ABCAWSCredentialsHolder):
60+
"""
61+
Credential class used to store AWS credentials provided in :func:`~redshift_connector.connect`.
62+
"""
63+
64+
def __init__(
65+
self, access_key_id: str, secret_access_key: str, session_token: typing.Optional[str], session: "boto3.Session"
66+
):
67+
super().__init__(session)
68+
self.access_key_id: str = access_key_id
69+
self.secret_access_key: str = secret_access_key
70+
self.session_token: typing.Optional[str] = session_token
71+
self._session: "boto3.Session" = session
72+
73+
def get_session_credentials(self: "AWSDirectCredentialsHolder") -> typing.Dict[str, str]:
74+
creds: typing.Dict[str, str] = {
75+
"aws_access_key_id": self.access_key_id,
76+
"aws_secret_access_key": self.secret_access_key,
77+
}
78+
79+
if self.session_token is not None:
80+
creds["aws_session_token"] = self.session_token
81+
82+
return creds
83+
84+
85+
class AWSProfileCredentialsHolder(ABCAWSCredentialsHolder):
86+
"""
87+
Credential class used to store AWS Credentials provided in environment IAM credentials.
88+
"""
89+
90+
def __init__(self, profile: str, session: "boto3.Session"):
91+
super().__init__(session)
92+
self.profile = profile
93+
94+
def get_session_credentials(self: "AWSProfileCredentialsHolder") -> typing.Dict[str, str]:
95+
return {"profile": self.profile}
96+
97+
98+
class CredentialsHolder(ABCCredentialsHolder):
99+
"""
100+
credentials class used to store credentials and metadata from SAML assertion.
101+
"""
4102

5-
# credentials class used to store credentials
6-
# and metadata from SAML assertion
7-
class CredentialsHolder:
8103
def __init__(self: "CredentialsHolder", credentials: typing.Dict[str, typing.Any]) -> None:
9104
self.metadata: "CredentialsHolder.IamMetadata" = CredentialsHolder.IamMetadata()
10105
self.credentials: typing.Dict[str, typing.Any] = credentials
@@ -28,15 +123,25 @@ def get_aws_secret_key(self: "CredentialsHolder") -> str:
28123
def get_session_token(self: "CredentialsHolder") -> str:
29124
return typing.cast(str, self.credentials["SessionToken"])
30125

126+
def get_session_credentials(self: "CredentialsHolder") -> typing.Dict[str, str]:
127+
return {
128+
"aws_access_key_id": self.get_aws_access_key_id(),
129+
"aws_secret_access_key": self.get_aws_secret_key(),
130+
"aws_session_token": self.get_session_token(),
131+
}
132+
31133
# The date on which the current credentials expire.
32134
def get_expiration(self: "CredentialsHolder") -> datetime.datetime:
33135
return self.expiration
34136

35137
def is_expired(self: "CredentialsHolder") -> bool:
36138
return datetime.datetime.now() > self.expiration.replace(tzinfo=None)
37139

38-
# metadata used to store information from SAML assertion
39140
class IamMetadata:
141+
"""
142+
Metadata used to store information from SAML assertion
143+
"""
144+
40145
def __init__(self: "CredentialsHolder.IamMetadata") -> None:
41146
self.auto_create: bool = False
42147
self.db_user: typing.Optional[str] = None

0 commit comments

Comments
 (0)