44import functools
55import os
66import re
7- from unittest .mock import patch , MagicMock
7+ from unittest .mock import MagicMock , patch
88
99import pytest
1010import sqlalchemy
@@ -78,14 +78,18 @@ def process_result_value(self, value, dialect):
7878)
7979
8080
81- @pytest .fixture (autouse = True , scope = "module " )
81+ @pytest .fixture (autouse = True , scope = "function " )
8282def create_test_database ():
8383 # Create test databases with tables creation
8484 for url in DATABASE_URLS :
8585 database_url = DatabaseURL (url )
86- if database_url .scheme == "mysql" :
86+ if database_url .scheme in [ "mysql" , "mysql+aiomysql" ] :
8787 url = str (database_url .replace (driver = "pymysql" ))
88- elif database_url .scheme == "postgresql+aiopg" :
88+ elif database_url .scheme in [
89+ "postgresql+aiopg" ,
90+ "sqlite+aiosqlite" ,
91+ "postgresql+asyncpg" ,
92+ ]:
8993 url = str (database_url .replace (driver = None ))
9094 engine = sqlalchemy .create_engine (url )
9195 metadata .create_all (engine )
@@ -96,9 +100,13 @@ def create_test_database():
96100 # Drop test databases
97101 for url in DATABASE_URLS :
98102 database_url = DatabaseURL (url )
99- if database_url .scheme == "mysql" :
103+ if database_url .scheme in [ "mysql" , "mysql+aiomysql" ] :
100104 url = str (database_url .replace (driver = "pymysql" ))
101- elif database_url .scheme == "postgresql+aiopg" :
105+ elif database_url .scheme in [
106+ "postgresql+aiopg" ,
107+ "sqlite+aiosqlite" ,
108+ "postgresql+asyncpg" ,
109+ ]:
102110 url = str (database_url .replace (driver = None ))
103111 engine = sqlalchemy .create_engine (url )
104112 metadata .drop_all (engine )
@@ -478,9 +486,12 @@ async def test_transaction_commit_serializable(database_url):
478486
479487 database_url = DatabaseURL (database_url )
480488
481- if database_url .scheme != "postgresql" :
489+ if database_url .scheme not in [ "postgresql" , "postgresql+asyncpg" ] :
482490 pytest .skip ("Test (currently) only supports asyncpg" )
483491
492+ if database_url .scheme == "postgresql+asyncpg" :
493+ database_url = database_url .replace (driver = None )
494+
484495 def insert_independently ():
485496 engine = sqlalchemy .create_engine (str (database_url ))
486497 conn = engine .connect ()
@@ -844,26 +855,34 @@ async def test_queries_with_expose_backend_connection(database_url):
844855 raw_connection = connection .raw_connection
845856
846857 # Insert query
847- if database .url .scheme in ["mysql" , "postgresql+aiopg" ]:
858+ if database .url .scheme in [
859+ "mysql" ,
860+ "mysql+aiomysql" ,
861+ "postgresql+aiopg" ,
862+ ]:
848863 insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)"
849864 else :
850865 insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)"
851866
852867 # execute()
853868 values = ("example1" , True )
854869
855- if database .url .scheme in ["mysql" , "postgresql+aiopg" ]:
870+ if database .url .scheme in [
871+ "mysql" ,
872+ "mysql+aiomysql" ,
873+ "postgresql+aiopg" ,
874+ ]:
856875 cursor = await raw_connection .cursor ()
857876 await cursor .execute (insert_query , values )
858- elif database .url .scheme == "postgresql" :
877+ elif database .url .scheme in [ "postgresql" , "postgresql+asyncpg" ] :
859878 await raw_connection .execute (insert_query , * values )
860- elif database .url .scheme == "sqlite" :
879+ elif database .url .scheme in [ "sqlite" , "sqlite+aiosqlite" ] :
861880 await raw_connection .execute (insert_query , values )
862881
863882 # execute_many()
864883 values = [("example2" , False ), ("example3" , True )]
865884
866- if database .url .scheme == "mysql" :
885+ if database .url .scheme in [ "mysql" , "mysql+aiomysql" ] :
867886 cursor = await raw_connection .cursor ()
868887 await cursor .executemany (insert_query , values )
869888 elif database .url .scheme == "postgresql+aiopg" :
@@ -878,13 +897,17 @@ async def test_queries_with_expose_backend_connection(database_url):
878897 select_query = "SELECT notes.id, notes.text, notes.completed FROM notes"
879898
880899 # fetch_all()
881- if database .url .scheme in ["mysql" , "postgresql+aiopg" ]:
900+ if database .url .scheme in [
901+ "mysql" ,
902+ "mysql+aiomysql" ,
903+ "postgresql+aiopg" ,
904+ ]:
882905 cursor = await raw_connection .cursor ()
883906 await cursor .execute (select_query )
884907 results = await cursor .fetchall ()
885- elif database .url .scheme == "postgresql" :
908+ elif database .url .scheme in [ "postgresql" , "postgresql+asyncpg" ] :
886909 results = await raw_connection .fetch (select_query )
887- elif database .url .scheme == "sqlite" :
910+ elif database .url .scheme in [ "sqlite" , "sqlite+aiosqlite" ] :
888911 results = await raw_connection .execute_fetchall (select_query )
889912
890913 assert len (results ) == 3
@@ -897,7 +920,7 @@ async def test_queries_with_expose_backend_connection(database_url):
897920 assert results [2 ][2 ] == True
898921
899922 # fetch_one()
900- if database .url .scheme == "postgresql" :
923+ if database .url .scheme in [ "postgresql" , "postgresql+asyncpg" ] :
901924 result = await raw_connection .fetchrow (select_query )
902925 else :
903926 cursor = await raw_connection .cursor ()
@@ -1065,8 +1088,8 @@ async def test_posgres_interface(database_url):
10651088 """
10661089 database_url = DatabaseURL (database_url )
10671090
1068- if database_url .scheme != "postgresql" :
1069- pytest .skip ("Test is only for postgresql " )
1091+ if database_url .scheme not in [ "postgresql" , "postgresql+asyncpg" ] :
1092+ pytest .skip ("Test is only for asyncpg " )
10701093
10711094 async with Database (database_url ) as database :
10721095 async with database .transaction (force_rollback = True ):
0 commit comments