Skip to content

Commit 8eacb12

Browse files
committed
fix(datatype): Fix data ingestion to properly handle NUMERIC data types when using the binary transfer protocol to prevent incorrect results. issue #207
1 parent 1608eaf commit 8eacb12

File tree

3 files changed

+161
-11
lines changed

3 files changed

+161
-11
lines changed

redshift_connector/utils/type_utils.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,8 @@ def intervald2s_send_integer(v: IntervalDayToSecond) -> bytes:
245245
def numeric_in_binary(data: bytes, offset: int, length: int, scale: int) -> Decimal:
246246
raw_value: int
247247

248-
if length == 8:
249-
raw_value = q_unpack(data, offset)[0]
250-
elif length == 16:
251-
temp: typing.Tuple[int, int] = qq_unpack(data, offset)
252-
raw_value = (temp[0] << 64) | temp[1]
248+
if length == 8 or length == 16:
249+
raw_value = int.from_bytes(data[offset : offset + length], byteorder="big", signed=True)
253250
else:
254251
raise Exception("Malformed column value of type numeric received")
255252

@@ -259,11 +256,8 @@ def numeric_in_binary(data: bytes, offset: int, length: int, scale: int) -> Deci
259256
def numeric_to_float_binary(data: bytes, offset: int, length: int, scale: int) -> float:
260257
raw_value: int
261258

262-
if length == 8:
263-
raw_value = q_unpack(data, offset)[0]
264-
elif length == 16:
265-
temp: typing.Tuple[int, int] = qq_unpack(data, offset)
266-
raw_value = (temp[0] << 64) | temp[1]
259+
if length == 8 or length == 16:
260+
raw_value = int.from_bytes(data[offset : offset + length], byteorder="big", signed=True)
267261
else:
268262
raise Exception("Malformed column value of type numeric received")
269263

test/integration/datatype/test_datatypes.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,13 @@ def test_abstime(db_kwargs, _input, client_protocol):
242242

243243
numeric_vals: typing.List[typing.Tuple[str, float]] = [
244244
("to_number('12,454.8-', 'S99G999D9')", -12454.8),
245+
("to_number('12,454.8', 'S99G999D9')", 12454.8),
245246
("to_number('8.1-', '9D9S')", -8.1),
247+
("to_number('8.1', '9D9S')", 8.1),
246248
("to_number('897.6', '999D9S')", 897.6),
249+
("to_number('897.6-', '999D9S')", -897.6),
250+
("to_number('$ 2,012,454.88', 'L 9,999,999.99');", 2012454.88),
251+
("to_number('$ -2,012,454.88', 'L 9,999,999.99');", -2012454.88),
247252
]
248253

249254

@@ -263,3 +268,47 @@ def test_numeric_to_float(db_kwargs, _input, client_protocol):
263268
rel_tol=1e-05,
264269
abs_tol=1e-08,
265270
)
271+
272+
273+
numeric_precision_scale_vals: typing.List[typing.Tuple[str, float]] = [
274+
("-135430.11999999999500::numeric(36,14)", -135430.11999999999500),
275+
("g", 135430.11999999999500),
276+
("-35430.11999999999500::numeric(36,14)", -35430.11999999999500),
277+
("35430.11999999999500::numeric(36,14)", 35430.11999999999500),
278+
("-7872432525245.4577::numeric(36,14)", -7872432525245.4577),
279+
("7872432525245.4577::numeric(36,14)", 7872432525245.4577),
280+
("-252252::numeric(36,14)", -252252),
281+
("252252::numeric(36,14)", 252252),
282+
("-135430.11999999999500::numeric(30,19)", -135430.11999999999500),
283+
("135430.11999999999500::numeric(36,19)", 135430.11999999999500),
284+
("-35430.11999999999500::numeric(36,19)", -35430.11999999999500),
285+
("35430.11999999999500::numeric(36,19)", 35430.11999999999500),
286+
("-252252::numeric(36,11)", -252252),
287+
("252252::numeric(36,11)", 252252),
288+
("-135430.11999999999500::numeric(36,11)", -135430.11999999999500),
289+
("135430.11999999999500::numeric(36,11)", 135430.11999999999500),
290+
("-35430.11999999999500::numeric(36,11)", -35430.11999999999500),
291+
("35430.11999999999500::numeric(36,11)", 35430.11999999999500),
292+
("-7872432525245.4577::numeric(36,11)", -7872432525245.4577),
293+
("7872432525245.4577::numeric(36,11)", 7872432525245.4577),
294+
("-252252::numeric(36,11)", -252252),
295+
("252252::numeric(36,11)", 252252),
296+
]
297+
298+
299+
@pytest.mark.parametrize("client_protocol", ClientProtocolVersion.list())
300+
@pytest.mark.parametrize("_input", numeric_precision_scale_vals)
301+
def test_numeric_precision(db_kwargs, client_protocol, _input):
302+
insert_val, exp_val = _input
303+
db_kwargs["numeric_to_float"] = True
304+
with redshift_connector.connect(**db_kwargs) as conn:
305+
with conn.cursor() as cursor:
306+
cursor.execute("select {}".format(insert_val))
307+
res = cursor.fetchone()
308+
assert isinstance(res[0], float)
309+
assert isclose(
310+
typing.cast(float, res[0]),
311+
typing.cast(float, exp_val),
312+
rel_tol=1e-05,
313+
abs_tol=1e-08,
314+
)

