1717from scramp import ScramClient # type: ignore
1818
1919from redshift_connector .config import (
20+ DEFAULT_PROTOCOL_VERSION ,
21+ ClientProtocolVersion ,
2022 _client_encoding ,
2123 max_int2 ,
2224 max_int4 ,
99101
100102__author__ = "Mathieu Fenniak"
101103
102-
103104ZERO : Timedelta = Timedelta (0 )
104105BINARY : type = bytes
105106
@@ -321,7 +322,6 @@ def create_message(code: bytes, data: bytes = b"") -> bytes:
321322IDLE_IN_TRANSACTION : bytes = b"T"
322323IDLE_IN_FAILED_TRANSACTION : bytes = b"E"
323324
324-
325325arr_trans : typing .Mapping [int , typing .Optional [str ]] = dict (zip (map (ord , "[] 'u" ), ["{" , "}" , None , None , None ]))
326326
327327
@@ -363,6 +363,7 @@ def __init__(
363363 tcp_keepalive : typing .Optional [bool ] = True ,
364364 application_name : typing .Optional [str ] = None ,
365365 replication : typing .Optional [str ] = None ,
366+ client_protocol_version : int = DEFAULT_PROTOCOL_VERSION ,
366367 ):
367368
368369 self .merge_socket_read = False
@@ -382,6 +383,7 @@ def __init__(
382383 self .parameter_statuses : deque = deque (maxlen = 100 )
383384 self .max_prepared_statements : int = int (max_prepared_statements )
384385 self ._run_cursor : Cursor = Cursor (self , paramstyle = "named" )
386+ self ._client_protocol_version : int = client_protocol_version
385387
386388 if user is None :
387389 raise InterfaceError ("The 'user' connection parameter cannot be None" )
@@ -391,6 +393,7 @@ def __init__(
391393 "database" : database ,
392394 "application_name" : application_name ,
393395 "replication" : replication ,
396+ "client_protocol_version" : str (self ._client_protocol_version ),
394397 }
395398
396399 for k , v in tuple (init_params .items ()):
@@ -551,6 +554,15 @@ def __init__(
551554 if self .error is not None :
552555 raise self .error
553556
557+ # if we didn't receive a server_protocol_version from the server, default to
558+ # using BASE_SERVER as the server is likely lacking this functionality due to
559+ # being out of date
560+ if (
561+ self ._client_protocol_version > ClientProtocolVersion .BASE_SERVER
562+ and not (b"server_protocol_version" , str (self ._client_protocol_version ).encode ()) in self .parameter_statuses
563+ ):
564+ self ._client_protocol_version = ClientProtocolVersion .BASE_SERVER
565+
554566 self .in_transaction = False
555567
556568 def handle_ERROR_RESPONSE (self : "Connection" , data , ps ):
@@ -845,24 +857,37 @@ def make_params(self: "Connection", values):
845857 # get the metadata of each row in database
846858 # and store these metadata into ps dictionary
847859 def handle_ROW_DESCRIPTION (self : "Connection" , data , cursor : Cursor ) -> None :
860+ if cursor .ps is None :
861+ raise InterfaceError ("Cursor is missing prepared statement" )
862+ elif "row_desc" not in cursor .ps :
863+ raise InterfaceError ("Prepared Statement is missing row description" )
864+
848865 count : int = h_unpack (data )[0 ]
849866 idx = 2
850867 for i in range (count ):
851- name = data [idx : data .find (NULL_BYTE , idx )]
852- idx += len (name ) + 1
868+ column_label = data [idx : data .find (NULL_BYTE , idx )]
869+ idx += len (column_label ) + 1
870+
853871 field : typing .Dict = dict (
854872 zip (
855873 ("table_oid" , "column_attrnum" , "type_oid" , "type_size" , "type_modifier" , "format" ),
856874 ihihih_unpack (data , idx ),
857875 )
858876 )
859- field ["name " ] = name
877+ field ["label " ] = column_label
860878 idx += 18
861879
862- if cursor .ps is None :
863- raise InterfaceError ("Cursor is missing prepared statement" )
864- elif "row_desc" not in cursor .ps :
865- raise InterfaceError ("Prepared Statement is missing row description" )
880+ if self ._client_protocol_version >= ClientProtocolVersion .EXTENDED_RESULT_METADATA :
881+ for entry in ("schema_name" , "table_name" , "column_name" , "catalog_name" ):
882+ field [entry ] = data [idx : data .find (NULL_BYTE , idx )]
883+ idx += len (field [entry ]) + 1
884+
885+ temp : int = h_unpack (data , idx )[0 ]
886+ field ["nullable" ] = temp & 0x1
887+ field ["autoincrement" ] = (temp >> 4 ) & 0x1
888+ field ["read_only" ] = (temp >> 8 ) & 0x1
889+ field ["searchable" ] = (temp >> 12 ) & 0x1
890+ idx += 2
866891
867892 cursor .ps ["row_desc" ].append (field )
868893 field ["pg8000_fc" ], field ["func" ] = pg_types [field ["type_oid" ]]
@@ -1178,6 +1203,18 @@ def handle_PARAMETER_STATUS(self: "Connection", data: bytes, ps) -> None:
11781203 if key == b"client_encoding" :
11791204 encoding = value .decode ("ascii" ).lower ()
11801205 _client_encoding = pg_to_py_encodings .get (encoding , encoding )
1206+ elif key == b"server_protocol_version" :
1207+ # when a mismatch occurs between the client's requested protocol version, and the server's response,
1208+ # warn the user and follow server
1209+ if self ._client_protocol_version != int (value ):
1210+ warn (
1211+ "Server indicated {} transfer protocol will be used rather than protocol requested by client: {}" .format (
1212+ ClientProtocolVersion .get_name (int (value )),
1213+ ClientProtocolVersion .get_name (self ._client_protocol_version ),
1214+ ),
1215+ stacklevel = 3 ,
1216+ )
1217+ self ._client_protocol_version = int (value )
11811218 elif key == b"server_version" :
11821219 self ._server_version : LooseVersion = LooseVersion (value .decode ("ascii" ))
11831220 if self ._server_version < LooseVersion ("8.2.0" ):
0 commit comments