diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 888362c6b..cb867221e 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -883,7 +883,9 @@ def execute_sql(self, result_type): f"{field.__class__.__name__}." ) prepared = field.get_db_prep_save(value, connection=self.connection) - if hasattr(value, "as_mql"): + if is_direct_value(value): + prepared = {"$literal": prepared} + else: prepared = prepared.as_mql(self, self.connection, as_expr=True) values[field.column] = prepared try: diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 60400d940..09d106b8d 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -211,9 +211,9 @@ def when(self, compiler, connection): def value(self, compiler, connection, as_expr=False): # noqa: ARG001 value = self.value - if isinstance(value, (list, int)) and as_expr: - # Wrap lists & numbers in $literal to prevent ambiguity when Value - # appears in $project. + if isinstance(value, (list, int, str, dict, tuple)) and as_expr: + # Wrap lists, numbers, strings, dict and tuple in $literal to avoid ambiguity when Value + # is used in queries' aggregate or update_many's $set. return {"$literal": value} if isinstance(value, Decimal): return Decimal128(value) diff --git a/django_mongodb_backend/test.py b/django_mongodb_backend/test.py index ee35b4e21..c1d3cffdc 100644 --- a/django_mongodb_backend/test.py +++ b/django_mongodb_backend/test.py @@ -19,3 +19,18 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline): eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307 expected_pipeline, ) + + def assertUpdateQuery(self, query, expected_collection, expected_condition, expected_values): + """ + Assert that the logged query is equal to: + db.{expected_collection}.update_many({expected_condition}, {expected_values}) + """ + prefix, pipeline = query.split("(", 1) + _, collection, operator = prefix.split(".") + self.assertEqual(operator, "update_many") + self.assertEqual(collection, expected_collection) + condition, values = eval( # noqa: S307 + pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {} + ) + self.assertEqual(condition, expected_condition) + self.assertEqual(values, expected_values) diff --git a/tests/basic_/__init__.py b/tests/basic_/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/basic_/models.py b/tests/basic_/models.py new file mode 100644 index 000000000..bd5f5241e --- /dev/null +++ b/tests/basic_/models.py @@ -0,0 +1,13 @@ +from django.db import models + + +class Author(models.Model): + name = models.CharField(max_length=10) + + def __str__(self): + return self.name + + +class Book(models.Model): + name = models.CharField(max_length=10) + data = models.JSONField(null=True) diff --git a/tests/basic_/tests.py b/tests/basic_/tests.py new file mode 100644 index 000000000..5fcbeca9d --- /dev/null +++ b/tests/basic_/tests.py @@ -0,0 +1,95 @@ +from operator import attrgetter + +from django.db.models import Value +from django.test import TestCase + +from django_mongodb_backend.test import MongoTestCaseMixin + +from .models import Author, Book + + +class SaveDollarPrefixTests(MongoTestCaseMixin, TestCase): + def test_insert_dollar_prefix(self): + """$-prefixed values are correctly saved on insert.""" + obj = Author.objects.create(name="$foobar") + obj.refresh_from_db() + self.assertEqual(obj.name, "$foobar") + + +class UpdateDollarPrefixTests(MongoTestCaseMixin, TestCase): + def test_update_dollar_prefix(self): + """$-prefixed values are correctly saved on update.""" + obj = Author.objects.create(name="foobar") + obj.name = "$updated" + with self.assertNumQueries(1) as ctx: + obj.save() + obj.refresh_from_db() + self.assertEqual(obj.name, "$updated") + self.assertUpdateQuery( + ctx.captured_queries[0]["sql"], + "basic__author", + {"_id": obj.id}, + [{"$set": {"name": {"$literal": "$updated"}}}], + ) + + def test_update_dollar_prefix_in_value_expression(self): + """$-prefixed Value() expressions are correctly handled on update.""" + obj = Author.objects.create(name="foobar") + obj.name = Value("$updated") + with self.assertNumQueries(1) as ctx: + obj.save() + obj.refresh_from_db() + self.assertEqual(obj.name, "$updated") + self.assertUpdateQuery( + ctx.captured_queries[0]["sql"], + "basic__author", + {"_id": obj.id}, + [{"$set": {"name": {"$literal": "$updated"}}}], + ) + + def test_update_dict_value(self): + """MQL-like dict Value() expressions are correctly handled on update.""" + obj = Book.objects.create(name="foobar", data={}) + obj.data = Value({"$concat": ["$name", "-", "$name"]}) + obj.save() + obj.refresh_from_db() + self.assertEqual(obj.data, {"$concat": ["$name", "-", "$name"]}) + + def test_update_dict(self): + """MQL-like dict updates are correctly handled on update.""" + obj = Book.objects.create(name="foobar") + obj.data = {"$concat": ["$name", "-", "$name"]} + obj.save() + obj.refresh_from_db() + self.assertEqual(obj.data, {"$concat": ["$name", "-", "$name"]}) + + def test_update_tuple(self): + """MQL-like tuple updates are correctly handled on update.""" + obj = Book.objects.create(name="foobar") + obj.data = ("$name", "-", "$name") + obj.save() + obj.refresh_from_db() + self.assertEqual(obj.data, ["$name", "-", "$name"]) + + def test_update_tuple_value(self): + """MQL-like tuple Value() expressions are correctly handled on update.""" + obj = Book.objects.create(name="foobar") + obj.data = Value(("$name", "-", "$name")) + obj.save() + obj.refresh_from_db() + self.assertEqual(obj.data, ["$name", "-", "$name"]) + + +class QueryDollarPrefixTests(MongoTestCaseMixin, TestCase): + def test_query_injection(self): + """$-prefixed Value() expressions are correctly handled on query.""" + Author.objects.create(name="Gustavo") + Author.objects.create(name="Walter") + with self.assertNumQueries(1) as ctx: + qs = list(Author.objects.annotate(a_value=Value("$name"))) + self.assertQuerySetEqual(qs, ["$name"] * 2, attrgetter("a_value")) + self.assertAggregateQuery( + ctx.captured_queries[0]["sql"], + "basic__author", + [{"$project": {"a_value": {"$literal": "$name"}, "_id": 1, "name": 1}}], + ) diff --git a/tests/expressions_/test_value.py b/tests/expressions_/test_value.py index 4adbc1d3e..a96be0c2d 100644 --- a/tests/expressions_/test_value.py +++ b/tests/expressions_/test_value.py @@ -41,6 +41,24 @@ def test_int(self): def test_str(self): self.assertEqual(Value("foo").as_mql(None, None), "foo") + def test_array_expr(self): + self.assertEqual( + Value(["$foo", "$bar"]).as_mql(None, None, True), {"$literal": ["$foo", "$bar"]} + ) + + def test_dict_expr(self): + self.assertEqual( + Value({"$foo": "$bar"}).as_mql(None, None, True), {"$literal": {"$foo": "$bar"}} + ) + + def test_str_expr(self): + self.assertEqual(Value("$foo").as_mql(None, None, True), {"$literal": "$foo"}) + + def test_tuple_expr(self): + self.assertEqual( + Value(("$foo", "$bar")).as_mql(None, None, True), {"$literal": ("$foo", "$bar")} + ) + def test_uuid(self): value = uuid.UUID(int=1) self.assertEqual(Value(value).as_mql(None, None), "00000000000000000000000000000001") diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index 7b3d72ec8..8453f6379 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -327,8 +327,18 @@ def test_nested_array_index_expr(self): }, { "$concat": [ - {"$ifNull": ["Z", ""]}, - {"$ifNull": ["acarias", ""]}, + { + "$ifNull": [ + {"$literal": "Z"}, + {"$literal": ""}, + ] + }, + { + "$ifNull": [ + {"$literal": "acarias"}, + {"$literal": ""}, + ] + }, ] }, ]