Skip to content

Commit 82835e1

Browse files
committed
feat(cursor): Add redshift_rowcount for SELECT rowcount support
1 parent 45dfe66 commit 82835e1

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

redshift_connector/core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,7 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
12411241

12421242
cursor._cached_rows.clear()
12431243
cursor._row_count = -1
1244+
cursor._redshift_row_count = -1
12441245

12451246
# Byte1('B') - Identifies the Bind command.
12461247
# Int32 - Message length, including self.
@@ -1315,6 +1316,10 @@ def handle_NO_DATA(self: "Connection", msg, ps) -> None:
13151316
# how many SQL commands was completed
13161317
# not support for 'SELECT' and 'COPY' query
13171318
def handle_COMMAND_COMPLETE(self: "Connection", data: bytes, cursor: Cursor) -> None:
1319+
"""
1320+
Modifies the cursor object and prepared statement when receiving COMMAND COMPLETE b'C' message from Redshift
1321+
server.
1322+
"""
13181323
values: typing.List[bytes] = data[:-1].split(b" ")
13191324
command = values[0]
13201325
if command in self._commands_with_count:
@@ -1323,6 +1328,12 @@ def handle_COMMAND_COMPLETE(self: "Connection", data: bytes, cursor: Cursor) ->
13231328
cursor._row_count = row_count
13241329
else:
13251330
cursor._row_count += row_count
1331+
cursor._redshift_row_count = cursor._row_count
1332+
elif command == b"SELECT":
1333+
# Redshift server does not support row count for SELECT statement
1334+
# so we derive this from the size of the rows associated with the
1335+
# cursor object
1336+
cursor._redshift_row_count = len(cursor._cached_rows)
13261337

13271338
if command in (b"ALTER", b"CREATE"):
13281339
for scache in self._caches.values():

redshift_connector/cursor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(self: "Cursor", connection: "Connection", paramstyle=None) -> None:
9595
self.arraysize: int = 1
9696
self.ps: typing.Optional[typing.Dict[str, typing.Any]] = None
9797
self._row_count: int = -1
98+
self._redshift_row_count: int = -1
9899
self._cached_rows: deque = deque()
99100
if paramstyle is None:
100101
self.paramstyle: str = redshift_connector.paramstyle
@@ -116,8 +117,28 @@ def connection(self: "Cursor") -> typing.Optional["Connection"]:
116117

117118
@property
118119
def rowcount(self: "Cursor") -> int:
120+
"""
121+
This read-only attribute specifies the number of rows that the last .execute*() produced
122+
(for DQL statements like SELECT) or affected (for DML statements like UPDATE or INSERT).
123+
124+
The attribute is -1 in case no .execute*() has been performed on the cursor or the rowcount of the last
125+
operation is cannot be determined by the interface.
126+
"""
119127
return self._row_count
120128

