Skip to content

Commit 68e4632

Browse files
committed
Fix Model.save() updates and QuerySet.update() for $-prefixed strings
1 parent 349dc1d commit 68e4632

File tree

5 files changed

+64
-1
lines changed

5 files changed

+64
-1
lines changed

django_mongodb_backend/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,9 @@ def execute_sql(self, result_type):
883883
f"{field.__class__.__name__}."
884884
)
885885
prepared = field.get_db_prep_save(value, connection=self.connection)
886-
if hasattr(value, "as_mql"):
886+
if is_direct_value(value):
887+
prepared = {"$literal": prepared}
888+
else:
887889
prepared = prepared.as_mql(self, self.connection, as_expr=True)
888890
values[field.column] = prepared
889891
try:

django_mongodb_backend/test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,18 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
1919
eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307
2020
expected_pipeline,
2121
)
22+
23+
def assertUpdateQuery(self, query, expected_collection, expected_condition, expected_values):
24+
"""
25+
Assert that the logged query is equal to:
26+
db.{expected_collection}.update_many({expected_condition}, {expected_values})
27+
"""
28+
prefix, pipeline = query.split("(", 1)
29+
_, collection, operator = prefix.split(".")
30+
self.assertEqual(operator, "update_many")
31+
self.assertEqual(collection, expected_collection)
32+
condition, values = eval( # noqa: S307
33+
pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}
34+
)
35+
self.assertEqual(condition, expected_condition)
36+
self.assertEqual(values, expected_values)

tests/basic_/__init__.py

Whitespace-only changes.

tests/basic_/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from django.db import models
2+
3+
4+
class Author(models.Model):
5+
name = models.CharField(max_length=10)
6+
7+
def __str__(self):
8+
return self.name

tests/basic_/tests.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from django.db.models import Value
2+
from django.test import TestCase
3+
4+
from django_mongodb_backend.test import MongoTestCaseMixin
5+
6+
from .models import Author
7+
8+
9+
class SaveUpdateDollarPrefixTests(MongoTestCaseMixin, TestCase):
10+
def test_update_dollar_prefix(self):
11+
"""$-prefixed values are correctly saved on update."""
12+
obj = Author.objects.create(name="foobar")
13+
obj.name = "$updated"
14+
with self.assertNumQueries(1) as ctx:
15+
obj.save()
16+
obj.refresh_from_db()
17+
self.assertEqual(obj.name, "$updated")
18+
self.assertUpdateQuery(
19+
ctx.captured_queries[0]["sql"],
20+
"basic__author",
21+
{"_id": obj.id},
22+
[{"$set": {"name": {"$literal": "$updated"}}}],
23+
)
24+
25+
def test_update_dollar_prefix_in_value_expression(self):
26+
"""$-prefixed Value() expressions are correctly handled on update."""
27+
obj = Author.objects.create(name="foobar")
28+
obj.name = Value("$updated")
29+
with self.assertNumQueries(1) as ctx:
30+
obj.save()
31+
obj.refresh_from_db()
32+
self.assertEqual(obj.name, "$updated")
33+
self.assertUpdateQuery(
34+
ctx.captured_queries[0]["sql"],
35+
"basic__author",
36+
{"_id": obj.id},
37+
[{"$set": {"name": {"$literal": "$updated"}}}],
38+
)

0 commit comments

Comments
 (0)