Skip to content

Commit 1eaac8d

Browse files
committed
feat(auth): Support Redshift custom domain name
1 parent 9a8c094 commit 1eaac8d

File tree

10 files changed

+592
-130
lines changed

10 files changed

+592
-130
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
- id: isort
1212
args: ["--profile", "black", "."]
1313
- repo: https://github.com/ambv/black
14-
rev: 20.8b1
14+
rev: 22.3.0
1515
hooks:
1616
- id: black
1717
- repo: https://github.com/pre-commit/mirrors-mypy

redshift_connector/core.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,20 @@ def client_os_version(self: "Connection") -> str:
394394
os_version = "unknown"
395395
return os_version
396396

397+
@staticmethod
398+
def __get_host_address_info(host: str, port: int):
399+
"""
400+
Returns IPv4 address and port given a host name and port
401+
"""
402+
# https://docs.python.org/3/library/socket.html#socket.getaddrinfo
403+
response = socket.getaddrinfo(host=host, port=port, family=socket.AF_INET)
404+
_logger.debug("getaddrinfo response {}".format(response))
405+
406+
if not response:
407+
raise InterfaceError("Unable to determine ip for host {} port {}".format(host, port))
408+
409+
return response[0][4]
410+
397411
def __init__(
398412
self: "Connection",
399413
user: str,
@@ -593,7 +607,11 @@ def get_calling_module() -> str:
593607
self._usock.settimeout(timeout)
594608

595609
if unix_sock is None and host is not None:
596-
self._usock.connect((host, port))
610+
hostport: typing.Tuple[str, int] = Connection.__get_host_address_info(host, port)
611+
_logger.debug(
612+
"Attempting to create connection socket with address {} {}".format(hostport[0], str(hostport[1]))
613+
)
614+
self._usock.connect(hostport)
597615
elif unix_sock is not None:
598616
self._usock.connect(unix_sock)
599617

redshift_connector/iam_helper.py

Lines changed: 234 additions & 78 deletions
Large diffs are not rendered by default.

redshift_connector/idp_auth_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def set_auth_properties(info: RedshiftProperty):
108108
)
109109
info.put_all(resp)
110110

111-
if info.cluster_identifier is None and not info._is_serverless:
111+
if info.cluster_identifier is None and not info._is_serverless and not info.is_cname:
112112
raise InterfaceError(
113113
"Invalid connection property setting. cluster_identifier must be provided when IAM is enabled"
114114
)

redshift_connector/redshift_property.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import logging
12
import typing
23

34
from redshift_connector.config import DEFAULT_PROTOCOL_VERSION
4-
from redshift_connector.error import ProgrammingError
55

66
SERVERLESS_HOST_PATTERN: str = r"(.+)\.(.+).redshift-serverless(-dev)?\.amazonaws\.com(.)*"
77
SERVERLESS_WITH_WORKGROUP_HOST_PATTERN: str = r"(.+)\.(.+)\.(.+).redshift-serverless(-dev)?\.amazonaws\.com(.)*"
88
IAM_URL_PATTERN: str = r"^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_']"
9+
PROVISIONED_HOST_PATTERN: str = r"(.+)\.(.+)\.(.+).redshift(-dev)?\.amazonaws\.com(.)*"
10+
11+
_logger: logging.Logger = logging.getLogger(__name__)
912

1013

1114
class RedshiftProperty:
@@ -118,6 +121,8 @@ def __init__(self: "RedshiftProperty", **kwargs):
118121
self.serverless_acct_id: typing.Optional[str] = None
119122
self.serverless_work_group: typing.Optional[str] = None
120123
self.group_federation: bool = False
124+
# flag indicating if host name and RedshiftProperty indicate Redshift with custom domain name is used
125+
self.is_cname: bool = False
121126

122127
else:
123128
for k, v in kwargs.items():
@@ -162,9 +167,44 @@ def is_serverless_host(self: "RedshiftProperty") -> bool:
162167
)
163168

164169
@property
165-
def _is_serverless(self):
170+
def is_provisioned_host(self: "RedshiftProperty") -> bool:
171+
"""
172+
Returns True if host matches Regex for Redshift provisioned. Otherwise returns False.
173+
"""
174+
if not self.host:
175+
return False
176+
177+
import re
178+
179+
return bool(re.fullmatch(pattern=PROVISIONED_HOST_PATTERN, string=str(self.host)))
180+
181+
def set_is_cname(self: "RedshiftProperty") -> None:
182+
"""
183+
Sets RedshiftProperty is_cname attribute based on RedshiftProperty attribute values and host name Regex matching.
184+
"""
185+
is_cname: bool = False
186+
_logger.debug("determining if host indicates Redshift instance with custom name")
187+
188+
if self.is_provisioned_host:
189+
_logger.debug("cluster identified as Redshift provisioned")
190+
elif self.is_serverless_host:
191+
_logger.debug("cluster identified as Redshift serverless")
192+
elif self.is_serverless:
193+
if self.serverless_work_group is not None:
194+
_logger.debug("cluster identified as Redshift serverless with NLB")
195+
else:
196+
_logger.debug("cluster identified as Redshift serverless with with custom name")
197+
is_cname = True
198+
else:
199+
_logger.debug("cluster identified as Redshift provisioned with with custom name/NLB")
200+
is_cname = True
201+
202+
self.put(key="is_cname", value=is_cname)
203+
204+
@property
205+
def _is_serverless(self: "RedshiftProperty"):
166206
"""
167-
Returns True if host patches serverless pattern or if is_serverless flag set by user
207+
Returns True if host matches serverless pattern or if is_serverless flag set by user. Otherwise returns False.
168208
"""
169209
return self.is_serverless_host or self.is_serverless
170210

