Skip to content

Commit 00cc141

Browse files
authored
Merge pull request #53 from qaspen-python/feature/supporting_more_types
Supported more types
2 parents 7dde7fd + 9867520 commit 00cc141

File tree

8 files changed

+458
-22
lines changed

8 files changed

+458
-22
lines changed

Cargo.lock

Lines changed: 329 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ crate-type = ["cdylib"]
1010

1111
[dependencies]
1212
deadpool-postgres = { git = "https://github.com/chandr-andr/deadpool.git", branch = "master" }
13-
pyo3 = { version = "*", features = ["chrono", "experimental-async"] }
13+
pyo3 = { version = "*", features = [
14+
"chrono",
15+
"experimental-async",
16+
"rust_decimal",
17+
] }
1418
pyo3-asyncio = { git = "https://github.com/chandr-andr/pyo3-asyncio.git", version = "0.20.0", features = [
1519
"tokio-runtime",
1620
] }
@@ -34,3 +38,7 @@ postgres-types = { git = "https://github.com/chandr-andr/rust-postgres.git", bra
3438
"derive",
3539
] }
3640
postgres-protocol = { git = "https://github.com/chandr-andr/rust-postgres.git", branch = "master" }
41+
rust_decimal = { git = "https://github.com/chandr-andr/rust-decimal.git", branch = "psqlpy", features = [
42+
"db-postgres",
43+
"db-tokio-postgres",
44+
] }