129+
@property
130+
def redshift_rowcount(self: "Cursor") -> int:
131+
"""
132+
Native to ``redshift_connector``, this read-only attribute specifies the number of rows that the last .execute*() produced.
133+
134+
For DQL statements (like SELECT) the number of rows is derived by ``redshift_connector`` rather than
135+
provided by the server. For DML statements (like UPDATE or INSERT) this value is provided by the server.
136+
137+
This property's behavior is subject to change inline with modifications made to query execution.
138+
Use ``rowcount`` as an alternative to this property.
139+
"""
140+
return self._redshift_row_count
141+
121142
@typing.no_type_check
122143
@functools.lru_cache()
123144
def truncated_row_desc(self: "Cursor"):
@@ -232,11 +253,14 @@ def executemany(self: "Cursor", operation, param_sets) -> "Cursor":
232253
The Cursor object used for executing the specified database operation: :class:`Cursor`
233254
"""
234255
rowcounts: typing.List[int] = []
256+
redshift_rowcounts: typing.List[int] = []
235257
for parameters in param_sets:
236258
self.execute(operation, parameters)
237259
rowcounts.append(self._row_count)
260+
redshift_rowcounts.append(self._redshift_row_count)
238261

239262
self._row_count = -1 if -1 in rowcounts else sum(rowcounts)
263+
self._redshift_row_count = -1 if -1 in redshift_rowcounts else sum(rowcounts)
240264
return self
241265

242266
def insert_data_bulk(

test/integration/test_query.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,21 @@ def test_insert_returning(db_table):
104104
# Test INSERT ... RETURNING with one row...
105105
cursor.execute("INSERT INTO t2 VALUES (%s, %s)", (row_id, "test1"))
106106

107-
assert cursor.rowcount == 1
107+
assert 1 == cursor.rowcount
108+
assert 1 == cursor.redshift_rowcount
108109

109110
cursor.execute("SELECT data FROM t2 WHERE id = %s", (row_id,))
111+
assert 1 == cursor.redshift_rowcount
112+
110113
assert "test1" == cursor.fetchone()[0]
111114

112115
# Test with multiple rows...
113116
cursor.execute("INSERT INTO t2 VALUES (2, 'test2'), (3, 'test3'), (4,'test4') ")
114-
assert cursor.rowcount == 3
117+
assert 3 == cursor.rowcount
118+
assert 3 == cursor.redshift_rowcount
119+
115120
cursor.execute("SELECT * FROM t2")
121+
assert 4 == cursor.redshift_rowcount
116122
ids: typing.Tuple[typing.List[typing.Union[int, str], ...]] = cursor.fetchall()
117123
assert len(ids) == 4
118124

@@ -134,12 +140,14 @@ def test_row_count(db_table):
134140

135141
# Check row_count without doing any reading first...
136142
assert -1 == cursor.rowcount
143+
assert expected_count == cursor.redshift_rowcount
137144

138145
# Check rowcount after reading some rows, make sure it still
139146
# works...
140147
for i in range(expected_count // 2):
141148
cursor.fetchone()
142149
assert -1 == cursor.rowcount
150+
assert expected_count == cursor.redshift_rowcount
143151

144152
with db_table.cursor() as cursor:
145153
# Restart the cursor, read a few rows, and then check rowcount
@@ -148,10 +156,12 @@ def test_row_count(db_table):
148156
for i in range(expected_count // 3):
149157
cursor.fetchone()
150158
assert -1 == cursor.rowcount
159+
assert expected_count == cursor.redshift_rowcount
151160

152161
# Should be -1 for a command with no results
153162
cursor.execute("DROP TABLE t1")
154163
assert -1 == cursor.rowcount
164+
assert -1 == cursor.redshift_rowcount
155165

156166

157167
def test_row_count_fetch(db_table):
@@ -160,6 +170,7 @@ def test_row_count_fetch(db_table):
160170
cursor.execute("SELECT * FROM t1")
161171
cursor.fetchall()
162172
assert -1 == cursor.rowcount
173+
assert 1 == cursor.redshift_rowcount
163174

164175

165176
def test_row_count_update(db_table):
@@ -170,7 +181,8 @@ def test_row_count_update(db_table):
170181
cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None))
171182
cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None))
172183
cursor.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",))
173-
assert cursor.rowcount == 2
184+
assert 2 == cursor.rowcount
185+
assert 2 == cursor.redshift_rowcount
174186

175187

176188
def test_int_oid(cursor):
@@ -199,7 +211,9 @@ def test_transactions(db_table):
199211
with db_table.cursor() as cursor:
200212
cursor.execute("commit")
201213
cursor.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, "Zombie"))
202-
assert cursor.rowcount == 1
214+
assert 1 == cursor.rowcount
215+
assert 1 == cursor.redshift_rowcount
216+
203217
cursor.execute("rollback")
204218
cursor.execute("select * from t1")
205219

0 commit comments

Comments
 (0)