test/unit/datatype/test_data_in.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Datatypes(Enum):
2525
float8: typing.Callable = type_utils.float8_recv
2626
timestamp: typing.Callable = type_utils.timestamp_recv_integer
2727
numeric_binary: typing.Callable = type_utils.numeric_in_binary
28+
numeric_to_float_binary: typing.Callable = type_utils.numeric_to_float_binary
2829
numeric: typing.Callable = type_utils.numeric_in
2930
timetz_binary: typing.Callable = type_utils.timetz_recv_binary
3031
time_binary: typing.Callable = type_utils.time_recv_binary
@@ -117,6 +118,73 @@ class Datatypes(Enum):
117118
datetime(year=2008, month=6, day=1, hour=9, minute=59, second=59),
118119
)
119120
],
121+
Datatypes.numeric_to_float_binary: [
122+
# 8
123+
(
124+
b"\x00\x01\x00\x00\x00\x08\x01\xb6\x9bK\xac\xd0_\x15",
125+
6,
126+
8,
127+
0,
128+
Decimal(123456789123456789),
129+
),
130+
(
131+
b"\x00\x01\x00\x00\x00\x08\x00\x00\x00\x00I\x96\x02\xd3",
132+
6,
133+
8,
134+
5,
135+
Decimal(12345.67891),
136+
),
137+
(
138+
b"\x00\x01\x00\x00\x00\x08\xff\xff\xff\xff\xb6i\xfd-",
139+
6,
140+
8,
141+
5,
142+
Decimal(-12345.67891),
143+
),
144+
(
145+
b"\x00\x01\x00\x00\x00\x08\x00\x00\x00\x00\x00\x0009",
146+
6,
147+
8,
148+
8,
149+
Decimal(0.00012345),
150+
),
151+
(
152+
b"\x00\x01\x00\x00\x00\x08\x00\x00\x01\x1fq\xfb\x088",
153+
6,
154+
8,
155+
8,
156+
Decimal(12345.67891),
157+
),
158+
(
159+
b"\x00\x01\x00\x00\x00\x08\x00\x18\x1e\xab\xfb\xfaj\x9e",
160+
6,
161+
8,
162+
3,
163+
Decimal(6789123456789.15),
164+
),
165+
(
166+
b'\x00\x01\x00\x00\x00\x08\x11"\x10\xf4\xc0#\xb6\xd4',
167+
6,
168+
8,
169+
1,
170+
Decimal(123456789123456789.2),
171+
),
172+
# 16
173+
(
174+
b"\x00\x01\x00\x00\x00\x10\tI\xb0\xf7\x13\xe9\x18_~\x8f\x1a\x99\xa9\x9b\xb6\xdb",
175+
6,
176+
16,
177+
0,
178+
Decimal(12345678912345678991234567891234567899),
179+
),
180+
(
181+
b"\x00\x01\x00\x00\x00\x10\x00\x02`\xb0`\x05\x18\xdb<\xd5\x01\x15\xd9\x8ek\x8d",
182+
6,
183+
16,
184+
26,
185+
Decimal(123456789.12345679104328155517578125),
186+
),
187+
],
120188
Datatypes.numeric_binary: [
121189
# 8
122190
(
@@ -154,6 +222,20 @@ class Datatypes(Enum):
154222
8,
155223
Decimal(12345.67891),
156224
),
225+
(
226+
b"\x00\x03\x00\x00\x00\x08\x00\x18\x1e\xab\xfb\xfaj\x9e\x00\x00\x00\x01a\x00\x00\x00\x08\x01\xb6\x9bK\xac\xd0_\x15",
227+
6,
228+
8,
229+
3,
230+
Decimal(6789123456789.15),
231+
),
232+
(
233+
b"\x00\x03\x00\x00\x00\x08\x00\x18\x1e\xab\xfb\xfaj\x9e\x00\x00\x00\x01a\x00\x00\x00\x08\x01\xb6\x9bK\xac\xd0_\x15",
234+
23,
235+
8,
236+
0,
237+
Decimal(123456789123456789.2),
238+
),
157239
# 16
158240
(
159241
b"\x00\x05\x00\x00\x00\x10\tI\xb0\xf7\x13\xe9\x18_~\x8f\x1a\x99\xa9\x9b\xb6\xdb\x00\x00\x00\x10\tI\xb0\xf7\x13\xe9\x18_~\x8f\x1a\x99\xa9\x9b\xb6\xdb\x00\x00\x00\x10\tI\xb0\xf7\x13\xe9\x18_~\x8f\x1a\x99\xa9\x9b\xb6\xdb\x00\x00\x00\x10\tI\xb0\xf7\x13\xe9\x18_~\x8f\x1a\x99\xa9\x9b\xb6\xdb\x00\x00\x00\x10\tI\xb0\xf7\x13\xe9\x18_~\x8f\x1a\x99\xa9\x9b\xb6\xdb",
@@ -376,7 +458,7 @@ def get_test_cases() -> typing.Generator:
376458
@pytest.mark.parametrize("_input", get_test_cases(), ids=[k.__name__ for k, v in get_test_cases()])
377459
def test_datatype_recv(_input):
378460
test_func, test_args = _input
379-
if len(test_args) == 5: # numeric_in_binary
461+
if len(test_args) == 5: # numeric_in_binary or numeric_to_float_binary
380462
_data, _offset, _length, scale, exp_result = test_args
381463
assert isclose(test_func(_data, _offset, _length, scale), exp_result, rel_tol=1e-6)
382464
else:
@@ -385,3 +467,28 @@ def test_datatype_recv(_input):
385467
assert isclose(test_func(_data, _offset, _length), exp_result, rel_tol=1e-6)
386468
else:
387469
assert test_func(_data, _offset, _length) == exp_result
470+
471+
472+
invalid_numeric_lengths: typing.List[int] = [-1, 0, 7, 9, 15, 17, 99]
473+
474+
475+
@pytest.mark.parametrize("length", invalid_numeric_lengths)
476+
def test_numeric_in_binary_raises_for_invalid_length(length):
477+
with pytest.raises(Exception, match="Malformed column value of type numeric received"):
478+
Datatypes.numeric_binary(
479+
b"\x00\x02\x00\x00\x00\x0c-12345.67891\x00\x00\x00\x08\xff\xff\xfe\xe0\x8e\x04\xf7\xc8",
480+
22,
481+
length, # invalid length
482+
8,
483+
)
484+
485+
486+
@pytest.mark.parametrize("length", invalid_numeric_lengths)
487+
def test_numeric_to_float_binary_raises_for_invalid_length(length):
488+
with pytest.raises(Exception, match="Malformed column value of type numeric received"):
489+
Datatypes.numeric_to_float_binary(
490+
b"\x00\x02\x00\x00\x00\x0c-12345.67891\x00\x00\x00\x08\xff\xff\xfe\xe0\x8e\x04\xf7\xc8",
491+
22,
492+
length, # invalid length
493+
8,
494+
)

0 commit comments

Comments
 (0)