66from datetime import datetime
77
88import sqlalchemy
9+ import sqlalchemy .orm
910from dependency_injector .wiring import Provide
1011from quart import Quart
1112
12- from quart_sqlalchemy .session import SessionProxy
13+ from quart_sqlalchemy .session import provide_global_contextual_session
1314from quart_sqlalchemy .sim import signals
1415from quart_sqlalchemy .sim .logic import LogicComponent
1516from quart_sqlalchemy .sim .model import AuthUser
@@ -44,14 +45,15 @@ class InvalidSubstringError(AuthUserBaseError):
4445
4546class HandlerBase :
4647 logic : LogicComponent = Provide ["logic" ]
47- session_factory = SessionProxy ()
4848
4949
5050class MagicClientHandler (HandlerBase ):
5151 auth_user_handler : AuthUserHandler = Provide ["AuthUserHandler" ]
5252
53+ @provide_global_contextual_session
5354 def add (
5455 self ,
56+ session : sa .orm .Session ,
5557 app_name = None ,
5658 rate_limit_tier = None ,
5759 connect_interop = None ,
@@ -68,21 +70,24 @@ def add(
6870 """
6971
7072 return self .logic .MagicClient .add (
71- self . session_factory () ,
73+ session ,
7274 app_name = app_name ,
7375 rate_limit_tier = rate_limit_tier ,
7476 connect_interop = connect_interop ,
7577 is_signing_modal_enabled = is_signing_modal_enabled ,
7678 global_audience_enabled = global_audience_enabled ,
7779 )
7880
79- def get_by_public_api_key (self , public_api_key ):
80- return self .logic .MagicClient .get_by_public_api_key (self .session_factory (), public_api_key )
81+ @provide_global_contextual_session
82+ def get_by_public_api_key (self , session : sa .orm .Session , public_api_key ):
83+ return self .logic .MagicClient .get_by_public_api_key (session , public_api_key )
8184
82- def get_by_id (self , magic_client_id ):
83- return self .logic .MagicClient .get_by_id (self .session_factory (), magic_client_id )
85+ @provide_global_contextual_session
86+ def get_by_id (self , session : sa .orm .Session , magic_client_id ):
87+ return self .logic .MagicClient .get_by_id (session , magic_client_id )
8488
85- def update_app_name_by_id (self , magic_client_id , app_name ):
89+ @provide_global_contextual_session
90+ def update_app_name_by_id (self , session : sa .orm .Session , magic_client_id , app_name ):
8691 """
8792 Args:
8893 magic_client_id (ObjectID|int|str): self explanatory.
@@ -92,36 +97,34 @@ def update_app_name_by_id(self, magic_client_id, app_name):
9297 None if `magic_client_id` doesn't exist in the db
9398 app_name if update was successful
9499 """
95- client = self .logic .MagicClient .update_by_id (
96- self .session_factory (), magic_client_id , app_name = app_name
97- )
100+ client = self .logic .MagicClient .update_by_id (session , magic_client_id , app_name = app_name )
98101
99102 if not client :
100103 return None
101104
102105 return client .app_name
103106
104- def update_by_id (self , magic_client_id , ** kwargs ):
105- client = self .logic .MagicClient .update_by_id (
106- self .session_factory (), magic_client_id , ** kwargs
107- )
107+ @provide_global_contextual_session
108+ def update_by_id (self , session : sa .orm .Session , magic_client_id , ** kwargs ):
109+ client = self .logic .MagicClient .update_by_id (session , magic_client_id , ** kwargs )
108110
109111 return client
110112
111- def set_inactive_by_id (self , magic_client_id ):
113+ @provide_global_contextual_session
114+ def set_inactive_by_id (self , session : sa .orm .Session , magic_client_id ):
112115 """
113116 Args:
114117 magic_client_id (ObjectID|int|str): self explanatory.
115118
116119 Returns:
117120 None
118121 """
119- self .logic .MagicClient .update_by_id (
120- self .session_factory (), magic_client_id , is_active = False
121- )
122+ self .logic .MagicClient .update_by_id (session , magic_client_id , is_active = False )
122123
124+ @provide_global_contextual_session
123125 def get_users_for_client (
124126 self ,
127+ session : sa .orm .Session ,
125128 magic_client_id ,
126129 offset = None ,
127130 limit = None ,
@@ -131,6 +134,7 @@ def get_users_for_client(
131134 """
132135 product_type = get_product_type_by_client_id (magic_client_id )
133136 auth_users = self .auth_user_handler .get_by_client_id_and_user_type (
137+ session ,
134138 magic_client_id ,
135139 product_type ,
136140 offset = offset ,
@@ -146,16 +150,18 @@ def get_users_for_client(
146150
147151
148152class AuthUserHandler (HandlerBase ):
149- def get_by_session_token (self , session_token ):
150- return self .logic .AuthUser .get_by_session_token (self .session_factory (), session_token )
153+ @provide_global_contextual_session
154+ def get_by_session_token (self , session : sa .orm .Session , session_token ):
155+ return self .logic .AuthUser .get_by_session_token (session , session_token )
151156
157+ @provide_global_contextual_session
152158 def get_or_create_by_email_and_client_id (
153159 self ,
160+ session : sa .orm .Session ,
154161 email ,
155162 client_id ,
156163 user_type = EntityType .MAGIC .value ,
157164 ):
158- session = self .session_factory ()
159165 with session .begin_nested ():
160166 auth_user = self .logic .AuthUser .get_by_email_and_client_id (
161167 session ,
@@ -173,14 +179,15 @@ def get_or_create_by_email_and_client_id(
173179 )
174180 return auth_user
175181
182+ @provide_global_contextual_session
176183 def create_verified_user (
177184 self ,
185+ session : sa .orm .Session ,
178186 client_id ,
179187 email ,
180188 user_type = EntityType .FORTMATIC .value ,
181189 ** kwargs ,
182190 ):
183- session = self .session_factory ()
184191 with session .begin_nested ():
185192 auid = self .logic .AuthUser .add_by_email_and_client_id (
186193 session ,
@@ -201,52 +208,66 @@ def create_verified_user(
201208
202209 return auth_user
203210
204- def get_by_id (self , auth_user_id ) -> AuthUser :
205- return self .logic .AuthUser .get_by_id (self .session_factory (), auth_user_id )
211+ @provide_global_contextual_session
212+ def get_by_id (self , session : sa .orm .Session , auth_user_id ) -> AuthUser :
213+ return self .logic .AuthUser .get_by_id (session , auth_user_id )
206214
215+ @provide_global_contextual_session
207216 def get_by_client_id_and_user_type (
208217 self ,
218+ session : sa .orm .Session ,
209219 client_id ,
210220 user_type ,
211221 offset = None ,
212222 limit = None ,
213223 ):
214224 return self .logic .AuthUser .get_by_client_id_and_user_type (
215- self . session_factory () ,
225+ session ,
216226 client_id ,
217227 user_type ,
218228 offset = offset ,
219229 limit = limit ,
220230 )
221231
222- def exist_by_email_client_id_and_user_type (self , email , client_id , user_type ):
232+ @provide_global_contextual_session
233+ def exist_by_email_client_id_and_user_type (
234+ self , session : sa .orm .Session , email , client_id , user_type
235+ ):
223236 return self .logic .AuthUser .exist_by_email_and_client_id (
224- self . session_factory () ,
237+ session ,
225238 email ,
226239 client_id ,
227240 user_type = user_type ,
228241 )
229242
230- def update_email_by_id (self , model_id , email ):
231- return self .logic .AuthUser .update_by_id (self .session_factory (), model_id , email = email )
243+ @provide_global_contextual_session
244+ def update_email_by_id (self , session : sa .orm .Session , model_id , email ):
245+ return self .logic .AuthUser .update_by_id (session , model_id , email = email )
232246
233- def get_by_email_client_id_and_user_type (self , email , client_id , user_type ):
247+ @provide_global_contextual_session
248+ def get_by_email_client_id_and_user_type (
249+ self , session : sa .orm .Session , email , client_id , user_type
250+ ):
234251 return self .logic .AuthUser .get_by_email_and_client_id (
235- self . session_factory () ,
252+ session ,
236253 email ,
237254 client_id ,
238255 user_type ,
239256 )
240257
241- def mark_date_verified_by_id (self , model_id ):
258+ @provide_global_contextual_session
259+ def mark_date_verified_by_id (self , session : sa .orm .Session , model_id ):
242260 return self .logic .AuthUser .update_by_id (
243- self . session_factory () ,
261+ session ,
244262 model_id ,
245263 date_verified = datetime .utcnow (),
246264 )
247265
248- def set_role_by_email_magic_client_id (self , email , magic_client_id , role ):
249- session = self .session_factory ()
266+ @provide_global_contextual_session
267+ def set_role_by_email_magic_client_id (
268+ self , session : sa .orm .Session , email , magic_client_id , role
269+ ):
270+ session = session
250271 auth_user = self .logic .AuthUser .get_by_email_and_client_id (
251272 session ,
252273 email ,
@@ -267,8 +288,9 @@ def set_role_by_email_magic_client_id(self, email, magic_client_id, role):
267288
268289 return self .logic .AuthUser .update_by_id (session , auth_user .id , ** {role : True })
269290
270- def mark_as_inactive (self , auth_user_id ):
271- self .logic .AuthUser .update_by_id (self .session_factory (), auth_user_id , is_active = False )
291+ @provide_global_contextual_session
292+ def mark_as_inactive (self , session : sa .orm .Session , auth_user_id ):
293+ self .logic .AuthUser .update_by_id (session , auth_user_id , is_active = False )
272294
273295
274296@signals .auth_user_duplicate .connect
@@ -282,37 +304,42 @@ def handle_duplicate_auth_users(
282304
283305
284306class AuthWalletHandler (HandlerBase ):
285- def get_by_id (self , model_id ):
286- return self .logic .AuthWallet .get_by_id (self .session_factory (), model_id )
307+ @provide_global_contextual_session
308+ def get_by_id (self , session : sa .orm .Session , model_id ):
309+ return self .logic .AuthWallet .get_by_id (session , model_id )
287310
288- def get_by_public_address (self , public_address ):
289- return self .logic .AuthWallet ().get_by_public_address (self .session_factory (), public_address )
311+ @provide_global_contextual_session
312+ def get_by_public_address (self , session : sa .orm .Session , public_address ):
313+ return self .logic .AuthWallet ().get_by_public_address (session , public_address )
290314
315+ @provide_global_contextual_session
291316 def get_by_auth_user_id (
292317 self ,
318+ session : sa .orm .Session ,
293319 auth_user_id : ObjectID ,
294320 network : t .Optional [str ] = None ,
295321 wallet_type : t .Optional [WalletType ] = None ,
296322 ** kwargs ,
297323 ) -> t .List [AuthWallet ]:
298324 return self .logic .AuthWallet .get_by_auth_user_id (
299- self . session_factory () ,
325+ session ,
300326 auth_user_id ,
301327 network = network ,
302328 wallet_type = wallet_type ,
303329 ** kwargs ,
304330 )
305331
332+ @provide_global_contextual_session
306333 def sync_auth_wallet (
307334 self ,
335+ session : sa .orm .Session ,
308336 auth_user_id ,
309337 public_address ,
310338 encrypted_private_address ,
311339 wallet_management_type ,
312340 network : t .Optional [str ] = None ,
313341 wallet_type : t .Optional [WalletType ] = None ,
314342 ):
315- session = self .session_factory ()
316343 with session .begin_nested ():
317344 existing_wallet = self .logic .AuthWallet .get_by_auth_user_id (
318345 session ,
0 commit comments