Skip to content

Commit d95eccf

Browse files
committed
feat(Connection): add initializer parameter for server protocol version, default to extended metadata
1 parent 49b3e99 commit d95eccf

File tree

9 files changed

+217
-10
lines changed

9 files changed

+217
-10
lines changed

redshift_connector/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import typing
44

5+
from redshift_connector.config import DEFAULT_PROTOCOL_VERSION
56
from redshift_connector.core import BINARY, Connection, Cursor
67
from redshift_connector.error import (
78
ArrayContentNotHomogenousError,
@@ -110,6 +111,7 @@ def connect(
110111
db_groups: typing.List[str] = list(),
111112
force_lowercase: bool = False,
112113
allow_db_user_override: bool = False,
114+
client_protocol_version: int = DEFAULT_PROTOCOL_VERSION,
113115
log_level: int = 0,
114116
log_path: str = log_path,
115117
) -> Connection:
@@ -163,6 +165,7 @@ def connect(
163165
db_groups=db_groups,
164166
force_lowercase=force_lowercase,
165167
allow_db_user_override=allow_db_user_override,
168+
client_protocol_version=client_protocol_version,
166169
)
167170

168171
return Connection(
@@ -180,6 +183,7 @@ def connect(
180183
tcp_keepalive=info.tcp_keepalive,
181184
application_name=info.application_name,
182185
replication=info.replication,
186+
client_protocol_version=info.client_protocol_version,
183187
)
184188

185189

redshift_connector/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,32 @@
33
from calendar import timegm
44
from datetime import datetime as Datetime
55
from datetime import timezone as Timezone
6+
from enum import IntEnum
67

78
FC_TEXT: int = 0
89
FC_BINARY: int = 1
910
_client_encoding: str = "utf8"
1011

12+
13+
class ClientProtocolVersion(IntEnum):
14+
BASE_SERVER = 0
15+
EXTENDED_RESULT_METADATA = 1
16+
BINARY = 2
17+
18+
@classmethod
19+
def list(cls) -> typing.List[int]:
20+
return list(map(lambda p: p.value, cls)) # type: ignore
21+
22+
@classmethod
23+
def get_name(cls, i: int) -> str:
24+
try:
25+
return ClientProtocolVersion(i).name
26+
except ValueError:
27+
return str(i)
28+
29+
30+
DEFAULT_PROTOCOL_VERSION: int = ClientProtocolVersion.EXTENDED_RESULT_METADATA.value
31+
1132
min_int2: int = -(2 ** 15)
1233
max_int2: int = 2 ** 15
1334
min_int4: int = -(2 ** 31)

redshift_connector/core.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from scramp import ScramClient # type: ignore
1818

1919
from redshift_connector.config import (
20+
DEFAULT_PROTOCOL_VERSION,
21+
ClientProtocolVersion,
2022
_client_encoding,
2123
max_int2,
2224
max_int4,
@@ -99,7 +101,6 @@
99101

100102
__author__ = "Mathieu Fenniak"
101103

102-
103104
ZERO: Timedelta = Timedelta(0)
104105
BINARY: type = bytes
105106

@@ -321,7 +322,6 @@ def create_message(code: bytes, data: bytes = b"") -> bytes:
321322
IDLE_IN_TRANSACTION: bytes = b"T"
322323
IDLE_IN_FAILED_TRANSACTION: bytes = b"E"
323324

324-
325325
arr_trans: typing.Mapping[int, typing.Optional[str]] = dict(zip(map(ord, "[] 'u"), ["{", "}", None, None, None]))
326326

327327

@@ -363,6 +363,7 @@ def __init__(
363363
tcp_keepalive: typing.Optional[bool] = True,
364364
application_name: typing.Optional[str] = None,
365365
replication: typing.Optional[str] = None,
366+
client_protocol_version: int = DEFAULT_PROTOCOL_VERSION,
366367
):
367368

368369
self.merge_socket_read = False
@@ -382,6 +383,7 @@ def __init__(
382383
self.parameter_statuses: deque = deque(maxlen=100)
383384
self.max_prepared_statements: int = int(max_prepared_statements)
384385
self._run_cursor: Cursor = Cursor(self, paramstyle="named")
386+
self._client_protocol_version: int = client_protocol_version
385387

386388
if user is None:
387389
raise InterfaceError("The 'user' connection parameter cannot be None")
@@ -391,6 +393,7 @@ def __init__(
391393
"database": database,
392394
"application_name": application_name,
393395
"replication": replication,
396+
"client_protocol_version": str(self._client_protocol_version),
394397
}
395398

396399
for k, v in tuple(init_params.items()):
@@ -551,6 +554,15 @@ def __init__(
551554
if self.error is not None:
552555
raise self.error
553556

557+
# if we didn't receive a server_protocol_version from the server, default to
558+
# using BASE_SERVER as the server is likely lacking this functionality due to
559+
# being out of date
560+
if (
561+
self._client_protocol_version > ClientProtocolVersion.BASE_SERVER
562+
and not (b"server_protocol_version", str(self._client_protocol_version).encode()) in self.parameter_statuses
563+
):
564+
self._client_protocol_version = ClientProtocolVersion.BASE_SERVER
565+
554566
self.in_transaction = False
555567

556568
def handle_ERROR_RESPONSE(self: "Connection", data, ps):
@@ -845,24 +857,37 @@ def make_params(self: "Connection", values):
845857
# get the metadata of each row in database
846858
# and store these metadata into ps dictionary
847859
def handle_ROW_DESCRIPTION(self: "Connection", data, cursor: Cursor) -> None:
860+
if cursor.ps is None:
861+
raise InterfaceError("Cursor is missing prepared statement")
862+
elif "row_desc" not in cursor.ps:
863+
raise InterfaceError("Prepared Statement is missing row description")
864+
848865
count: int = h_unpack(data)[0]
849866
idx = 2
850867
for i in range(count):
851-
name = data[idx : data.find(NULL_BYTE, idx)]
852-
idx += len(name) + 1
868+
column_label = data[idx : data.find(NULL_BYTE, idx)]
869+
idx += len(column_label) + 1
870+
853871
field: typing.Dict = dict(
854872
zip(
855873
("table_oid", "column_attrnum", "type_oid", "type_size", "type_modifier", "format"),
856874
ihihih_unpack(data, idx),
857875
)
858876
)
859-
field["name"] = name
877+
field["label"] = column_label
860878
idx += 18
861879

862-
if cursor.ps is None:
863-
raise InterfaceError("Cursor is missing prepared statement")
864-
elif "row_desc" not in cursor.ps:
865-
raise InterfaceError("Prepared Statement is missing row description")
880+
if self._client_protocol_version >= ClientProtocolVersion.EXTENDED_RESULT_METADATA:
881+
for entry in ("schema_name", "table_name", "column_name", "catalog_name"):
882+
field[entry] = data[idx : data.find(NULL_BYTE, idx)]
883+
idx += len(field[entry]) + 1
884+
885+
temp: int = h_unpack(data, idx)[0]
886+
field["nullable"] = temp & 0x1
887+
field["autoincrement"] = (temp >> 4) & 0x1
888+
field["read_only"] = (temp >> 8) & 0x1
889+
field["searchable"] = (temp >> 12) & 0x1
890+
idx += 2
866891

867892
cursor.ps["row_desc"].append(field)
868893
field["pg8000_fc"], field["func"] = pg_types[field["type_oid"]]
@@ -1178,6 +1203,18 @@ def handle_PARAMETER_STATUS(self: "Connection", data: bytes, ps) -> None:
11781203
if key == b"client_encoding":
11791204
encoding = value.decode("ascii").lower()
11801205
_client_encoding = pg_to_py_encodings.get(encoding, encoding)
1206+
elif key == b"server_protocol_version":
1207+
# when a mismatch occurs between the client's requested protocol version, and the server's response,
1208+
# warn the user and follow server
1209+
if self._client_protocol_version != int(value):
1210+
warn(
1211+
"Server indicated {} transfer protocol will be used rather than protocol requested by client: {}".format(
1212+
ClientProtocolVersion.get_name(int(value)),
1213+
ClientProtocolVersion.get_name(self._client_protocol_version),
1214+
),
1215+
stacklevel=3,
1216+
)
1217+
self._client_protocol_version = int(value)
11811218
elif key == b"server_version":
11821219
self._server_version: LooseVersion = LooseVersion(value.decode("ascii"))
11831220
if self._server_version < LooseVersion("8.2.0"):

redshift_connector/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _getDescription(self: "Cursor") -> typing.Optional[typing.List[typing.Option
112112
return None
113113
columns: typing.List[typing.Optional[typing.Tuple]] = []
114114
for col in row_desc:
115-
columns.append((col["name"], col["type_oid"], None, None, None, None, None))
115+
columns.append((col["label"], col["type_oid"], None, None, None, None, None))
116116
return columns
117117

118118
##

redshift_connector/iam_helper.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import boto3 # type: ignore
66
import botocore # type: ignore
77

8+
from redshift_connector.config import ClientProtocolVersion
89
from redshift_connector.credentials_holder import CredentialsHolder
910
from redshift_connector.error import InterfaceError
1011
from redshift_connector.plugin import (
@@ -65,6 +66,7 @@ def set_iam_properties(
6566
db_groups: typing.List[str],
6667
force_lowercase: bool,
6768
allow_db_user_override: bool,
69+
client_protocol_version: int,
6870
) -> None:
6971
if info is None:
7072
raise InterfaceError("Invalid connection property setting. info must be specified")
@@ -108,6 +110,13 @@ def set_iam_properties(
108110
if password is None:
109111
raise InterfaceError("Invalid connection property setting. password must be specified")
110112

113+
if client_protocol_version not in ClientProtocolVersion.list():
114+
raise InterfaceError(
115+
"Invalid connection property setting. client_protocol_version must be in: {}".format(
116+
ClientProtocolVersion.list()
117+
)
118+
)
119+
111120
# basic driver parameters
112121
info.user_name = user
113122
info.host = host
@@ -121,6 +130,7 @@ def set_iam_properties(
121130
info.tcp_keepalive = tcp_keepalive
122131
info.application_name = application_name
123132
info.replication = replication
133+
info.client_protocol_version = client_protocol_version
124134

125135
# Idp parameters
126136
info.idp_host = idp_host

redshift_connector/redshift_property.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import typing
22
from typing import TYPE_CHECKING
33

4+
from redshift_connector.config import DEFAULT_PROTOCOL_VERSION
5+
46
if TYPE_CHECKING:
57
from redshift_connector.iam_helper import SSLMode
68

@@ -71,6 +73,8 @@ class RedshiftProperty:
7173
max_prepared_statements: int = 1000
7274
# Use this property to enable or disable TCP keepalives. The following values are possible:
7375
tcp_keepalive: bool = True
76+
# client's requested transfer protocol version. See config.py for supported protocols
77+
client_protocol_version: int = DEFAULT_PROTOCOL_VERSION
7478
# application name
7579
application_name: typing.Optional[str] = None
7680
# Used to run in streaming replication mode. If your server character encoding is not ascii or utf8,

test/integration/test_connection.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest # type: ignore
88

99
import redshift_connector
10+
from redshift_connector.config import ClientProtocolVersion
1011

1112
conf = configparser.ConfigParser()
1213
root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -181,3 +182,18 @@ def test_scram_sha_256(db_kwargs):
181182
# Should only raise an exception saying db doesn't exist
182183
with pytest.raises(redshift_connector.ProgrammingError, match="3D000"):
183184
redshift_connector.connect(**db_kwargs)
185+
186+
187+
@pytest.mark.parametrize("_input", ClientProtocolVersion.list()[:-1])
188+
def test_client_protocol_version_is_used(db_kwargs, _input):
189+
db_kwargs["client_protocol_version"] = _input
190+
191+
with redshift_connector.connect(**db_kwargs) as conn:
192+
assert conn._client_protocol_version == _input
193+
194+
195+
def test_client_protocol_version_invalid_warns_user(db_kwargs):
196+
db_kwargs["client_protocol_version"] = max(ClientProtocolVersion.list()) + 1
197+
198+
with pytest.warns(UserWarning):
199+
redshift_connector.Connection(**db_kwargs)

0 commit comments

Comments
 (0)