Skip to content

Commit 155c639

Browse files
committed
parse db name out of connection string
1 parent ac62bcb commit 155c639

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

django_mongodb_backend/base.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pymongo.collection import Collection
1212
from pymongo.driver_info import DriverInfo
1313
from pymongo.mongo_client import MongoClient
14+
from pymongo.uri_parser import parse_uri
1415

1516
from . import __version__ as django_mongodb_backend_version
1617
from . import dbapi as Database
@@ -157,6 +158,18 @@ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
157158
self.in_atomic_block_mongo = False
158159
# Current number of nested 'atomic' calls.
159160
self.nested_atomics = 0
161+
# If "NAME" isn't specified, try to get the database name from HOST,
162+
# if it's a connection string.
163+
if self.settings_dict["NAME"] == "": # Empty string = unspecified; None = _nodb_cursor()
164+
name_is_missing = True
165+
host = self.settings_dict["HOST"]
166+
if host.startswith(("mongodb://", "mongodb+srv://")):
167+
uri = parse_uri(host)
168+
if database := uri.get("database"):
169+
self.settings_dict["NAME"] = database
170+
name_is_missing = False
171+
if name_is_missing:
172+
raise ImproperlyConfigured('settings.DATABASES is missing the "NAME" value.')
160173

161174
def get_collection(self, name, **kwargs):
162175
collection = Collection(self.database, name, **kwargs)
@@ -183,8 +196,6 @@ def init_connection_state(self):
183196

184197
def get_connection_params(self):
185198
settings_dict = self.settings_dict
186-
if not settings_dict["NAME"]:
187-
raise ImproperlyConfigured('settings.DATABASES is missing the "NAME" value.')
188199
params = {
189200
"host": settings_dict["HOST"] or None,
190201
**settings_dict["OPTIONS"],

tests/backend_/test_base.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import patch
2+
13
from django.core.exceptions import ImproperlyConfigured
24
from django.db import connection
35
from django.db.backends.signals import connection_created
@@ -6,14 +8,47 @@
68
from django_mongodb_backend.base import DatabaseWrapper
79

810

9-
class GetConnectionParamsTests(SimpleTestCase):
11+
class DatabaseWrapperTests(SimpleTestCase):
1012
def test_database_name_empty(self):
1113
settings = connection.settings_dict.copy()
1214
settings["NAME"] = ""
1315
msg = 'settings.DATABASES is missing the "NAME" value.'
1416
with self.assertRaisesMessage(ImproperlyConfigured, msg):
15-
DatabaseWrapper(settings).get_connection_params()
17+
DatabaseWrapper(settings)
18+
19+
def test_database_name_empty_and_host_does_not_contain_database(self):
20+
settings = connection.settings_dict.copy()
21+
settings["NAME"] = ""
22+
settings["HOST"] = "mongodb://localhost"
23+
msg = 'settings.DATABASES is missing the "NAME" value.'
24+
with self.assertRaisesMessage(ImproperlyConfigured, msg):
25+
DatabaseWrapper(settings)
26+
27+
def test_database_name_parsed_from_host(self):
28+
settings = connection.settings_dict.copy()
29+
settings["NAME"] = ""
30+
settings["HOST"] = "mongodb://localhost/db"
31+
self.assertEqual(DatabaseWrapper(settings).settings_dict["NAME"], "db")
1632

33+
def test_database_name_parsed_from_srv_host(self):
34+
settings = connection.settings_dict.copy()
35+
settings["NAME"] = ""
36+
settings["HOST"] = "mongodb+srv://localhost/db"
37+
# patch() prevents a crash when PyMongo attempts to resolve the
38+
# nonexistent SRV record.
39+
with patch("dns.resolver.resolve"):
40+
self.assertEqual(DatabaseWrapper(settings).settings_dict["NAME"], "db")
41+
42+
def test_database_name_not_overridden_by_host(self):
43+
settings = connection.settings_dict.copy()
44+
settings["NAME"] = "should not be overridden"
45+
settings["HOST"] = "mongodb://localhost/db"
46+
self.assertEqual(
47+
DatabaseWrapper(settings).settings_dict["NAME"], "should not be overridden"
48+
)
49+
50+
51+
class GetConnectionParamsTests(SimpleTestCase):
1752
def test_host(self):
1853
settings = connection.settings_dict.copy()
1954
settings["HOST"] = "host"

0 commit comments

Comments
 (0)