Skip to content

Commit df7fc1d

Browse files
committed
fix(cursor, execute): bind param parsing for multiline comments, colon character issue
1 parent f9af45b commit df7fc1d

File tree

7 files changed

+158
-59
lines changed

7 files changed

+158
-59
lines changed

redshift_connector/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import typing
33

44
from redshift_connector import plugin
5-
from redshift_connector.config import DEFAULT_PROTOCOL_VERSION, ClientProtocolVersion
5+
from redshift_connector.config import (
6+
DEFAULT_PROTOCOL_VERSION,
7+
ClientProtocolVersion,
8+
DbApiParamstyle,
9+
)
610
from redshift_connector.core import BINARY, Connection, Cursor
711
from redshift_connector.error import (
812
ArrayContentNotHomogenousError,

redshift_connector/config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from calendar import timegm
44
from datetime import datetime as Datetime
55
from datetime import timezone as Timezone
6-
from enum import IntEnum
6+
from enum import Enum, IntEnum
77

88
FC_TEXT: int = 0
99
FC_BINARY: int = 1
@@ -29,6 +29,19 @@ def get_name(cls, i: int) -> str:
2929

3030
DEFAULT_PROTOCOL_VERSION: int = ClientProtocolVersion.BINARY.value
3131

32+
33+
class DbApiParamstyle(Enum):
34+
QMARK = "qmark"
35+
NUMERIC = "numeric"
36+
NAMED = "named"
37+
FORMAT = "format"
38+
PYFORMAT = "pyformat"
39+
40+
@classmethod
41+
def list(cls) -> typing.List[int]:
42+
return list(map(lambda p: p.value, cls)) # type: ignore
43+
44+
3245
min_int2: int = -(2 ** 15)
3346
max_int2: int = 2 ** 15
3447
min_int4: int = -(2 ** 31)

redshift_connector/core.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from redshift_connector.config import (
2121
DEFAULT_PROTOCOL_VERSION,
2222
ClientProtocolVersion,
23+
DbApiParamstyle,
2324
_client_encoding,
2425
max_int2,
2526
max_int4,
@@ -145,6 +146,7 @@ def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]:
145146
INSIDE_ES: int = 3 # inside escaped single-quote string, E'...'
146147
INSIDE_PN: int = 4 # inside parameter name eg. :name
147148
INSIDE_CO: int = 5 # inside inline comment eg. --
149+
INSIDE_MC: int = 6 # inside multiline comment eg. /*
148150

149151
in_quote_escape: bool = False
150152
in_param_escape: bool = False
@@ -173,23 +175,27 @@ def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]:
173175
output_query.append(c)
174176
if prev_c == "-":
175177
state = INSIDE_CO
176-
elif style == "qmark" and c == "?":
178+
elif c == "*":
179+
output_query.append(c)
180+
if prev_c == "/":
181+
state = INSIDE_MC
182+
elif style == DbApiParamstyle.QMARK.value and c == "?":
177183
output_query.append(next(param_idx))
178-
elif style == "numeric" and c == ":" and next_c not in ":=" and prev_c != ":":
184+
elif style == DbApiParamstyle.NUMERIC.value and c == ":" and next_c not in ":=" and prev_c != ":":
179185
# Treat : as beginning of parameter name if and only
180186
# if it's the only : around
181187
# Needed to properly process type conversions
182188
# i.e. sum(x)::float
183189
output_query.append("$")
184-
elif style == "named" and c == ":" and next_c not in ":=" and prev_c != ":":
190+
elif style == DbApiParamstyle.NAMED.value and c == ":" and next_c not in ":=" and prev_c != ":":
185191
# Same logic for : as in numeric parameters
186192
state = INSIDE_PN
187193
placeholders.append("")
188-
elif style == "pyformat" and c == "%" and next_c == "(":
194+
elif style == DbApiParamstyle.PYFORMAT.value and c == "%" and next_c == "(":
189195
state = INSIDE_PN
190196
placeholders.append("")
191-
elif style in ("format", "pyformat") and c == "%":
192-
style = "format"
197+
elif style in (DbApiParamstyle.FORMAT.value, DbApiParamstyle.PYFORMAT.value) and c == "%":
198+
style = DbApiParamstyle.FORMAT.value
193199
if in_param_escape:
194200
in_param_escape = False
195201
output_query.append(c)
@@ -227,7 +233,7 @@ def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]:
227233
output_query.append(c)
228234

