Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,9 @@ def execute_sql(self, result_type):
f"{field.__class__.__name__}."
)
prepared = field.get_db_prep_save(value, connection=self.connection)
if hasattr(value, "as_mql"):
if is_direct_value(value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still question whether the $literal escaping is overly broad. Better safe than sorry, sure, but it decreases readability a bit and makes queries larger. But I'm not sure how to compare the tradeoffs between some more complicated logic (isinstance() CPU time) and leaving it as is. It seems if we did wrapping of only the types that Value() wraps, it should be safe, unless the logic in Value() is deficient. And are there any string values besides those that start with $ that could be problematic? Perhaps to be discussed in chat tomorrow.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 I agree. Seeing $literal everywhere is annoying, but Value's resolver is probably covering the most common types used in queries. On the other hand, implementing the same escaping logic that Value uses is not a big deal.

prepared = {"$literal": prepared}
else:
prepared = prepared.as_mql(self, self.connection, as_expr=True)
values[field.column] = prepared
try:
Expand Down
6 changes: 3 additions & 3 deletions django_mongodb_backend/expressions/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def when(self, compiler, connection):

def value(self, compiler, connection, as_expr=False): # noqa: ARG001
value = self.value
if isinstance(value, (list, int)) and as_expr:
# Wrap lists & numbers in $literal to prevent ambiguity when Value
# appears in $project.
if isinstance(value, (list, int, str, dict, tuple)) and as_expr:
# Wrap lists, numbers, strings, dict and tuple in $literal to avoid ambiguity when Value
# is used in queries' aggregate or update_many's $set.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, you replaced $project with aggregate. Is there another place in aggregate besides $project where this could be an issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in place like $group or $match (for having or filter). Do we need a test for that?

return {"$literal": value}
if isinstance(value, Decimal):
return Decimal128(value)
Expand Down
15 changes: 15 additions & 0 deletions django_mongodb_backend/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,18 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307
expected_pipeline,
)

def assertUpdateQuery(self, query, expected_collection, expected_condition, expected_values):
"""
Assert that the logged query is equal to:
db.{expected_collection}.update_many({expected_condition}, {expected_values})
"""
prefix, pipeline = query.split("(", 1)
_, collection, operator = prefix.split(".")
self.assertEqual(operator, "update_many")
self.assertEqual(collection, expected_collection)
condition, values = eval( # noqa: S307
pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}
)
self.assertEqual(condition, expected_condition)
self.assertEqual(values, expected_values)
Empty file added tests/basic_/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions tests/basic_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from django.db import models


class Author(models.Model):
name = models.CharField(max_length=10)

def __str__(self):
return self.name


class Book(models.Model):
name = models.CharField(max_length=10)
data = models.JSONField(null=True)
95 changes: 95 additions & 0 deletions tests/basic_/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from operator import attrgetter

from django.db.models import Value
from django.test import TestCase

from django_mongodb_backend.test import MongoTestCaseMixin

from .models import Author, Book


class SaveDollarPrefixTests(MongoTestCaseMixin, TestCase):
def test_insert_dollar_prefix(self):
"""$-prefixed values are correctly saved on insert."""
obj = Author.objects.create(name="$foobar")
obj.refresh_from_db()
self.assertEqual(obj.name, "$foobar")


class UpdateDollarPrefixTests(MongoTestCaseMixin, TestCase):
def test_update_dollar_prefix(self):
"""$-prefixed values are correctly saved on update."""
obj = Author.objects.create(name="foobar")
obj.name = "$updated"
with self.assertNumQueries(1) as ctx:
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.name, "$updated")
self.assertUpdateQuery(
ctx.captured_queries[0]["sql"],
"basic__author",
{"_id": obj.id},
[{"$set": {"name": {"$literal": "$updated"}}}],
)

def test_update_dollar_prefix_in_value_expression(self):
"""$-prefixed Value() expressions are correctly handled on update."""
obj = Author.objects.create(name="foobar")
obj.name = Value("$updated")
with self.assertNumQueries(1) as ctx:
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.name, "$updated")
self.assertUpdateQuery(
ctx.captured_queries[0]["sql"],
"basic__author",
{"_id": obj.id},
[{"$set": {"name": {"$literal": "$updated"}}}],
)

def test_update_dict_value(self):
"""MQL-like dict Value() expressions are correctly handled on update."""
obj = Book.objects.create(name="foobar", data={})
obj.data = Value({"$concat": ["$name", "-", "$name"]})
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.data, {"$concat": ["$name", "-", "$name"]})

def test_update_dict(self):
"""MQL-like dict updates are correctly handled on update."""
obj = Book.objects.create(name="foobar")
obj.data = {"$concat": ["$name", "-", "$name"]}
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.data, {"$concat": ["$name", "-", "$name"]})

def test_update_tuple(self):
"""MQL-like tuple updates are correctly handled on update."""
obj = Book.objects.create(name="foobar")
obj.data = ("$name", "-", "$name")
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.data, ["$name", "-", "$name"])

def test_update_tuple_value(self):
"""MQL-like tuple Value() expressions are correctly handled on update."""
obj = Book.objects.create(name="foobar")
obj.data = Value(("$name", "-", "$name"))
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.data, ["$name", "-", "$name"])


class QueryDollarPrefixTests(MongoTestCaseMixin, TestCase):
def test_query_injection(self):
"""$-prefixed Value() expressions are correctly handled on query."""
Author.objects.create(name="Gustavo")
Author.objects.create(name="Walter")
with self.assertNumQueries(1) as ctx:
qs = list(Author.objects.annotate(a_value=Value("$name")))
self.assertQuerySetEqual(qs, ["$name"] * 2, attrgetter("a_value"))
self.assertAggregateQuery(
ctx.captured_queries[0]["sql"],
"basic__author",
[{"$project": {"a_value": {"$literal": "$name"}, "_id": 1, "name": 1}}],
)
18 changes: 18 additions & 0 deletions tests/expressions_/test_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ def test_int(self):
def test_str(self):
self.assertEqual(Value("foo").as_mql(None, None), "foo")

def test_array_expr(self):
self.assertEqual(
Value(["$foo", "$bar"]).as_mql(None, None, True), {"$literal": ["$foo", "$bar"]}
)

def test_dict_expr(self):
self.assertEqual(
Value({"$foo": "$bar"}).as_mql(None, None, True), {"$literal": {"$foo": "$bar"}}
)

def test_str_expr(self):
self.assertEqual(Value("$foo").as_mql(None, None, True), {"$literal": "$foo"})

def test_tuple_expr(self):
self.assertEqual(
Value(("$foo", "$bar")).as_mql(None, None, True), {"$literal": ("$foo", "$bar")}
)

def test_uuid(self):
value = uuid.UUID(int=1)
self.assertEqual(Value(value).as_mql(None, None), "00000000000000000000000000000001")
14 changes: 12 additions & 2 deletions tests/model_fields_/test_embedded_model_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,18 @@ def test_nested_array_index_expr(self):
},
{
"$concat": [
{"$ifNull": ["Z", ""]},
{"$ifNull": ["acarias", ""]},
{
"$ifNull": [
{"$literal": "Z"},
{"$literal": ""},
]
},
{
"$ifNull": [
{"$literal": "acarias"},
{"$literal": ""},
]
},
]
},
]
Expand Down