Skip to content

Commit 9338093

Browse files
committed
Refactor.
1 parent 64d4a03 commit 9338093

File tree

5 files changed

+52
-56
lines changed

5 files changed

+52
-56
lines changed

django_mongodb_backend/expressions/builtins.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from django.core.exceptions import EmptyResultSet, FullResultSet
77
from django.db import NotSupportedError
88
from django.db.models.expressions import (
9+
BaseExpression,
910
Case,
1011
Col,
1112
ColPairs,
1213
CombinedExpression,
1314
Exists,
1415
ExpressionList,
1516
ExpressionWrapper,
17+
Func,
1618
NegatedExpression,
1719
OrderBy,
1820
RawSQL,
@@ -23,9 +25,12 @@
2325
Value,
2426
When,
2527
)
28+
from django.db.models.fields.json import KeyTransform
2629
from django.db.models.sql import Query
2730

28-
from ..query_utils import process_lhs
31+
from django_mongodb_backend.fields.array import Array
32+
33+
from ..query_utils import is_direct_value, process_lhs
2934

3035

3136
# EXTRA IS TOTALLY IGNORED
@@ -234,6 +239,36 @@ def value(self, compiler, connection, as_path=False): # noqa: ARG001
234239
return value
235240

236241

242+
@staticmethod
243+
def _is_constant_value(value):
244+
if isinstance(value, Array):
245+
return all(_is_constant_value(e) for e in value.get_source_expressions())
246+
if isinstance(value, Value) or is_direct_value(value):
247+
v = value.value if isinstance(value, Value) else value
248+
return not isinstance(v, str) or "." not in v
249+
return isinstance(value, Func | Value) and not (
250+
value.contains_aggregate
251+
or value.contains_over_clause
252+
or value.contains_column_references
253+
or value.contains_subquery
254+
)
255+
256+
257+
@staticmethod
258+
def _is_simple_column(lhs):
259+
while isinstance(lhs, KeyTransform):
260+
if "." in getattr(lhs, "key_name", ""):
261+
return False
262+
lhs = lhs.lhs
263+
col = lhs.source if isinstance(lhs, Ref) else lhs
264+
# Foreign columns from parent cannot be addressed as single match
265+
return isinstance(col, Col) and col.alias is not None
266+
267+
268+
def is_simple_expression(self):
269+
return self.is_simple_column(self.lhs) and self.is_constant_value(self.rhs)
270+
271+
237272
def register_expressions():
238273
Case.as_mql = case
239274
Col.as_mql = col
@@ -252,3 +287,6 @@ def register_expressions():
252287
Subquery.as_mql = subquery
253288
When.as_mql = when
254289
Value.as_mql = value
290+
BaseExpression.is_simple_expression = is_simple_expression
291+
BaseExpression.is_simple_column = _is_simple_column
292+
BaseExpression.is_constant_value = _is_constant_value

django_mongodb_backend/fields/array.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from django.utils.translation import gettext_lazy as _
88

99
from ..forms import SimpleArrayField
10-
from ..lookups import is_constant_value, is_simple_column
1110
from ..query_utils import process_lhs, process_rhs
1211
from ..utils import prefix_validation_error
1312
from ..validators import ArrayMaxLengthValidator, LengthValidator
@@ -256,7 +255,7 @@ class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
256255
lookup_name = "contains"
257256

258257
def as_mql(self, compiler, connection, as_path=False):
259-
if as_path and is_simple_column(self.lhs) and is_constant_value(self.rhs):
258+
if as_path and self.is_simple_expression():
260259
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
261260
value = process_rhs(self, compiler, connection, as_path=as_path)
262261
if value is None:
@@ -346,7 +345,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
346345
]
347346

348347
def as_mql(self, compiler, connection, as_path=False):
349-
if as_path and is_simple_column(self.lhs) and is_constant_value(self.rhs):
348+
if as_path and self.is_simple_expression():
350349
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
351350
value = process_rhs(self, compiler, connection, as_path=True)
352351
return {lhs_mql: {"$in": value}}

django_mongodb_backend/fields/json.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from itertools import chain
22

33
from django.db import NotSupportedError
4-
from django.db.models.expressions import Value
54
from django.db.models.fields.json import (
65
ContainedBy,
76
DataContains,
@@ -17,7 +16,7 @@
1716
KeyTransformNumericLookupMixin,
1817
)
1918

20-
from ..lookups import builtin_lookup, is_constant_value, is_simple_column
19+
from ..lookups import builtin_lookup
2120
from ..query_utils import process_lhs, process_rhs
2221

2322

@@ -72,23 +71,13 @@ def _has_key_predicate(path, root_column=None, negated=False, as_path=False):
7271
return result
7372

7473