python/psqlpy/_internal/extra_types.pyi

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ class BigInt:
3232
- `inner_value`: int object.
3333
"""
3434

35+
class Money:
36+
"""Represent `MONEY` in PostgreSQL and `i64` in Rust."""
37+
38+
def __init__(self: Self, inner_value: int) -> None:
39+
"""Create new instance of class.
40+
41+
### Parameters:
42+
- `inner_value`: int object.
43+
"""
44+
3545
class Float32:
3646
"""Represents `FLOAT4` in `PostgreSQL` and `f32` in Rust."""
3747

python/psqlpy/extra_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Float32,
44
Float64,
55
Integer,
6+
Money,
67
PyCustomType,
78
PyJSON,
89
PyJSONB,
@@ -26,4 +27,5 @@
2627
"PyCustomType",
2728
"Float32",
2829
"Float64",
30+
"Money",
2931
]

python/tests/test_value_converter.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import uuid
3+
from decimal import Decimal
34
from enum import Enum
45
from ipaddress import IPv4Address
56
from typing import Any, Dict, List, Union
@@ -15,6 +16,7 @@
1516
Float32,
1617
Float64,
1718
Integer,
19+
Money,
1820
PyJSON,
1921
PyJSONB,
2022
PyMacAddr6,
@@ -72,10 +74,18 @@ async def test_as_class(
7274
("BYTEA", b"Bytes", [66, 121, 116, 101, 115]),
7375
("VARCHAR", "Some String", "Some String"),
7476
("TEXT", "Some String", "Some String"),
77+
(
78+
"XML",
79+
"""<?xml version="1.0"?><book><title>Manual</title><chapter>...</chapter></book>""",
80+
"""<book><title>Manual</title><chapter>...</chapter></book>""",
81+
),
7582
("BOOL", True, True),
7683
("INT2", SmallInt(12), 12),
7784
("INT4", Integer(121231231), 121231231),
7885
("INT8", BigInt(99999999999999999), 99999999999999999),
86+
("MONEY", BigInt(99999999999999999), 99999999999999999),
87+
("MONEY", Money(99999999999999999), 99999999999999999),
88+
("NUMERIC(5, 2)", Decimal("120.12"), Decimal("120.12")),
7989
("FLOAT4", 32.12329864501953, 32.12329864501953),
8090
("FLOAT4", Float32(32.12329864501953), 32.12329864501953),
8191
("FLOAT8", Float64(32.12329864501953), 32.12329864501953),
@@ -145,6 +155,16 @@ async def test_as_class(
145155
[BigInt(99999999999999999), BigInt(99999999999999999)],
146156
[99999999999999999, 99999999999999999],
147157
),
158+
(
159+
"MONEY ARRAY",
160+
[Money(99999999999999999), Money(99999999999999999)],
161+
[99999999999999999, 99999999999999999],
162+
),
163+
(
164+
"NUMERIC(5, 2) ARRAY",
165+
[Decimal("121.23"), Decimal("188.99")],
166+
[Decimal("121.23"), Decimal("188.99")],
167+
),
148168
(
149169
"FLOAT4 ARRAY",
150170
[32.12329864501953, 32.12329864501953],

src/exceptions/rust_errors.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ pub enum RustPSQLDriverError {
7575
RustMacAddrConversionError(#[from] macaddr::ParseError),
7676
#[error("Cannot execute future in Rust: {0}")]
7777
RustRuntimeJoinError(#[from] JoinError),
78+
#[error("Cannot convert python Decimal into rust Decimal")]
79+
DecimalConversionError(#[from] rust_decimal::Error),
7880
}
7981

8082
impl From<RustPSQLDriverError> for pyo3::PyErr {
@@ -92,7 +94,8 @@ impl From<RustPSQLDriverError> for pyo3::PyErr {
9294
RustPSQLDriverError::RustToPyValueConversionError(_) => {
9395
RustToPyValueMappingError::new_err((error_desc,))
9496
}
95-
RustPSQLDriverError::PyToRustValueConversionError(_) => {
97+
RustPSQLDriverError::PyToRustValueConversionError(_)
98+
| RustPSQLDriverError::DecimalConversionError(_) => {
9699
PyToRustValueMappingError::new_err((error_desc,))
97100
}
98101
RustPSQLDriverError::ConnectionPoolConfigurationError(_) => {

src/extra_types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ macro_rules! build_python_type {
4444
build_python_type!(SmallInt, i16);
4545
build_python_type!(Integer, i32);
4646
build_python_type!(BigInt, i64);
47+
build_python_type!(Money, i64);
4748
build_python_type!(Float32, f32);
4849
build_python_type!(Float64, f64);
4950

@@ -189,6 +190,7 @@ pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyRes
189190
pymod.add_class::<SmallInt>()?;
190191
pymod.add_class::<Integer>()?;
191192
pymod.add_class::<BigInt>()?;
193+
pymod.add_class::<Money>()?;
192194
pymod.add_class::<Float32>()?;
193195
pymod.add_class::<Float64>()?;
194196
pymod.add_class::<PyText>()?;

src/value_converter.rs

Lines changed: 82 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
22
use macaddr::{MacAddr6, MacAddr8};
33
use postgres_types::{Field, FromSql, Kind, ToSql};
4+
use rust_decimal::Decimal;
45
use serde_json::{json, Map, Value};
56
use std::{fmt::Debug, net::IpAddr};
67
use uuid::Uuid;
78

89
use bytes::{BufMut, BytesMut};
910
use postgres_protocol::types;
1011
use pyo3::{
12+
sync::GILOnceCell,
1113
types::{
1214
PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyDictMethods, PyFloat, PyInt,
13-
PyList, PyListMethods, PyString, PyTime, PyTuple, PyTypeMethods,
15+
PyList, PyListMethods, PyString, PyTime, PyTuple, PyType, PyTypeMethods,
1416
},
15-
Bound, Py, PyAny, Python, ToPyObject,
17+
Bound, Py, PyAny, PyObject, PyResult, Python, ToPyObject,
1618
};
1719
use tokio_postgres::{
1820
types::{to_sql_checked, Type},
@@ -23,13 +25,43 @@ use crate::{
2325
additional_types::{RustMacAddr6, RustMacAddr8},
2426
exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult},
2527
extra_types::{
26-
BigInt, Float32, Float64, Integer, PyCustomType, PyJSON, PyJSONB, PyMacAddr6, PyMacAddr8,
27-
PyText, PyVarChar, SmallInt,
28+
BigInt, Float32, Float64, Integer, Money, PyCustomType, PyJSON, PyJSONB, PyMacAddr6,
29+
PyMacAddr8, PyText, PyVarChar, SmallInt,
2830
},
2931
};
3032

33+
static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
34+
3135
pub type QueryParameter = (dyn ToSql + Sync);
3236

37+
fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
38+
DECIMAL_CLS
39+
.get_or_try_init(py, || {
40+
let type_object = py
41+
.import_bound("decimal")?
42+
.getattr("Decimal")?
43+
.downcast_into()?;
44+
Ok(type_object.unbind())
45+
})
46+
.map(|ty| ty.bind(py))
47+
}
48+
49+
/// Struct for Decimal.
50+
///
51+
/// It's necessary because we use custom forks and there is
52+
/// no implementation of `ToPyObject` for Decimal.
53+
struct InnerDecimal(Decimal);
54+
55+
impl ToPyObject for InnerDecimal {
56+
fn to_object(&self, py: Python<'_>) -> PyObject {
57+
let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal");
58+
let ret = dec_cls
59+
.call1((self.0.to_string(),))
60+
.expect("failed to call decimal.Decimal(value)");
61+
ret.to_object(py)
62+
}
63+
}
64+
3365
/// Additional type for types come from Python.
3466
///
3567
/// It's necessary because we need to pass this
@@ -51,6 +83,7 @@ pub enum PythonDTO {
5183
PyIntU64(u64),
5284
PyFloat32(f32),
5385
PyFloat64(f64),
86+
PyMoney(i64),
5487
PyDate(NaiveDate),
5588
PyTime(NaiveTime),
5689
PyDateTime(NaiveDateTime),
@@ -62,6 +95,7 @@ pub enum PythonDTO {
6295
PyJson(Value),
6396
PyMacAddr6(MacAddr6),
6497
PyMacAddr8(MacAddr8),
98+
PyDecimal(Decimal),
6599
PyCustomType(Vec<u8>),
66100
}
67101

@@ -89,6 +123,7 @@ impl PythonDTO {
89123
PythonDTO::PyIntI64(_) => Ok(tokio_postgres::types::Type::INT8_ARRAY),
90124
PythonDTO::PyFloat32(_) => Ok(tokio_postgres::types::Type::FLOAT4_ARRAY),
91125
PythonDTO::PyFloat64(_) => Ok(tokio_postgres::types::Type::FLOAT8_ARRAY),
126+
PythonDTO::PyMoney(_) => Ok(tokio_postgres::types::Type::MONEY_ARRAY),
92127
PythonDTO::PyIpAddress(_) => Ok(tokio_postgres::types::Type::INET_ARRAY),
93128
PythonDTO::PyJsonb(_) => Ok(tokio_postgres::types::Type::JSONB_ARRAY),
94129
PythonDTO::PyJson(_) => Ok(tokio_postgres::types::Type::JSON_ARRAY),
@@ -98,6 +133,7 @@ impl PythonDTO {
98133
PythonDTO::PyDateTimeTz(_) => Ok(tokio_postgres::types::Type::TIMESTAMPTZ_ARRAY),
99134
PythonDTO::PyMacAddr6(_) => Ok(tokio_postgres::types::Type::MACADDR_ARRAY),
100135
PythonDTO::PyMacAddr8(_) => Ok(tokio_postgres::types::Type::MACADDR8_ARRAY),
136+
PythonDTO::PyDecimal(_) => Ok(tokio_postgres::types::Type::NUMERIC_ARRAY),
101137
_ => Err(RustPSQLDriverError::PyToRustValueConversionError(
102138
"Can't process array type, your type doesn't have support yet".into(),
103139
)),
@@ -197,7 +233,7 @@ impl ToSql for PythonDTO {
197233
}
198234
PythonDTO::PyIntI16(int) => out.put_i16(*int),
199235
PythonDTO::PyIntI32(int) => out.put_i32(*int),
200-
PythonDTO::PyIntI64(int) => out.put_i64(*int),
236+
PythonDTO::PyIntI64(int) | PythonDTO::PyMoney(int) => out.put_i64(*int),
201237
PythonDTO::PyIntU32(int) => out.put_u32(*int),
202238
PythonDTO::PyIntU64(int) => out.put_u64(*int),
203239
PythonDTO::PyFloat32(float) => out.put_f32(*float),
@@ -237,6 +273,9 @@ impl ToSql for PythonDTO {
237273
PythonDTO::PyJsonb(py_dict) | PythonDTO::PyJson(py_dict) => {
238274
<&Value as ToSql>::to_sql(&py_dict, ty, out)?;
239275
}
276+
PythonDTO::PyDecimal(py_decimal) => {
277+
<Decimal as ToSql>::to_sql(py_decimal, ty, out)?;
278+
}
240279
}
241280
if return_is_null_true {
242281
Ok(tokio_postgres::types::IsNull::Yes)
@@ -286,6 +325,7 @@ pub fn convert_parameters(parameters: Py<PyAny>) -> RustPSQLDriverPyResult<Vec<P
286325
/// or value of the type is incorrect.
287326
#[allow(clippy::too_many_lines)]
288327
pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<PythonDTO> {
328+
println!("{}", parameter.get_type().name()?);
289329
if parameter.is_none() {
290330
return Ok(PythonDTO::PyNone);
291331
}
@@ -352,6 +392,12 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
352392
));
353393
}
354394

395+
if parameter.is_instance_of::<Money>() {
396+
return Ok(PythonDTO::PyMoney(
397+
parameter.extract::<Money>()?.retrieve_value(),
398+
));
399+
}
400+
355401
if parameter.is_instance_of::<PyInt>() {
356402
return Ok(PythonDTO::PyIntI32(parameter.extract::<i32>()?));
357403
}
@@ -443,6 +489,13 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
443489
)?));
444490
}
445491

492+
if parameter.get_type().name()? == "decimal.Decimal" {
493+
println!("{}", parameter.str()?.extract::<&str>()?);
494+
return Ok(PythonDTO::PyDecimal(Decimal::from_str_exact(
495+
parameter.str()?.extract::<&str>()?,
496+
)?));
497+
}
498+
446499
if let Ok(id_address) = parameter.extract::<IpAddr>() {
447500
return Ok(PythonDTO::PyIpAddress(id_address));
448501
}
@@ -496,7 +549,7 @@ fn postgres_bytes_to_py(
496549
.to_object(py)),
497550
// // ---------- String Types ----------
498551
// // Convert TEXT and VARCHAR type into String, then into str
499-
Type::TEXT | Type::VARCHAR => Ok(_composite_field_postgres_to_py::<Option<String>>(
552+
Type::TEXT | Type::VARCHAR | Type::XML => Ok(_composite_field_postgres_to_py::<Option<String>>(
500553
type_, buf, is_simple,
501554
)?
502555
.to_object(py)),
@@ -515,7 +568,7 @@ fn postgres_bytes_to_py(
515568
_composite_field_postgres_to_py::<Option<i32>>(type_, buf, is_simple)?.to_object(py),
516569
),
517570
// Convert BigInt into i64, then into int
518-
Type::INT8 => Ok(
571+
Type::INT8 | Type::MONEY => Ok(
519572
_composite_field_postgres_to_py::<Option<i64>>(type_, buf, is_simple)?.to_object(py),
520573
),
521574
// Convert REAL into f32, then into float
@@ -592,13 +645,21 @@ fn postgres_bytes_to_py(
592645
Ok(py.None().to_object(py))
593646
}
594647
}
648+
Type::NUMERIC => {
649+
if let Some(numeric_) = _composite_field_postgres_to_py::<Option<Decimal>>(
650+
type_, buf, is_simple,
651+
)? {
652+
return Ok(InnerDecimal(numeric_).to_object(py));
653+
}
654+
Ok(py.None().to_object(py))
655+
}
595656
// ---------- Array Text Types ----------
596657
Type::BOOL_ARRAY => Ok(_composite_field_postgres_to_py::<Option<Vec<bool>>>(
597658
type_, buf, is_simple,
598659
)?
599660
.to_object(py)),
600661
// Convert ARRAY of TEXT or VARCHAR into Vec<String>, then into list[str]
601-
Type::TEXT_ARRAY | Type::VARCHAR_ARRAY => Ok(_composite_field_postgres_to_py::<
662+
Type::TEXT_ARRAY | Type::VARCHAR_ARRAY | Type::XML_ARRAY => Ok(_composite_field_postgres_to_py::<
602663
Option<Vec<String>>,
603664
>(type_, buf, is_simple)?
604665
.to_object(py)),
@@ -614,7 +675,7 @@ fn postgres_bytes_to_py(
614675
)?
615676
.to_object(py)),
616677
// Convert ARRAY of BigInt into Vec<i64>, then into list[int]
617-
Type::INT8_ARRAY => Ok(_composite_field_postgres_to_py::<Option<Vec<i64>>>(
678+
Type::INT8_ARRAY | Type::MONEY_ARRAY => Ok(_composite_field_postgres_to_py::<Option<Vec<i64>>>(
618679
type_, buf, is_simple,
619680
)?
620681
.to_object(py)),
@@ -686,6 +747,18 @@ fn postgres_bytes_to_py(
686747
None => Ok(py.None().to_object(py)),
687748
}
688749
}
750+
Type::NUMERIC_ARRAY => {
751+
if let Some(numeric_array) = _composite_field_postgres_to_py::<Option<Vec<Decimal>>>(
752+
type_, buf, is_simple,
753+
)? {
754+
let py_list = PyList::empty_bound(py);
755+
for numeric_ in numeric_array {
756+
py_list.append(InnerDecimal(numeric_).to_object(py))?;
757+
}
758+
return Ok(py_list.to_object(py))
759+
};
760+
Ok(py.None().to_object(py))
761+
},
689762
_ => Err(RustPSQLDriverError::RustToPyValueConversionError(
690763
format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.")
691764
)),

0 commit comments

Comments
 (0)