11#!/usr/bin/env python3
22# -*- coding: utf-8 -*-
3- from typing import Any , Generic , Iterable , Literal , Sequence , Type , TypeVar
3+ from typing import Any , Generic , Iterable , Sequence , Type
44
5- from pydantic import BaseModel
6- from sqlalchemy import Row , RowMapping , and_ , asc , desc , or_ , select
5+ from sqlalchemy import Row , RowMapping , select
76from sqlalchemy import delete as sa_delete
87from sqlalchemy import update as sa_update
98from sqlalchemy .ext .asyncio import AsyncSession
109
11- from sqlalchemy_crud_plus .errors import ModelColumnError , SelectExpressionError
10+ from sqlalchemy_crud_plus .errors import MultipleResultsError
11+ from sqlalchemy_crud_plus .types import CreateSchema , Model , UpdateSchema
12+ from sqlalchemy_crud_plus .utils import apply_sorting , count , parse_filters
1213
13- _Model = TypeVar ('_Model' )
14- _CreateSchema = TypeVar ('_CreateSchema' , bound = BaseModel )
15- _UpdateSchema = TypeVar ('_UpdateSchema' , bound = BaseModel )
1614
17-
18- class CRUDPlus (Generic [_Model ]):
19- def __init__ (self , model : Type [_Model ]):
15+ class CRUDPlus (Generic [Model ]):
16+ def __init__ (self , model : Type [Model ]):
2017 self .model = model
2118
22- async def create_model (self , session : AsyncSession , obj : _CreateSchema , commit : bool = False , ** kwargs ) -> _Model :
19+ async def create_model (self , session : AsyncSession , obj : CreateSchema , commit : bool = False , ** kwargs ) -> Model :
2320 """
2421 Create a new instance of a model
2522
26- :param session:
27- :param obj:
28- :param commit:
29- :param kwargs:
23+ :param session: The SQLAlchemy async session.
24+ :param obj: The Pydantic schema containing data to be saved.
25+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
26+ :param kwargs: Additional model data not included in the pydantic schema.
3027 :return:
3128 """
32- if kwargs :
33- ins = self .model (** obj .model_dump (), ** kwargs )
34- else :
29+ if not kwargs :
3530 ins = self .model (** obj .model_dump ())
31+ else :
32+ ins = self .model (** obj .model_dump (), ** kwargs )
3633 session .add (ins )
3734 if commit :
3835 await session .commit ()
3936 return ins
4037
4138 async def create_models (
42- self , session : AsyncSession , obj : Iterable [_CreateSchema ], commit : bool = False
43- ) -> list [_Model ]:
39+ self , session : AsyncSession , obj : Iterable [CreateSchema ], commit : bool = False
40+ ) -> list [Model ]:
4441 """
4542 Create new instances of a model
4643
47- :param session:
48- :param obj:
49- :param commit:
44+ :param session: The SQLAlchemy async session.
45+ :param obj: The Pydantic schema list containing data to be saved.
46+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
5047 :return:
5148 """
5249 ins_list = []
53- for i in obj :
54- ins_list .append (self .model (** i .model_dump ()))
50+ for ins in obj :
51+ ins_list .append (self .model (** ins .model_dump ()))
5552 session .add_all (ins_list )
5653 if commit :
5754 await session .commit ()
5855 return ins_list
5956
60- async def select_model_by_id (self , session : AsyncSession , pk : int ) -> _Model | None :
57+ async def select_model (self , session : AsyncSession , pk : int ) -> Model | None :
6158 """
6259 Query by ID
6360
64- :param session:
65- :param pk:
61+ :param session: The SQLAlchemy async session.
62+ :param pk: The database primary key value.
6663 :return:
6764 """
6865 stmt = select (self .model ).where (self .model .id == pk )
6966 query = await session .execute (stmt )
7067 return query .scalars ().first ()
7168
72- async def select_model_by_column (self , session : AsyncSession , column : str , column_value : Any ) -> _Model | None :
69+ async def select_model_by_column (self , session : AsyncSession , ** kwargs ) -> Model | None :
7370 """
7471 Query by column
7572
76- :param session:
77- :param column:
78- :param column_value:
79- :return:
80- """
81- if hasattr (self .model , column ):
82- model_column = getattr (self .model , column )
83- stmt = select (self .model ).where (model_column == column_value ) # type: ignore
84- query = await session .execute (stmt )
85- return query .scalars ().first ()
86- else :
87- raise ModelColumnError (f'Column { column } is not found in { self .model } ' )
88-
89- async def select_model_by_columns (
90- self , session : AsyncSession , expression : Literal ['and' , 'or' ] = 'and' , ** conditions
91- ) -> _Model | None :
92- """
93- Query by columns
94-
95- :param session:
96- :param expression:
97- :param conditions: Query conditions, format:column1=value1, column2=value2
73+ :param session: The SQLAlchemy async session.
74+ :param kwargs: Query expressions.
9875 :return:
9976 """
100- where_list = []
101- for column , value in conditions .items ():
102- if hasattr (self .model , column ):
103- model_column = getattr (self .model , column )
104- where_list .append (model_column == value )
105- else :
106- raise ModelColumnError (f'Column { column } is not found in { self .model } ' )
107- match expression :
108- case 'and' :
109- stmt = select (self .model ).where (and_ (* where_list ))
110- query = await session .execute (stmt )
111- case 'or' :
112- stmt = select (self .model ).where (or_ (* where_list ))
113- query = await session .execute (stmt )
114- case _:
115- raise SelectExpressionError (
116- f'Select expression { expression } is not supported, only supports `and`, `or`'
117- )
77+ filters = await parse_filters (self .model , ** kwargs )
78+ stmt = select (self .model ).where (* filters )
79+ query = await session .execute (stmt )
11880 return query .scalars ().first ()
11981
120- async def select_models (self , session : AsyncSession ) -> Sequence [Row [Any ] | RowMapping | Any ]:
82+ async def select_models (self , session : AsyncSession , ** kwargs ) -> Sequence [Row [Any ] | RowMapping | Any ]:
12183 """
12284 Query all rows
12385
124- :param session:
86+ :param session: The SQLAlchemy async session.
87+ :param kwargs: Query expressions.
12588 :return:
12689 """
127- stmt = select (self .model )
90+ filters = await parse_filters (self .model , ** kwargs )
91+ stmt = select (self .model ).where (* filters )
12892 query = await session .execute (stmt )
12993 return query .scalars ().all ()
13094
13195 async def select_models_order (
132- self ,
133- session : AsyncSession ,
134- * columns ,
135- model_sort : Literal ['asc' , 'desc' ] = 'desc' ,
96+ self , session : AsyncSession , sort_columns : str | list [str ], sort_orders : str | list [str ] | None = None , ** kwargs
13697 ) -> Sequence [Row | RowMapping | Any ] | None :
13798 """
138- Query all rows asc or desc
99+ Query all rows and sort by columns
139100
140- :param session:
141- :param columns:
142- :param model_sort:
101+ :param session: The SQLAlchemy async session.
102+ :param sort_columns: more details see apply_sorting
103+ :param sort_orders: more details see apply_sorting
143104 :return:
144105 """
145- sort_list = []
146- for column in columns :
147- if hasattr (self .model , column ):
148- model_column = getattr (self .model , column )
149- sort_list .append (model_column )
150- else :
151- raise ModelColumnError (f'Column { column } is not found in { self .model } ' )
152- match model_sort :
153- case 'asc' :
154- query = await session .execute (select (self .model ).order_by (asc (* sort_list )))
155- case 'desc' :
156- query = await session .execute (select (self .model ).order_by (desc (* sort_list )))
157- case _:
158- raise SelectExpressionError (
159- f'Select sort expression { model_sort } is not supported, only supports `asc`, `desc`'
160- )
106+ filters = await parse_filters (self .model , ** kwargs )
107+ stmt = select (self .model ).where (* filters )
108+ stmt_sort = await apply_sorting (self .model , stmt , sort_columns , sort_orders )
109+ query = await session .execute (stmt_sort )
161110 return query .scalars ().all ()
162111
163112 async def update_model (
164- self , session : AsyncSession , pk : int , obj : _UpdateSchema | dict [str , Any ], commit : bool = False , ** kwargs
113+ self , session : AsyncSession , pk : int , obj : UpdateSchema | dict [str , Any ], commit : bool = False
165114 ) -> int :
166115 """
167- Update an instance of model's primary key
116+ Update an instance by model's primary key
168117
169- :param session:
170- :param pk:
171- :param obj:
172- :param commit:
173- :param kwargs:
118+ :param session: The SQLAlchemy async session.
119+ :param pk: The database primary key value.
120+ :param obj: A pydantic schema or dictionary containing the update data
121+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
174122 :return:
175123 """
176124 if isinstance (obj , dict ):
177125 instance_data = obj
178126 else :
179127 instance_data = obj .model_dump (exclude_unset = True )
180- if kwargs :
181- instance_data .update (kwargs )
182128 stmt = sa_update (self .model ).where (self .model .id == pk ).values (** instance_data )
183129 result = await session .execute (stmt )
184130 if commit :
@@ -188,55 +134,80 @@ async def update_model(
188134 async def update_model_by_column (
189135 self ,
190136 session : AsyncSession ,
191- column : str ,
192- column_value : Any ,
193- obj : _UpdateSchema | dict [str , Any ],
137+ obj : UpdateSchema | dict [str , Any ],
138+ allow_multiple : bool = False ,
194139 commit : bool = False ,
195140 ** kwargs ,
196141 ) -> int :
197142 """
198- Update an instance of model column
143+ Update an instance by model column
199144
200- :param session:
201- :param column:
202- :param column_value:
203- :param obj:
204- :param commit:
205- :param kwargs:
145+ :param session: The SQLAlchemy async session.
146+ :param obj: A pydantic schema or dictionary containing the update data
147+ :param allow_multiple: If `True`, allows updating multiple records that match the filters.
148+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
149+ :param kwargs: Query expressions.
206150 :return:
207151 """
152+ filters = await parse_filters (self .model , ** kwargs )
153+ total_count = await count (session , self .model , filters )
154+ if not allow_multiple and total_count > 1 :
155+ raise MultipleResultsError (f'Only one record is expected to be update, found { total_count } records.' )
208156 if isinstance (obj , dict ):
209157 instance_data = obj
210158 else :
211159 instance_data = obj .model_dump (exclude_unset = True )
212- if kwargs :
213- instance_data .update (kwargs )
214- if hasattr (self .model , column ):
215- model_column = getattr (self .model , column )
216- else :
217- raise ModelColumnError (f'Column { column } is not found in { self .model } ' )
218- stmt = sa_update (self .model ).where (model_column == column_value ).values (** instance_data ) # type: ignore
160+ stmt = sa_update (self .model ).where (* filters ).values (** instance_data ) # type: ignore
219161 result = await session .execute (stmt )
220162 if commit :
221163 await session .commit ()
222164 return result .rowcount # type: ignore
223165
224- async def delete_model (self , session : AsyncSession , pk : int , commit : bool = False , ** kwargs ) -> int :
166+ async def delete_model (self , session : AsyncSession , pk : int , commit : bool = False ) -> int :
225167 """
226- Delete an instance of a model
168+ Delete an instance by model's primary key
227169
228- :param session:
229- :param pk:
230- :param commit:
231- :param kwargs: for soft deletion only
170+ :param session: The SQLAlchemy async session.
171+ :param pk: The database primary key value.
172+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
232173 :return:
233174 """
234- if not kwargs :
235- stmt = sa_delete (self .model ).where (self .model .id == pk )
236- result = await session .execute (stmt )
237- else :
238- stmt = sa_update (self .model ).where (self .model .id == pk ).values (** kwargs )
239- result = await session .execute (stmt )
175+ stmt = sa_delete (self .model ).where (self .model .id == pk )
176+ result = await session .execute (stmt )
240177 if commit :
241178 await session .commit ()
242179 return result .rowcount # type: ignore
180+
181+ async def delete_model_by_column (
182+ self ,
183+ session : AsyncSession ,
184+ allow_multiple : bool = False ,
185+ logical_deletion : bool = False ,
186+ deleted_flag_column : str = 'del_flag' ,
187+ commit : bool = False ,
188+ ** kwargs ,
189+ ) -> int :
190+ """
191+ Delete
192+
193+ :param session: The SQLAlchemy async session.
194+ :param commit: If `True`, commits the transaction immediately. Default is `False`.
195+ :param kwargs: Query expressions.
196+ :param allow_multiple: If `True`, allows deleting multiple records that match the filters.
197+ :param logical_deletion: If `True`, enable logical deletion instead of physical deletion
198+ :param deleted_flag_column: Specify the flag column for logical deletion
199+ :return:
200+ """
201+ filters = await parse_filters (self .model , ** kwargs )
202+ total_count = await count (session , self .model , filters )
203+ if not allow_multiple and total_count > 1 :
204+ raise MultipleResultsError (f'Only one record is expected to be delete, found { total_count } records.' )
205+ if logical_deletion :
206+ deleted_flag = {deleted_flag_column : True }
207+ stmt = sa_update (self .model ).where (* filters ).values (** deleted_flag )
208+ else :
209+ stmt = sa_delete (self .model ).where (* filters )
210+ await session .execute (stmt )
211+ if commit :
212+ await session .commit ()
213+ return total_count
0 commit comments