Skip to content

Commit 671edeb

Browse files
delete global session
1 parent bcc5a9b commit 671edeb

File tree

5 files changed

+98
-78
lines changed

5 files changed

+98
-78
lines changed

src/quart_sqlalchemy/sim/handle.py

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from datetime import datetime
77

88
import sqlalchemy
9+
import sqlalchemy.orm
910
from dependency_injector.wiring import Provide
1011
from quart import Quart
1112

12-
from quart_sqlalchemy.session import SessionProxy
13+
from quart_sqlalchemy.session import provide_global_contextual_session
1314
from quart_sqlalchemy.sim import signals
1415
from quart_sqlalchemy.sim.logic import LogicComponent
1516
from quart_sqlalchemy.sim.model import AuthUser
@@ -44,14 +45,15 @@ class InvalidSubstringError(AuthUserBaseError):
4445

4546
class HandlerBase:
4647
logic: LogicComponent = Provide["logic"]
47-
session_factory = SessionProxy()
4848

4949

5050
class MagicClientHandler(HandlerBase):
5151
auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"]
5252

53+
@provide_global_contextual_session
5354
def add(
5455
self,
56+
session: sa.orm.Session,
5557
app_name=None,
5658
rate_limit_tier=None,
5759
connect_interop=None,
@@ -68,21 +70,24 @@ def add(
6870
"""
6971

7072
return self.logic.MagicClient.add(
71-
self.session_factory(),
73+
session,
7274
app_name=app_name,
7375
rate_limit_tier=rate_limit_tier,
7476
connect_interop=connect_interop,
7577
is_signing_modal_enabled=is_signing_modal_enabled,
7678
global_audience_enabled=global_audience_enabled,
7779
)
7880

79-
def get_by_public_api_key(self, public_api_key):
80-
return self.logic.MagicClient.get_by_public_api_key(self.session_factory(), public_api_key)
81+
@provide_global_contextual_session
82+
def get_by_public_api_key(self, session: sa.orm.Session, public_api_key):
83+
return self.logic.MagicClient.get_by_public_api_key(session, public_api_key)
8184

82-
def get_by_id(self, magic_client_id):
83-
return self.logic.MagicClient.get_by_id(self.session_factory(), magic_client_id)
85+
@provide_global_contextual_session
86+
def get_by_id(self, session: sa.orm.Session, magic_client_id):
87+
return self.logic.MagicClient.get_by_id(session, magic_client_id)
8488

85-
def update_app_name_by_id(self, magic_client_id, app_name):
89+
@provide_global_contextual_session
90+
def update_app_name_by_id(self, session: sa.orm.Session, magic_client_id, app_name):
8691
"""
8792
Args:
8893
magic_client_id (ObjectID|int|str): self explanatory.
@@ -92,36 +97,34 @@ def update_app_name_by_id(self, magic_client_id, app_name):
9297
None if `magic_client_id` doesn't exist in the db
9398
app_name if update was successful
9499
"""
95-
client = self.logic.MagicClient.update_by_id(
96-
self.session_factory(), magic_client_id, app_name=app_name
97-
)
100+
client = self.logic.MagicClient.update_by_id(session, magic_client_id, app_name=app_name)
98101

99102
if not client:
100103
return None
101104

102105
return client.app_name
103106

104-
def update_by_id(self, magic_client_id, **kwargs):
105-
client = self.logic.MagicClient.update_by_id(
106-
self.session_factory(), magic_client_id, **kwargs
107-
)
107+
@provide_global_contextual_session
108+
def update_by_id(self, session: sa.orm.Session, magic_client_id, **kwargs):
109+
client = self.logic.MagicClient.update_by_id(session, magic_client_id, **kwargs)
108110

109111
return client
110112

111-
def set_inactive_by_id(self, magic_client_id):
113+
@provide_global_contextual_session
114+
def set_inactive_by_id(self, session: sa.orm.Session, magic_client_id):
112115
"""
113116
Args:
114117
magic_client_id (ObjectID|int|str): self explanatory.
115118
116119
Returns:
117120
None
118121
"""
119-
self.logic.MagicClient.update_by_id(
120-
self.session_factory(), magic_client_id, is_active=False
121-
)
122+
self.logic.MagicClient.update_by_id(session, magic_client_id, is_active=False)
122123

124+
@provide_global_contextual_session
123125
def get_users_for_client(
124126
self,
127+
session: sa.orm.Session,
125128
magic_client_id,
126129
offset=None,
127130
limit=None,
@@ -131,6 +134,7 @@ def get_users_for_client(
131134
"""
132135
product_type = get_product_type_by_client_id(magic_client_id)
133136
auth_users = self.auth_user_handler.get_by_client_id_and_user_type(
137+
session,
134138
magic_client_id,
135139
product_type,
136140
offset=offset,
@@ -146,16 +150,18 @@ def get_users_for_client(
146150

147151

148152
class AuthUserHandler(HandlerBase):
149-
def get_by_session_token(self, session_token):
150-
return self.logic.AuthUser.get_by_session_token(self.session_factory(), session_token)
153+
@provide_global_contextual_session
154+
def get_by_session_token(self, session: sa.orm.Session, session_token):
155+
return self.logic.AuthUser.get_by_session_token(session, session_token)
151156

157+
@provide_global_contextual_session
152158
def get_or_create_by_email_and_client_id(
153159
self,
160+
session: sa.orm.Session,
154161
email,
155162
client_id,
156163
user_type=EntityType.MAGIC.value,
157164
):
158-
session = self.session_factory()
159165
with session.begin_nested():
160166
auth_user = self.logic.AuthUser.get_by_email_and_client_id(
161167
session,
@@ -173,14 +179,15 @@ def get_or_create_by_email_and_client_id(
173179
)
174180
return auth_user
175181

182+
@provide_global_contextual_session
176183
def create_verified_user(
177184
self,
185+
session: sa.orm.Session,
178186
client_id,
179187
email,
180188
user_type=EntityType.FORTMATIC.value,
181189
**kwargs,
182190
):
183-
session = self.session_factory()
184191
with session.begin_nested():
185192
auid = self.logic.AuthUser.add_by_email_and_client_id(
186193
session,
@@ -201,52 +208,66 @@ def create_verified_user(
201208

202209
return auth_user
203210

204-
def get_by_id(self, auth_user_id) -> AuthUser:
205-
return self.logic.AuthUser.get_by_id(self.session_factory(), auth_user_id)
211+
@provide_global_contextual_session
212+
def get_by_id(self, session: sa.orm.Session, auth_user_id) -> AuthUser:
213+
return self.logic.AuthUser.get_by_id(session, auth_user_id)
206214

215+
@provide_global_contextual_session
207216
def get_by_client_id_and_user_type(
208217
self,
218+
session: sa.orm.Session,
209219
client_id,
210220
user_type,
211221
offset=None,
212222
limit=None,
213223
):
214224
return self.logic.AuthUser.get_by_client_id_and_user_type(
215-
self.session_factory(),
225+
session,
216226
client_id,
217227
user_type,
218228
offset=offset,
219229
limit=limit,
220230
)
221231

222-
def exist_by_email_client_id_and_user_type(self, email, client_id, user_type):
232+
@provide_global_contextual_session
233+
def exist_by_email_client_id_and_user_type(
234+
self, session: sa.orm.Session, email, client_id, user_type
235+
):
223236
return self.logic.AuthUser.exist_by_email_and_client_id(
224-
self.session_factory(),
237+
session,
225238
email,
226239
client_id,
227240
user_type=user_type,
228241
)
229242

230-
def update_email_by_id(self, model_id, email):
231-
return self.logic.AuthUser.update_by_id(self.session_factory(), model_id, email=email)
243+
@provide_global_contextual_session
244+
def update_email_by_id(self, session: sa.orm.Session, model_id, email):
245+
return self.logic.AuthUser.update_by_id(session, model_id, email=email)
232246

233-
def get_by_email_client_id_and_user_type(self, email, client_id, user_type):
247+
@provide_global_contextual_session
248+
def get_by_email_client_id_and_user_type(
249+
self, session: sa.orm.Session, email, client_id, user_type
250+
):
234251
return self.logic.AuthUser.get_by_email_and_client_id(
235-
self.session_factory(),
252+
session,
236253
email,
237254
client_id,
238255
user_type,
239256
)
240257

241-
def mark_date_verified_by_id(self, model_id):
258+
@provide_global_contextual_session
259+
def mark_date_verified_by_id(self, session: sa.orm.Session, model_id):
242260
return self.logic.AuthUser.update_by_id(
243-
self.session_factory(),
261+
session,
244262
model_id,
245263
date_verified=datetime.utcnow(),
246264
)
247265

248-
def set_role_by_email_magic_client_id(self, email, magic_client_id, role):
249-
session = self.session_factory()
266+
@provide_global_contextual_session
267+
def set_role_by_email_magic_client_id(
268+
self, session: sa.orm.Session, email, magic_client_id, role
269+
):
270+
session = session
250271
auth_user = self.logic.AuthUser.get_by_email_and_client_id(
251272
session,
252273
email,
@@ -267,8 +288,9 @@ def set_role_by_email_magic_client_id(self, email, magic_client_id, role):
267288

268289
return self.logic.AuthUser.update_by_id(session, auth_user.id, **{role: True})
269290

270-
def mark_as_inactive(self, auth_user_id):
271-
self.logic.AuthUser.update_by_id(self.session_factory(), auth_user_id, is_active=False)
291+
@provide_global_contextual_session
292+
def mark_as_inactive(self, session: sa.orm.Session, auth_user_id):
293+
self.logic.AuthUser.update_by_id(session, auth_user_id, is_active=False)
272294

273295

274296
@signals.auth_user_duplicate.connect
@@ -282,37 +304,42 @@ def handle_duplicate_auth_users(
282304

283305

284306
class AuthWalletHandler(HandlerBase):
285-
def get_by_id(self, model_id):
286-
return self.logic.AuthWallet.get_by_id(self.session_factory(), model_id)
307+
@provide_global_contextual_session
308+
def get_by_id(self, session: sa.orm.Session, model_id):
309+
return self.logic.AuthWallet.get_by_id(session, model_id)
287310

288-
def get_by_public_address(self, public_address):
289-
return self.logic.AuthWallet().get_by_public_address(self.session_factory(), public_address)
311+
@provide_global_contextual_session
312+
def get_by_public_address(self, session: sa.orm.Session, public_address):
313+
return self.logic.AuthWallet().get_by_public_address(session, public_address)
290314

315+
@provide_global_contextual_session
291316
def get_by_auth_user_id(
292317
self,
318+
session: sa.orm.Session,
293319
auth_user_id: ObjectID,
294320
network: t.Optional[str] = None,
295321
wallet_type: t.Optional[WalletType] = None,
296322
**kwargs,
297323
) -> t.List[AuthWallet]:
298324
return self.logic.AuthWallet.get_by_auth_user_id(
299-
self.session_factory(),
325+
session,
300326
auth_user_id,
301327
network=network,
302328
wallet_type=wallet_type,
303329
**kwargs,
304330
)
305331

332+
@provide_global_contextual_session
306333
def sync_auth_wallet(
307334
self,
335+
session: sa.orm.Session,
308336
auth_user_id,
309337
public_address,
310338
encrypted_private_address,
311339
wallet_management_type,
312340
network: t.Optional[str] = None,
313341
wallet_type: t.Optional[WalletType] = None,
314342
):
315-
session = self.session_factory()
316343
with session.begin_nested():
317344
existing_wallet = self.logic.AuthWallet.get_by_auth_user_id(
318345
session,

src/quart_sqlalchemy/sim/logic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
import logging
32
import secrets
43
import typing as t
@@ -12,11 +11,8 @@
1211
from quart_sqlalchemy.sim import signals
1312
from quart_sqlalchemy.sim.model import AuthUser as auth_user_model
1413
from quart_sqlalchemy.sim.model import AuthWallet as auth_wallet_model
15-
from quart_sqlalchemy.sim.model import ConnectInteropStatus
1614
from quart_sqlalchemy.sim.model import EntityType
1715
from quart_sqlalchemy.sim.model import MagicClient as magic_client_model
18-
from quart_sqlalchemy.sim.model import Provenance
19-
from quart_sqlalchemy.sim.model import WalletType
2016
from quart_sqlalchemy.sim.repo_adapter import RepositoryLegacyAdapter
2117
from quart_sqlalchemy.sim.util import ObjectID
2218
from quart_sqlalchemy.sim.util import one

src/quart_sqlalchemy/sim/views/auth_user.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from dependency_injector.wiring import Provide
77

88
from quart_sqlalchemy.framework import QuartSQLAlchemy
9-
from quart_sqlalchemy.session import set_global_contextual_session
109

1110
from ..auth import authorized_request
1211
from ..auth import RequestCredentials
@@ -51,8 +50,7 @@ def get_auth_user(
5150
credentials: RequestCredentials = Provide["request_credentials"],
5251
) -> ResponseWrapper[AuthUserSchema]:
5352
with db.bind.Session() as session:
54-
with set_global_contextual_session(session):
55-
auth_user = auth_user_handler.get_by_session_token(credentials.current_user.value)
53+
auth_user = auth_user_handler.get_by_session_token(session, credentials.current_user.value)
5654

5755
return ResponseWrapper[AuthUserSchema](data=AuthUserSchema.from_orm(auth_user))
5856

@@ -76,12 +74,12 @@ def create_auth_user(
7674
) -> ResponseWrapper[CreateAuthUserResponse]:
7775
with db.bind.Session() as session:
7876
with session.begin():
79-
with set_global_contextual_session(session):
80-
client = auth_user_handler.create_verified_user(
81-
email=data.email,
82-
client_id=credentials.current_client.subject.id,
83-
user_type=EntityType.MAGIC.value,
84-
)
77+
client = auth_user_handler.create_verified_user(
78+
session,
79+
email=data.email,
80+
client_id=credentials.current_client.subject.id,
81+
user_type=EntityType.MAGIC.value,
82+
)
8583

8684
return ResponseWrapper[CreateAuthUserResponse](
8785
data=dict(auth_user=AuthUserSchema.from_orm(client)) # type: ignore

src/quart_sqlalchemy/sim/views/auth_wallet.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from quart import g
77

88
from quart_sqlalchemy.framework import QuartSQLAlchemy
9-
from quart_sqlalchemy.session import set_global_contextual_session
109

1110
from ..auth import authorized_request
1211
from ..auth import RequestCredentials
@@ -70,15 +69,15 @@ def sync(
7069
) -> ResponseWrapper[WalletSyncResponse]:
7170
with db.bind.Session() as session:
7271
with session.begin():
73-
with set_global_contextual_session(session):
74-
wallet = auth_wallet_handler.sync_auth_wallet(
75-
credentials.current_user.subject.id,
76-
data.public_address,
77-
data.encrypted_private_address,
78-
WalletManagementType.DELEGATED.value,
79-
network=web3.network,
80-
wallet_type=data.wallet_type,
81-
)
72+
wallet = auth_wallet_handler.sync_auth_wallet(
73+
session,
74+
credentials.current_user.subject.id,
75+
data.public_address,
76+
data.encrypted_private_address,
77+
WalletManagementType.DELEGATED.value,
78+
network=web3.network,
79+
wallet_type=data.wallet_type,
80+
)
8281

8382
return ResponseWrapper[WalletSyncResponse](
8483
data=dict(

0 commit comments

Comments
 (0)