11use chrono:: { self , DateTime , FixedOffset , NaiveDate , NaiveDateTime , NaiveTime } ;
22use macaddr:: { MacAddr6 , MacAddr8 } ;
33use postgres_types:: { Field , FromSql , Kind , ToSql } ;
4+ use rust_decimal:: Decimal ;
45use serde_json:: { json, Map , Value } ;
56use std:: { fmt:: Debug , net:: IpAddr } ;
67use uuid:: Uuid ;
78
89use bytes:: { BufMut , BytesMut } ;
910use postgres_protocol:: types;
1011use pyo3:: {
12+ sync:: GILOnceCell ,
1113 types:: {
1214 PyAnyMethods , PyBool , PyBytes , PyDate , PyDateTime , PyDict , PyDictMethods , PyFloat , PyInt ,
13- PyList , PyListMethods , PyString , PyTime , PyTuple , PyTypeMethods ,
15+ PyList , PyListMethods , PyString , PyTime , PyTuple , PyType , PyTypeMethods ,
1416 } ,
15- Bound , Py , PyAny , Python , ToPyObject ,
17+ Bound , Py , PyAny , PyObject , PyResult , Python , ToPyObject ,
1618} ;
1719use tokio_postgres:: {
1820 types:: { to_sql_checked, Type } ,
@@ -23,13 +25,43 @@ use crate::{
2325 additional_types:: { RustMacAddr6 , RustMacAddr8 } ,
2426 exceptions:: rust_errors:: { RustPSQLDriverError , RustPSQLDriverPyResult } ,
2527 extra_types:: {
26- BigInt , Float32 , Float64 , Integer , PyCustomType , PyJSON , PyJSONB , PyMacAddr6 , PyMacAddr8 ,
27- PyText , PyVarChar , SmallInt ,
28+ BigInt , Float32 , Float64 , Integer , Money , PyCustomType , PyJSON , PyJSONB , PyMacAddr6 ,
29+ PyMacAddr8 , PyText , PyVarChar , SmallInt ,
2830 } ,
2931} ;
3032
33+ static DECIMAL_CLS : GILOnceCell < Py < PyType > > = GILOnceCell :: new ( ) ;
34+
3135pub type QueryParameter = ( dyn ToSql + Sync ) ;
3236
37+ fn get_decimal_cls ( py : Python < ' _ > ) -> PyResult < & Bound < ' _ , PyType > > {
38+ DECIMAL_CLS
39+ . get_or_try_init ( py, || {
40+ let type_object = py
41+ . import_bound ( "decimal" ) ?
42+ . getattr ( "Decimal" ) ?
43+ . downcast_into ( ) ?;
44+ Ok ( type_object. unbind ( ) )
45+ } )
46+ . map ( |ty| ty. bind ( py) )
47+ }
48+
49+ /// Struct for Decimal.
50+ ///
51+ /// It's necessary because we use custom forks and there is
52+ /// no implementation of `ToPyObject` for Decimal.
53+ struct InnerDecimal ( Decimal ) ;
54+
55+ impl ToPyObject for InnerDecimal {
56+ fn to_object ( & self , py : Python < ' _ > ) -> PyObject {
57+ let dec_cls = get_decimal_cls ( py) . expect ( "failed to load decimal.Decimal" ) ;
58+ let ret = dec_cls
59+ . call1 ( ( self . 0 . to_string ( ) , ) )
60+ . expect ( "failed to call decimal.Decimal(value)" ) ;
61+ ret. to_object ( py)
62+ }
63+ }
64+
3365/// Additional type for types come from Python.
3466///
3567/// It's necessary because we need to pass this
@@ -51,6 +83,7 @@ pub enum PythonDTO {
5183 PyIntU64 ( u64 ) ,
5284 PyFloat32 ( f32 ) ,
5385 PyFloat64 ( f64 ) ,
86+ PyMoney ( i64 ) ,
5487 PyDate ( NaiveDate ) ,
5588 PyTime ( NaiveTime ) ,
5689 PyDateTime ( NaiveDateTime ) ,
@@ -62,6 +95,7 @@ pub enum PythonDTO {
6295 PyJson ( Value ) ,
6396 PyMacAddr6 ( MacAddr6 ) ,
6497 PyMacAddr8 ( MacAddr8 ) ,
98+ PyDecimal ( Decimal ) ,
6599 PyCustomType ( Vec < u8 > ) ,
66100}
67101
@@ -89,6 +123,7 @@ impl PythonDTO {
89123 PythonDTO :: PyIntI64 ( _) => Ok ( tokio_postgres:: types:: Type :: INT8_ARRAY ) ,
90124 PythonDTO :: PyFloat32 ( _) => Ok ( tokio_postgres:: types:: Type :: FLOAT4_ARRAY ) ,
91125 PythonDTO :: PyFloat64 ( _) => Ok ( tokio_postgres:: types:: Type :: FLOAT8_ARRAY ) ,
126+ PythonDTO :: PyMoney ( _) => Ok ( tokio_postgres:: types:: Type :: MONEY_ARRAY ) ,
92127 PythonDTO :: PyIpAddress ( _) => Ok ( tokio_postgres:: types:: Type :: INET_ARRAY ) ,
93128 PythonDTO :: PyJsonb ( _) => Ok ( tokio_postgres:: types:: Type :: JSONB_ARRAY ) ,
94129 PythonDTO :: PyJson ( _) => Ok ( tokio_postgres:: types:: Type :: JSON_ARRAY ) ,
@@ -98,6 +133,7 @@ impl PythonDTO {
98133 PythonDTO :: PyDateTimeTz ( _) => Ok ( tokio_postgres:: types:: Type :: TIMESTAMPTZ_ARRAY ) ,
99134 PythonDTO :: PyMacAddr6 ( _) => Ok ( tokio_postgres:: types:: Type :: MACADDR_ARRAY ) ,
100135 PythonDTO :: PyMacAddr8 ( _) => Ok ( tokio_postgres:: types:: Type :: MACADDR8_ARRAY ) ,
136+ PythonDTO :: PyDecimal ( _) => Ok ( tokio_postgres:: types:: Type :: NUMERIC_ARRAY ) ,
101137 _ => Err ( RustPSQLDriverError :: PyToRustValueConversionError (
102138 "Can't process array type, your type doesn't have support yet" . into ( ) ,
103139 ) ) ,
@@ -197,7 +233,7 @@ impl ToSql for PythonDTO {
197233 }
198234 PythonDTO :: PyIntI16 ( int) => out. put_i16 ( * int) ,
199235 PythonDTO :: PyIntI32 ( int) => out. put_i32 ( * int) ,
200- PythonDTO :: PyIntI64 ( int) => out. put_i64 ( * int) ,
236+ PythonDTO :: PyIntI64 ( int) | PythonDTO :: PyMoney ( int ) => out. put_i64 ( * int) ,
201237 PythonDTO :: PyIntU32 ( int) => out. put_u32 ( * int) ,
202238 PythonDTO :: PyIntU64 ( int) => out. put_u64 ( * int) ,
203239 PythonDTO :: PyFloat32 ( float) => out. put_f32 ( * float) ,
@@ -237,6 +273,9 @@ impl ToSql for PythonDTO {
237273 PythonDTO :: PyJsonb ( py_dict) | PythonDTO :: PyJson ( py_dict) => {
238274 <& Value as ToSql >:: to_sql ( & py_dict, ty, out) ?;
239275 }
276+ PythonDTO :: PyDecimal ( py_decimal) => {
277+ <Decimal as ToSql >:: to_sql ( py_decimal, ty, out) ?;
278+ }
240279 }
241280 if return_is_null_true {
242281 Ok ( tokio_postgres:: types:: IsNull :: Yes )
@@ -286,6 +325,7 @@ pub fn convert_parameters(parameters: Py<PyAny>) -> RustPSQLDriverPyResult<Vec<P
286325/// or value of the type is incorrect.
287326#[ allow( clippy:: too_many_lines) ]
288327pub fn py_to_rust ( parameter : & pyo3:: Bound < ' _ , PyAny > ) -> RustPSQLDriverPyResult < PythonDTO > {
328+ println ! ( "{}" , parameter. get_type( ) . name( ) ?) ;
289329 if parameter. is_none ( ) {
290330 return Ok ( PythonDTO :: PyNone ) ;
291331 }
@@ -352,6 +392,12 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
352392 ) ) ;
353393 }
354394
395+ if parameter. is_instance_of :: < Money > ( ) {
396+ return Ok ( PythonDTO :: PyMoney (
397+ parameter. extract :: < Money > ( ) ?. retrieve_value ( ) ,
398+ ) ) ;
399+ }
400+
355401 if parameter. is_instance_of :: < PyInt > ( ) {
356402 return Ok ( PythonDTO :: PyIntI32 ( parameter. extract :: < i32 > ( ) ?) ) ;
357403 }
@@ -443,6 +489,13 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
443489 ) ?) ) ;
444490 }
445491
492+ if parameter. get_type ( ) . name ( ) ? == "decimal.Decimal" {
493+ println ! ( "{}" , parameter. str ( ) ?. extract:: <& str >( ) ?) ;
494+ return Ok ( PythonDTO :: PyDecimal ( Decimal :: from_str_exact (
495+ parameter. str ( ) ?. extract :: < & str > ( ) ?,
496+ ) ?) ) ;
497+ }
498+
446499 if let Ok ( id_address) = parameter. extract :: < IpAddr > ( ) {
447500 return Ok ( PythonDTO :: PyIpAddress ( id_address) ) ;
448501 }
@@ -496,7 +549,7 @@ fn postgres_bytes_to_py(
496549 . to_object ( py) ) ,
497550 // // ---------- String Types ----------
498551 // // Convert TEXT and VARCHAR type into String, then into str
499- Type :: TEXT | Type :: VARCHAR => Ok ( _composite_field_postgres_to_py :: < Option < String > > (
552+ Type :: TEXT | Type :: VARCHAR | Type :: XML => Ok ( _composite_field_postgres_to_py :: < Option < String > > (
500553 type_, buf, is_simple,
501554 ) ?
502555 . to_object ( py) ) ,
@@ -515,7 +568,7 @@ fn postgres_bytes_to_py(
515568 _composite_field_postgres_to_py :: < Option < i32 > > ( type_, buf, is_simple) ?. to_object ( py) ,
516569 ) ,
517570 // Convert BigInt into i64, then into int
518- Type :: INT8 => Ok (
571+ Type :: INT8 | Type :: MONEY => Ok (
519572 _composite_field_postgres_to_py :: < Option < i64 > > ( type_, buf, is_simple) ?. to_object ( py) ,
520573 ) ,
521574 // Convert REAL into f32, then into float
@@ -592,13 +645,21 @@ fn postgres_bytes_to_py(
592645 Ok ( py. None ( ) . to_object ( py) )
593646 }
594647 }
648+ Type :: NUMERIC => {
649+ if let Some ( numeric_) = _composite_field_postgres_to_py :: < Option < Decimal > > (
650+ type_, buf, is_simple,
651+ ) ? {
652+ return Ok ( InnerDecimal ( numeric_) . to_object ( py) ) ;
653+ }
654+ Ok ( py. None ( ) . to_object ( py) )
655+ }
595656 // ---------- Array Text Types ----------
596657 Type :: BOOL_ARRAY => Ok ( _composite_field_postgres_to_py :: < Option < Vec < bool > > > (
597658 type_, buf, is_simple,
598659 ) ?
599660 . to_object ( py) ) ,
600661 // Convert ARRAY of TEXT or VARCHAR into Vec<String>, then into list[str]
601- Type :: TEXT_ARRAY | Type :: VARCHAR_ARRAY => Ok ( _composite_field_postgres_to_py :: <
662+ Type :: TEXT_ARRAY | Type :: VARCHAR_ARRAY | Type :: XML_ARRAY => Ok ( _composite_field_postgres_to_py :: <
602663 Option < Vec < String > > ,
603664 > ( type_, buf, is_simple) ?
604665 . to_object ( py) ) ,
@@ -614,7 +675,7 @@ fn postgres_bytes_to_py(
614675 ) ?
615676 . to_object ( py) ) ,
616677 // Convert ARRAY of BigInt into Vec<i64>, then into list[int]
617- Type :: INT8_ARRAY => Ok ( _composite_field_postgres_to_py :: < Option < Vec < i64 > > > (
678+ Type :: INT8_ARRAY | Type :: MONEY_ARRAY => Ok ( _composite_field_postgres_to_py :: < Option < Vec < i64 > > > (
618679 type_, buf, is_simple,
619680 ) ?
620681 . to_object ( py) ) ,
@@ -686,6 +747,18 @@ fn postgres_bytes_to_py(
686747 None => Ok ( py. None ( ) . to_object ( py) ) ,
687748 }
688749 }
750+ Type :: NUMERIC_ARRAY => {
751+ if let Some ( numeric_array) = _composite_field_postgres_to_py :: < Option < Vec < Decimal > > > (
752+ type_, buf, is_simple,
753+ ) ? {
754+ let py_list = PyList :: empty_bound ( py) ;
755+ for numeric_ in numeric_array {
756+ py_list. append ( InnerDecimal ( numeric_) . to_object ( py) ) ?;
757+ }
758+ return Ok ( py_list. to_object ( py) )
759+ } ;
760+ Ok ( py. None ( ) . to_object ( py) )
761+ } ,
689762 _ => Err ( RustPSQLDriverError :: RustToPyValueConversionError (
690763 format ! ( "Cannot convert {type_} into Python type, please look at the custom_decoders functionality." )
691764 ) ) ,
0 commit comments