@@ -19,33 +19,43 @@ class CRUDPlus(Generic[_Model]):
1919 def __init__ (self , model : Type [_Model ]):
2020 self .model = model
2121
22- async def create_model (self , session : AsyncSession , obj : _CreateSchema , ** kwargs ) -> None :
22+ async def create_model (self , session : AsyncSession , obj : _CreateSchema , commit : bool = False , ** kwargs ) -> _Model :
2323 """
2424 Create a new instance of a model
2525
2626 :param session:
2727 :param obj:
28+ :param commit:
2829 :param kwargs:
2930 :return:
3031 """
3132 if kwargs :
32- instance = self .model (** obj .model_dump (), ** kwargs )
33+ ins = self .model (** obj .model_dump (), ** kwargs )
3334 else :
34- instance = self .model (** obj .model_dump ())
35- session .add (instance )
35+ ins = self .model (** obj .model_dump ())
36+ session .add (ins )
37+ if commit :
38+ await session .commit ()
39+ return ins
3640
37- async def create_models (self , session : AsyncSession , obj : Iterable [_CreateSchema ]) -> None :
41+ async def create_models (
42+ self , session : AsyncSession , obj : Iterable [_CreateSchema ], commit : bool = False
43+ ) -> list [_Model ]:
3844 """
3945 Create new instances of a model
4046
4147 :param session:
4248 :param obj:
49+ :param commit:
4350 :return:
4451 """
45- instance_list = []
52+ ins_list = []
4653 for i in obj :
47- instance_list .append (self .model (** i .model_dump ()))
48- session .add_all (instance_list )
54+ ins_list .append (self .model (** i .model_dump ()))
55+ session .add_all (ins_list )
56+ if commit :
57+ await session .commit ()
58+ return ins_list
4959
5060 async def select_model_by_id (self , session : AsyncSession , pk : int ) -> _Model | None :
5161 """
@@ -55,7 +65,8 @@ async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | N
5565 :param pk:
5666 :return:
5767 """
58- query = await session .execute (select (self .model ).where (self .model .id == pk ))
68+ stmt = select (self .model ).where (self .model .id == pk )
69+ query = await session .execute (stmt )
5970 return query .scalars ().first ()
6071
6172 async def select_model_by_column (self , session : AsyncSession , column : str , column_value : Any ) -> _Model | None :
@@ -69,10 +80,11 @@ async def select_model_by_column(self, session: AsyncSession, column: str, colum
6980 """
7081 if hasattr (self .model , column ):
7182 model_column = getattr (self .model , column )
72- query = await session .execute (select (self .model ).where (model_column == column_value )) # type: ignore
83+ stmt = select (self .model ).where (model_column == column_value ) # type: ignore
84+ query = await session .execute (stmt )
7385 return query .scalars ().first ()
7486 else :
75- raise ModelColumnError (f'Model column { column } is not found' )
87+ raise ModelColumnError (f'Column { column } is not found in { self . model } ' )
7688
7789 async def select_model_by_columns (
7890 self , session : AsyncSession , expression : Literal ['and' , 'or' ] = 'and' , ** conditions
@@ -91,31 +103,36 @@ async def select_model_by_columns(
91103 model_column = getattr (self .model , column )
92104 where_list .append (model_column == value )
93105 else :
94- raise ModelColumnError (f'Model column { column } is not found' )
106+ raise ModelColumnError (f'Column { column } is not found in { self . model } ' )
95107 match expression :
96108 case 'and' :
97- query = await session .execute (select (self .model ).where (and_ (* where_list )))
109+ stmt = select (self .model ).where (and_ (* where_list ))
110+ query = await session .execute (stmt )
98111 case 'or' :
99- query = await session .execute (select (self .model ).where (or_ (* where_list )))
112+ stmt = select (self .model ).where (or_ (* where_list ))
113+ query = await session .execute (stmt )
100114 case _:
101- raise SelectExpressionError (f'select expression { expression } is not supported' )
115+ raise SelectExpressionError (
116+ f'Select expression { expression } is not supported, only supports `and`, `or`'
117+ )
102118 return query .scalars ().first ()
103119
104- async def select_models (self , session : AsyncSession ) -> Sequence [Row | RowMapping | Any ] | None :
120+ async def select_models (self , session : AsyncSession ) -> Sequence [Row [ Any ] | RowMapping | Any ]:
105121 """
106122 Query all rows
107123
108124 :param session:
109125 :return:
110126 """
111- query = await session .execute (select (self .model ))
127+ stmt = select (self .model )
128+ query = await session .execute (stmt )
112129 return query .scalars ().all ()
113130
114131 async def select_models_order (
115132 self ,
116133 session : AsyncSession ,
117134 * columns ,
118- model_sort : Literal ['default' , ' asc' , 'desc' ] = 'default ' ,
135+ model_sort : Literal ['asc' , 'desc' ] = 'desc ' ,
119136 ) -> Sequence [Row | RowMapping | Any ] | None :
120137 """
121138 Query all rows asc or desc
@@ -131,25 +148,28 @@ async def select_models_order(
131148 model_column = getattr (self .model , column )
132149 sort_list .append (model_column )
133150 else :
134- raise ModelColumnError (f'Model column { column } is not found' )
151+ raise ModelColumnError (f'Column { column } is not found in { self . model } ' )
135152 match model_sort :
136- case 'default' :
137- query = await session .execute (select (self .model ).order_by (* sort_list ))
138153 case 'asc' :
139154 query = await session .execute (select (self .model ).order_by (asc (* sort_list )))
140155 case 'desc' :
141156 query = await session .execute (select (self .model ).order_by (desc (* sort_list )))
142157 case _:
143- raise SelectExpressionError (f'select sort expression { model_sort } is not supported' )
158+ raise SelectExpressionError (
159+ f'Select sort expression { model_sort } is not supported, only supports `asc`, `desc`'
160+ )
144161 return query .scalars ().all ()
145162
146- async def update_model (self , session : AsyncSession , pk : int , obj : _UpdateSchema | dict [str , Any ], ** kwargs ) -> int :
163+ async def update_model (
164+ self , session : AsyncSession , pk : int , obj : _UpdateSchema | dict [str , Any ], commit : bool = False , ** kwargs
165+ ) -> int :
147166 """
148167 Update an instance of model's primary key
149168
150169 :param session:
151170 :param pk:
152171 :param obj:
172+ :param commit:
153173 :param kwargs:
154174 :return:
155175 """
@@ -159,11 +179,20 @@ async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema
159179 instance_data = obj .model_dump (exclude_unset = True )
160180 if kwargs :
161181 instance_data .update (kwargs )
162- result = await session .execute (sa_update (self .model ).where (self .model .id == pk ).values (** instance_data ))
182+ stmt = sa_update (self .model ).where (self .model .id == pk ).values (** instance_data )
183+ result = await session .execute (stmt )
184+ if commit :
185+ await session .commit ()
163186 return result .rowcount # type: ignore
164187
165188 async def update_model_by_column (
166- self , session : AsyncSession , column : str , column_value : Any , obj : _UpdateSchema | dict [str , Any ], ** kwargs
189+ self ,
190+ session : AsyncSession ,
191+ column : str ,
192+ column_value : Any ,
193+ obj : _UpdateSchema | dict [str , Any ],
194+ commit : bool = False ,
195+ ** kwargs ,
167196 ) -> int :
168197 """
169198 Update an instance of model column
@@ -172,6 +201,7 @@ async def update_model_by_column(
172201 :param column:
173202 :param column_value:
174203 :param obj:
204+ :param commit:
175205 :param kwargs:
176206 :return:
177207 """
@@ -184,23 +214,29 @@ async def update_model_by_column(
184214 if hasattr (self .model , column ):
185215 model_column = getattr (self .model , column )
186216 else :
187- raise ModelColumnError (f'Model column { column } is not found' )
188- result = await session .execute (
189- sa_update (self .model ).where (model_column == column_value ).values (** instance_data )
190- )
217+ raise ModelColumnError (f'Column { column } is not found in { self .model } ' )
218+ stmt = sa_update (self .model ).where (model_column == column_value ).values (** instance_data ) # type: ignore
219+ result = await session .execute (stmt )
220+ if commit :
221+ await session .commit ()
191222 return result .rowcount # type: ignore
192223
193- async def delete_model (self , session : AsyncSession , pk : int , ** kwargs ) -> int :
224+ async def delete_model (self , session : AsyncSession , pk : int , commit : bool = False , ** kwargs ) -> int :
194225 """
195226 Delete an instance of a model
196227
197228 :param session:
198229 :param pk:
230+ :param commit:
199231 :param kwargs: for soft deletion only
200232 :return:
201233 """
202234 if not kwargs :
203- result = await session .execute (sa_delete (self .model ).where (self .model .id == pk ))
235+ stmt = sa_delete (self .model ).where (self .model .id == pk )
236+ result = await session .execute (stmt )
204237 else :
205- result = await session .execute (sa_update (self .model ).where (self .model .id == pk ).values (** kwargs ))
238+ stmt = sa_update (self .model ).where (self .model .id == pk ).values (** kwargs )
239+ result = await session .execute (stmt )
240+ if commit :
241+ await session .commit ()
206242 return result .rowcount # type: ignore
0 commit comments