229235
elif state == INSIDE_PN:
230-
if style == "named":
236+
if style == DbApiParamstyle.NAMED.value:
231237
placeholders[-1] += c
232238
if next_c is None or (not next_c.isalnum() and next_c != "_"):
233239
state = OUTSIDE
@@ -237,7 +243,7 @@ def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]:
237243
del placeholders[-1]
238244
except ValueError:
239245
output_query.append("$" + str(len(placeholders)))
240-
elif style == "pyformat":
246+
elif style == DbApiParamstyle.PYFORMAT.value:
241247
if prev_c == ")" and c == "s":
242248
state = OUTSIDE
243249
try:
@@ -250,17 +256,22 @@ def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]:
250256
pass
251257
else:
252258
placeholders[-1] += c
253-
elif style == "format":
259+
elif style == DbApiParamstyle.FORMAT.value:
254260
state = OUTSIDE
255261

256262
elif state == INSIDE_CO:
257263
output_query.append(c)
258264
if c == "\n":
259265
state = OUTSIDE
260266

267+
elif state == INSIDE_MC:
268+
output_query.append(c)
269+
if c == "/" and prev_c == "*":
270+
state = OUTSIDE
271+
261272
prev_c = c
262273

263-
if style in ("numeric", "qmark", "format"):
274+
if style in (DbApiParamstyle.NUMERIC.value, DbApiParamstyle.QMARK.value, DbApiParamstyle.FORMAT.value):
264275

