|
14 | 14 | from databases.interfaces import DatabaseBackend, Record |
15 | 15 |
|
16 | 16 | if sys.version_info >= (3, 7): # pragma: no cover |
17 | | - import contextvars as contextvars |
| 17 | + from contextvars import ContextVar |
18 | 18 | else: # pragma: no cover |
19 | | - import aiocontextvars as contextvars |
| 19 | + from aiocontextvars import ContextVar |
20 | 20 |
|
21 | 21 | try: # pragma: no cover |
22 | 22 | import click |
@@ -69,9 +69,7 @@ def __init__( |
69 | 69 | self._backend = backend_cls(self.url, **self.options) |
70 | 70 |
|
71 | 71 | # Connections are stored as task-local state. |
72 | | - self._connection_context = contextvars.ContextVar( |
73 | | - "connection_context" |
74 | | - ) # type: contextvars.ContextVar |
| 72 | + self._connection_context = ContextVar("connection_context") # type: ContextVar |
75 | 73 |
|
76 | 74 | # When `force_rollback=True` is used, we use a single global |
77 | 75 | # connection, within a transaction that always rolls back. |
@@ -120,7 +118,7 @@ async def disconnect(self) -> None: |
120 | 118 | self._global_transaction = None |
121 | 119 | self._global_connection = None |
122 | 120 | else: |
123 | | - self._connection_context = contextvars.ContextVar("connection_context") |
| 121 | + self._connection_context = ContextVar("connection_context") |
124 | 122 |
|
125 | 123 | await self._backend.disconnect() |
126 | 124 | logger.info( |
@@ -182,35 +180,21 @@ async def iterate( |
182 | 180 | async for record in connection.iterate(query, values): |
183 | 181 | yield record |
184 | 182 |
|
185 | | - def _new_connection(self) -> "Connection": |
186 | | - connection = Connection(self._backend) |
187 | | - self._connection_context.set(connection) |
188 | | - return connection |
189 | | - |
190 | 183 | def connection(self) -> "Connection": |
191 | 184 | if self._global_connection is not None: |
192 | 185 | return self._global_connection |
193 | 186 |
|
194 | 187 | try: |
195 | 188 | return self._connection_context.get() |
196 | 189 | except LookupError: |
197 | | - return self._new_connection() |
| 190 | + connection = Connection(self._backend) |
| 191 | + self._connection_context.set(connection) |
| 192 | + return connection |
198 | 193 |
|
199 | 194 | def transaction( |
200 | 195 | self, *, force_rollback: bool = False, **kwargs: typing.Any |
201 | 196 | ) -> "Transaction": |
202 | | - try: |
203 | | - connection = self._connection_context.get() |
204 | | - is_root = not connection._transaction_stack |
205 | | - if is_root: |
206 | | - newcontext = contextvars.copy_context() |
207 | | - get_conn = lambda: newcontext.run(self._new_connection) |
208 | | - else: |
209 | | - get_conn = self.connection |
210 | | - except LookupError: |
211 | | - get_conn = self.connection |
212 | | - |
213 | | - return Transaction(get_conn, force_rollback=force_rollback, **kwargs) |
| 197 | + return Transaction(self.connection, force_rollback=force_rollback, **kwargs) |
214 | 198 |
|
215 | 199 | @contextlib.contextmanager |
216 | 200 | def force_rollback(self) -> typing.Iterator[None]: |
|
0 commit comments