@@ -564,7 +564,7 @@ def select(query):
564564 except GeneratorExit :
565565 pass
566566
567- cursor .release ()
567+ yield from cursor .release
568568 return result
569569
570570
@@ -584,7 +584,7 @@ def insert(query):
584584 result = yield from query .database .last_insert_id_async (
585585 cursor , query .model_class )
586586
587- cursor .release ()
587+ yield from cursor .release
588588 return result
589589
590590
@@ -599,7 +599,7 @@ def update(query):
599599 cursor = yield from _execute_query_async (query )
600600 rowcount = cursor .rowcount
601601
602- cursor .release ()
602+ yield from cursor .release
603603 return rowcount
604604
605605
@@ -614,7 +614,7 @@ def delete(query):
614614 cursor = yield from _execute_query_async (query )
615615 rowcount = cursor .rowcount
616616
617- cursor .release ()
617+ yield from cursor .release
618618 return rowcount
619619
620620
@@ -650,7 +650,7 @@ def scalar(query, as_tuple=False):
650650 cursor = yield from _execute_query_async (query )
651651 row = yield from cursor .fetchone ()
652652
653- cursor .release ()
653+ yield from cursor .release
654654 if row and not as_tuple :
655655 return row [0 ]
656656 else :
@@ -672,7 +672,7 @@ def raw_query(query):
672672 except GeneratorExit :
673673 pass
674674
675- cursor .release ()
675+ yield from cursor .release
676676 return result
677677
678678
@@ -983,36 +983,35 @@ def connect(self):
983983 ** self .connect_kwargs )
984984
985985 @asyncio .coroutine
986- def cursor (self , conn = None , * args , ** kwargs ):
987- """Get cursor for connection from pool.
986+ def close (self ):
987+ """Terminate all pool connections .
988988 """
989- if conn is None :
990- # Acquire connection with cursor, once cursor is released
991- # connection is also released to pool:
989+ self .pool .terminate ()
990+ yield from self .pool .wait_closed ()
992991
992+ @asyncio .coroutine
993+ def cursor (self , conn = None , * args , ** kwargs ):
994+ """Get a cursor for the specified transaction connection
995+ or acquire from the pool.
996+ """
997+ in_transaction = conn is not None
998+ if not conn :
993999 conn = yield from self .acquire ()
994- cursor = yield from conn .cursor (* args , ** kwargs )
995-
996- def release ():
997- cursor .close ()
998- self .pool .release (conn )
999- cursor .release = release
1000- else :
1001- # Acquire cursor from provided connection, after cursor is
1002- # released connection is NOT released to pool, i.e.
1003- # for handling transactions:
1004-
1005- cursor = yield from conn .cursor (* args , ** kwargs )
1006- cursor .release = lambda : cursor .close ()
1007-
1000+ cursor = yield from conn .cursor (* args , ** kwargs )
1001+ # NOTE: `cursor.release` is an awaitable object!
1002+ cursor .release = self .release_cursor (
1003+ cursor , in_transaction = in_transaction )
10081004 return cursor
10091005
10101006 @asyncio .coroutine
1011- def close (self ):
1012- """Terminate all pool connections.
1007+ def release_cursor (self , cursor , in_transaction = False ):
1008+ """Release cursor coroutine. Unless in transaction,
1009+ the connection is also released back to the pool.
10131010 """
1014- self .pool .terminate ()
1015- yield from self .pool .wait_closed ()
1011+ conn = cursor .connection
1012+ cursor .close ()
1013+ if not in_transaction :
1014+ self .pool .release (conn )
10161015
10171016
10181017class AsyncPostgresqlMixin (AsyncDatabase ):
@@ -1143,37 +1142,35 @@ def connect(self):
11431142 connect_timeout = self .timeout ,
11441143 ** self .connect_kwargs )
11451144
1145+ @asyncio .coroutine
1146+ def close (self ):
1147+ """Terminate all pool connections.
1148+ """
1149+ self .pool .terminate ()
1150+ yield from self .pool .wait_closed ()
1151+
11461152 @asyncio .coroutine
11471153 def cursor (self , conn = None , * args , ** kwargs ):
11481154 """Get cursor for connection from pool.
11491155 """
1150- if conn is None :
1151- # Acquire connection with cursor, once cursor is released
1152- # connection is also released to pool:
1153-
1156+ in_transaction = conn is not None
1157+ if not conn :
11541158 conn = yield from self .acquire ()
1155- cursor = yield from conn .cursor (* args , ** kwargs )
1156-
1157- def release ():
1158- cursor .close ()
1159- self .pool .release (conn )
1160- cursor .release = release
1161- else :
1162- # Acquire cursor from provided connection, after cursor is
1163- # released connection is NOT released to pool, i.e.
1164- # for handling transactions:
1165-
1166- cursor = yield from conn .cursor (* args , ** kwargs )
1167- cursor .release = lambda : cursor .close ()
1168-
1159+ cursor = yield from conn .cursor (* args , ** kwargs )
1160+ # NOTE: `cursor.release` is an awaitable object!
1161+ cursor .release = self .release_cursor (
1162+ cursor , in_transaction = in_transaction )
11691163 return cursor
11701164
11711165 @asyncio .coroutine
1172- def close (self ):
1173- """Terminate all pool connections.
1166+ def release_cursor (self , cursor , in_transaction = False ):
1167+ """Release cursor coroutine. Unless in transaction,
1168+ the connection is also released back to the pool.
11741169 """
1175- self .pool .terminate ()
1176- yield from self .pool .wait_closed ()
1170+ conn = cursor .connection
1171+ yield from cursor .close ()
1172+ if not in_transaction :
1173+ self .pool .release (conn )
11771174
11781175
11791176class MySQLDatabase (AsyncDatabase , peewee .MySQLDatabase ):
@@ -1395,7 +1392,7 @@ def _run_sql(database, operation, *args, **kwargs):
13951392 try :
13961393 yield from cursor .execute (operation , * args , ** kwargs )
13971394 except :
1398- cursor .release ()
1395+ yield from cursor .release
13991396 raise
14001397
14011398 return cursor
0 commit comments