265276
def make_args(vals):
266277
return vals
@@ -466,7 +477,7 @@ def __init__(
466477
self.notices: deque = deque(maxlen=100)
467478
self.parameter_statuses: deque = deque(maxlen=100)
468479
self.max_prepared_statements: int = int(max_prepared_statements)
469-
self._run_cursor: Cursor = Cursor(self, paramstyle="named")
480+
self._run_cursor: Cursor = Cursor(self, paramstyle=DbApiParamstyle.NAMED.value)
470481
self._client_protocol_version: int = client_protocol_version
471482
self._database = database
472483
self.py_types = deepcopy(PY_TYPES)
@@ -1600,7 +1611,7 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
16001611
args: typing.Tuple[typing.Optional[typing.Tuple[str, typing.Any]], ...] = ()
16011612
# transforms user provided bind parameters to server friendly bind parameters
16021613
params: typing.Tuple[typing.Optional[typing.Tuple[int, int, typing.Callable]], ...] = ()
1603-
1614+
has_bind_parameters: bool = False if vals is None else True
16041615
# multi dimensional dictionary to store the data
16051616
# cache = self._caches[cursor.paramstyle][pid]
16061617
# cache = {'statement': {}, 'ps': {}}
@@ -1622,12 +1633,12 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
16221633
try:
16231634
statement, make_args = cache["statement"][operation]
16241635
except KeyError:
1625-
if vals:
1636+
if has_bind_parameters:
16261637
statement, make_args = cache["statement"][operation] = convert_paramstyle(cursor.paramstyle, operation)
16271638
else:
16281639
# use a no-op make_args in lieu of parsing the sql statement
16291640
statement, make_args = cache["statement"][operation] = operation, lambda p: ()
1630-
if vals:
1641+
if has_bind_parameters:
16311642
args = make_args(vals)
16321643
# change the args to the format that the DB will identify
16331644
# take reference from self.py_types

redshift_connector/cursor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import redshift_connector
1111
from redshift_connector.config import (
1212
ClientProtocolVersion,
13+
DbApiParamstyle,
1314
_client_encoding,
1415
table_type_clauses,
1516
)
@@ -355,7 +356,7 @@ def __has_valid_columns(self: "Cursor", table: str, columns: typing.List[str]) -
355356
else:
356357
param_list = [[split_table_name[0], c] for c in columns]
357358
temp = self.paramstyle
358-
self.paramstyle = "qmark"
359+
self.paramstyle = DbApiParamstyle.QMARK.value
359360
try:
360361
for params in param_list:
361362
self.execute(q, params)
@@ -376,7 +377,7 @@ def callproc(self, procname, parameters=None):
376377
from redshift_connector.core import convert_paramstyle
377378

378379
try:
379-
statement, make_args = convert_paramstyle("format", operation)
380+
statement, make_args = convert_paramstyle(DbApiParamstyle.FORMAT.value, operation)
380381
vals = make_args(args)
381382
self.execute(statement, vals)
382383

@@ -534,7 +535,7 @@ def __is_valid_table(self: "Cursor", table: str) -> bool:
534535
q: str = "select 1 from information_schema.tables where table_name = ?"
535536

536537
temp = self.paramstyle
537-
self.paramstyle = "qmark"
538+
self.paramstyle = DbApiParamstyle.QMARK.value
538539

539540
try:
540541
if len(split_table_name) == 2:
@@ -643,7 +644,7 @@ def get_procedures(
643644
if len(query_args) > 0:
644645
# temporarily use qmark paramstyle
645646
temp = self.paramstyle
646-
self.paramstyle = "qmark"
647+
self.paramstyle = DbApiParamstyle.QMARK.value
647648

648649
try:
649650
self.execute(sql, tuple(query_args))
@@ -721,7 +722,7 @@ def get_schemas(
721722
if len(query_args) == 1:
722723
# temporarily use qmark paramstyle
723724
temp = self.paramstyle
724-
self.paramstyle = "qmark"
725+
self.paramstyle = DbApiParamstyle.QMARK.value
725726
try:
726727
self.execute(sql, tuple(query_args))
727728
except:
@@ -774,7 +775,7 @@ def get_primary_keys(
774775
if len(query_args) > 0:
775776
# temporarily use qmark paramstyle
776777
temp = self.paramstyle
777-
self.paramstyle = "qmark"
778+
self.paramstyle = DbApiParamstyle.QMARK.value
778779
try:
779780
self.execute(sql, tuple(query_args))
780781
except:
@@ -855,7 +856,7 @@ def get_tables(
855856

856857
if len(sql_args) > 0:
857858
temp = self.paramstyle
858-
self.paramstyle = "qmark"
859+
self.paramstyle = DbApiParamstyle.QMARK.value
859860
try:
860861
self.execute(sql, sql_args)
861862
except:

test/integration/test_dbapi20.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,15 @@ def _paraminsert(cur):
184184
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (table_prefix))
185185
assert cur.rowcount in (-1, 1)
186186

187-
if driver.paramstyle == "qmark":
187+
if driver.paramstyle == redshift_connector.config.DbApiParamstyle.QMARK.value:
188188
cur.execute("insert into %sbooze values (?)" % table_prefix, ("Cooper's",))
189-
elif driver.paramstyle == "numeric":
189+
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.NUMERIC.value:
190190
cur.execute("insert into %sbooze values (:1)" % table_prefix, ("Cooper's",))
191-
elif driver.paramstyle == "named":
191+
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.NAMED.value:
192192
cur.execute("insert into %sbooze values (:beer)" % table_prefix, {"beer": "Cooper's"})
193-
elif driver.paramstyle == "format":
193+
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.FORMAT.value:
194194
cur.execute("insert into %sbooze values (%%s)" % table_prefix, ("Cooper's",))
195-
elif driver.paramstyle == "pyformat":
195+
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.PYFORMAT.value:
196196
cur.execute("insert into %sbooze values (%%(beer)s)" % table_prefix, {"beer": "Cooper's"})
197197
else:
198198
assert False, "Invalid paramstyle"
@@ -212,15 +212,15 @@ def test_executemany(cursor):
212212
execute_ddl_1(cursor)
213213
largs: typing.List[typing.Tuple[str]] = [("Cooper's",), ("Boag's",)]
214214
margs: typing.List[typing.Dict[str, str]] = [{"beer": "Cooper's"}, {"beer": "Boag's"}]
215-
if driver.paramstyle == "qmark":
215+
if driver.paramstyle == redshift_connector.config.DbApiParamstyle.QMARK.value:
216216
cursor.executemany("insert into %sbooze values (?)" % table_prefix, largs)
217-
elif driver.paramstyle == "numeric":
217+
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.NUMERIC.value:
218218
cursor.executemany("insert into %sbooze values (:1)" % table_prefix, largs)
219-
elif driver.paramstyle == "named":
219+
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.NAMED.value:
220220
cursor.executemany("insert into %sbooze values (:beer)" % table_prefix, margs)
221-
elif driver.paramstyle == "format":
221+
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.FORMAT.value:
222222
cursor.executemany("insert into %sbooze values (%%s)" % table_prefix, largs)
223-
elif driver.paramstyle == "pyformat":
223+
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.PYFORMAT.value:
224224
cursor.executemany("insert into %sbooze values (%%(beer)s)" % (table_prefix), margs)
225225
else:
226226
assert False, "Unknown paramstyle"

test/unit/test_dbapi20.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ def test_threadsafety():
2424

2525

2626
def test_paramstyle():
27+
from redshift_connector.config import DbApiParamstyle
28+
2729
try:
2830
# Must exist
2931
paramstyle: str = driver.paramstyle
3032
# Must be a valid value
31-
assert paramstyle in ("qmark", "numeric", "named", "format", "pyformat")
33+
assert paramstyle in DbApiParamstyle.list()
3234
except AttributeError:
3335
assert False, "Driver doesn't define paramstyle"
3436

0 commit comments

Comments
 (0)