75-
def has_key_lookup_rhs_check(rhs):
76-
for key in rhs:
77-
if not is_constant_value(key):
78-
return False
79-
value = key.value if isinstance(key, Value) else key
80-
if isinstance(value, str) and "." in value:
81-
return False
82-
return True
83-
84-
8574
def has_key_lookup(self, compiler, connection, as_path=False):
8675
"""Return MQL to check for the existence of a key."""
8776
rhs = self.rhs
8877
lhs = process_lhs(self, compiler, connection)
8978
if not isinstance(rhs, (list, tuple)):
9079
rhs = [rhs]
91-
as_path = as_path and is_simple_column(self.lhs) and has_key_lookup_rhs_check(rhs)
80+
as_path = as_path and self.is_simple_expression()
9281
paths = []
9382
# Transform any "raw" keys into KeyTransforms to allow consistent handling
9483
# in the code that follows.
@@ -132,7 +121,7 @@ def key_transform(self, compiler, connection, as_path=False):
132121
while isinstance(previous, KeyTransform):
133122
key_transforms.insert(0, previous.key_name)
134123
previous = previous.lhs
135-
if as_path and is_simple_column(self.lhs):
124+
if as_path and self.is_simple_column(self.lhs):
136125
lhs_mql = previous.as_mql(compiler, connection, as_path=True)
137126
return build_json_mql_path(lhs_mql, key_transforms, as_path=True)
138127
# Collect all key transforms in order.
@@ -147,7 +136,7 @@ def key_transform_in(self, compiler, connection, as_path=False):
147136
Return MQL to check if a JSON path exists and that its values are in the
148137
set of specified values (rhs).
149138
"""
150-
if as_path and is_simple_column(self.lhs) and is_constant_value(self.rhs):
139+
if as_path and self.is_simple_expression():
151140
return builtin_lookup(self, compiler, connection, as_path=True)
152141

153142
lhs_mql = process_lhs(self, compiler, connection)
@@ -175,7 +164,7 @@ def key_transform_is_null(self, compiler, connection, as_path=False):
175164
176165
Reference: https://code.djangoproject.com/ticket/32252
177166
"""
178-
if as_path and is_simple_column(self.lhs) and is_constant_value(self.rhs):
167+
if as_path and self.is_simple_expression():
179168
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
180169
rhs_mql = process_rhs(self, compiler, connection)
181170
return _has_key_predicate(lhs_mql, None, negated=rhs_mql, as_path=True)
@@ -195,7 +184,7 @@ def key_transform_numeric_lookup_mixin(self, compiler, connection, as_path=False
195184
Return MQL to check if the field exists (i.e., is not "missing" or "null")
196185
and that the field matches the given numeric lookup expression.
197186
"""
198-
if is_simple_column(self.lhs) and is_constant_value(self.rhs) and as_path:
187+
if as_path and self.is_simple_expression():
199188
return builtin_lookup(self, compiler, connection, as_path=True)
200189

201190
lhs = process_lhs(self, compiler, connection, as_path=False)
@@ -209,7 +198,7 @@ def key_transform_numeric_lookup_mixin(self, compiler, connection, as_path=False
209198

210199

211200
def key_transform_exact(self, compiler, connection, as_path=False):
212-
if is_simple_column(self.lhs) and is_constant_value(self.rhs) and as_path:
201+
if as_path and self.is_simple_expression():
213202
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
214203
return {
215204
"$and": [

django_mongodb_backend/functions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
Upper,
4040
)
4141

42-
from .lookups import is_constant_value
4342
from .query_utils import process_lhs
4443

4544
MONGO_OPERATORS = {
@@ -167,7 +166,7 @@ def preserve_null(operator):
167166
# If the argument is null, the function should return null, not
168167
# $toLower/Upper's behavior of returning an empty string.
169168
def wrapped(self, compiler, connection, as_path=False):
170-
if is_constant_value(self.lhs) and as_path:
169+
if as_path and self.is_constant_value(self.lhs):
171170
if self.lhs is None:
172171
return None
173172
lhs_mql = process_lhs(self, compiler, connection, as_path=True)

django_mongodb_backend/lookups.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from django.db import NotSupportedError
2-
from django.db.models.expressions import Col, Func, Ref, Value
3-
from django.db.models.fields.json import KeyTransform
42
from django.db.models.fields.related_lookups import In, RelatedIn
53
from django.db.models.lookups import (
64
BuiltinLookup,
@@ -10,38 +8,11 @@
108
UUIDTextMixin,
119
)
1210

13-
from .query_utils import is_direct_value, process_lhs, process_rhs
14-
15-
16-
def is_constant_value(value):
17-
from django_mongodb_backend.fields.array import Array # noqa: PLC0415
18-
19-
if isinstance(value, Array):
20-
return all(is_constant_value(e) for e in value.get_source_expressions())
21-
22-
return is_direct_value(value) or (
23-
isinstance(value, Func | Value)
24-
and not (
25-
value.contains_aggregate
26-
or value.contains_over_clause
27-
or value.contains_column_references
28-
or value.contains_subquery
29-
)
30-
)
31-
32-
33-
def is_simple_column(lhs):
34-
while isinstance(lhs, KeyTransform):
35-
if "." in lhs.key_name:
36-
return False
37-
lhs = lhs.lhs
38-
col = lhs.source if isinstance(lhs, Ref) else lhs
39-
# Foreign columns from parent cannot be addressed as single match
40-
return isinstance(col, Col) and col.alias is not None
11+
from .query_utils import process_lhs, process_rhs
4112

4213

4314
def builtin_lookup(self, compiler, connection, as_path=False):
44-
if is_simple_column(self.lhs) and is_constant_value(self.rhs) and as_path:
15+
if as_path and self.is_simple_expression():
4516
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
4617
value = process_rhs(self, compiler, connection, as_path=True)
4718
return connection.mongo_operators_match[self.lookup_name](lhs_mql, value)
@@ -114,7 +85,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
11485
def is_null(self, compiler, connection, as_path=False):
11586
if not isinstance(self.rhs, bool):
11687
raise ValueError("The QuerySet value for an isnull lookup must be True or False.")
117-
if is_constant_value(self.rhs) and as_path and is_simple_column(self.lhs):
88+
if as_path and self.is_simple_expression():
11889
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
11990
return connection.mongo_operators_match["isnull"](lhs_mql, self.rhs)
12091
lhs_mql = process_lhs(self, compiler, connection, as_path=False)

0 commit comments

Comments
 (0)