test/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ def serverless_iam_db_kwargs() -> typing.Dict[str, typing.Union[str, bool]]:
107107
return db_connect # type: ignore
108108

109109

110+
@pytest.fixture(scope="class")
111+
def provisioned_cname_db_kwargs() -> typing.Dict[str, str]:
112+
db_connect = {
113+
"database": conf.get("redshift-provisioned-cname", "database", fallback="mock_database"),
114+
"host": conf.get("redshift-provisioned-cname", "host", fallback="cname.mytest.com"),
115+
"db_user": conf.get("redshift-provisioned-cname", "db_user", fallback="mock_user"),
116+
"password": conf.get("redshift-provisioned-cname", "password", fallback="mock_password"),
117+
}
118+
119+
return db_connect
120+
121+
110122
@pytest.fixture(scope="class")
111123
def okta_idp() -> typing.Dict[str, typing.Union[str, bool, int]]:
112124
db_connect = {

test/integration/test_connection.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,11 @@ def fin():
6666
request.addfinalizer(fin)
6767

6868

69-
def test_socket_missing():
70-
conn_params = {
71-
"unix_sock": "/file-does-not-exist",
72-
"user": "doesn't-matter",
73-
"password": "hunter2",
74-
"database": "myDb",
75-
}
69+
def test_socket_missing(db_kwargs):
70+
db_kwargs["unix_sock"] = "/file-does-not-exist"
7671

7772
with pytest.raises(redshift_connector.InterfaceError):
78-
redshift_connector.connect(**conn_params)
73+
redshift_connector.connect(**db_kwargs)
7974

8075

8176
def test_database_missing(db_kwargs):
@@ -150,7 +145,7 @@ def test_unicode_database_name(db_kwargs):
150145

151146

152147
def test_bytes_database_name(db_kwargs):
153-
""" Should only raise an exception saying db doesn't exist """
148+
"""Should only raise an exception saying db doesn't exist"""
154149

155150
db_kwargs["database"] = bytes("redshift_connector_sn\uFF6Fw", "utf8")
156151
with pytest.raises(redshift_connector.ProgrammingError, match="3D000"):
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytest
2+
3+
import redshift_connector
4+
from redshift_connector.idp_auth_helper import SupportedSSLMode
5+
6+
"""
7+
These functional tests ensure connections to Redshift provisioned customer with custom domain name can be established
8+
when using various authentication methods.
9+
10+
Pre-requisites:
11+
1) Redshift provisioned configuration
12+
2) Existing custom domain association with instance created in step 1
13+
"""
14+
15+
16+
# @pytest.mark.skip(reason="manual")
17+
@pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL))
18+
def test_native_connect(provisioned_cname_db_kwargs, sslmode):
19+
# this test requires aws default profile contains valid credentials that provide permissions for
20+
# redshift:GetClusterCredentials ( Only called from this test method)
21+
import boto3
22+
23+
profile = "default"
24+
client = boto3.client(
25+
service_name="redshift",
26+
region_name="eu-north-1",
27+
)
28+
# fetch cluster credentials and pass them as driver connect parameters
29+
response = client.get_cluster_credentials(
30+
CustomDomainName=provisioned_cname_db_kwargs["host"], DbUser=provisioned_cname_db_kwargs["db_user"]
31+
)
32+
33+
provisioned_cname_db_kwargs["password"] = response["DbPassword"]
34+
provisioned_cname_db_kwargs["user"] = response["DbUser"]
35+
provisioned_cname_db_kwargs["profile"] = profile
36+
provisioned_cname_db_kwargs["ssl"] = True
37+
provisioned_cname_db_kwargs["sslmode"] = sslmode.value
38+
39+
with redshift_connector.connect(**provisioned_cname_db_kwargs):
40+
pass
41+
42+
43+
# @pytest.mark.skip(reason="manual")
44+
@pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL))
45+
def test_iam_connect(provisioned_cname_db_kwargs, sslmode):
46+
# this test requires aws default profile contains valid credentials that provide permissions for
47+
# redshift:GetClusterCredentials (called from driver)
48+
# redshift:DescribeClusters (called from driver)
49+
# redshift:DescribeCustomDomainAssociations (called from driver)
50+
provisioned_cname_db_kwargs["iam"] = True
51+
provisioned_cname_db_kwargs["profile"] = "default"
52+
provisioned_cname_db_kwargs["auto_create"] = True
53+
provisioned_cname_db_kwargs["region"] = "eu-north-1"
54+
provisioned_cname_db_kwargs["ssl"] = True
55+
provisioned_cname_db_kwargs["sslmode"] = sslmode.value
56+
with redshift_connector.connect(**provisioned_cname_db_kwargs):
57+
pass
58+
59+
60+
def test_idp_connect(okta_idp, provisioned_cname_db_kwargs):
61+
# todo
62+
pass
63+
64+
65+
# @pytest.mark.skip(reason="manual")
66+
def test_nlb_connect():
67+
args = {
68+
"iam": True,
69+
# "access_key_id": "xxx",
70+
# "secret_access_key": "xxx",
71+
"cluster_identifier": "replace-me",
72+
"region": "us-east-1",
73+
"host": "replace-me",
74+
"port": 5439,
75+
"database": "dev",
76+
"db_user": "replace-me",
77+
}
78+
with redshift_connector.connect(**args):
79+
pass

0 commit comments

Comments
 (0)