66from django .core .exceptions import EmptyResultSet , FullResultSet
77from django .db import NotSupportedError
88from django .db .models .expressions import (
9+ BaseExpression ,
910 Case ,
1011 Col ,
1112 ColPairs ,
1213 CombinedExpression ,
1314 Exists ,
1415 ExpressionList ,
1516 ExpressionWrapper ,
17+ Func ,
1618 NegatedExpression ,
1719 OrderBy ,
1820 RawSQL ,
2325 Value ,
2426 When ,
2527)
28+ from django .db .models .fields .json import KeyTransform
2629from django .db .models .sql import Query
2730
28- from .. query_utils import process_lhs
31+ from django_mongodb_backend . fields . array import Array
2932
33+ from ..query_utils import is_direct_value , process_lhs
3034
31- def case (self , compiler , connection ):
35+
36+ def case (self , compiler , connection , as_path = False ):
3237 case_parts = []
3338 for case in self .cases :
3439 case_mql = {}
3540 try :
36- case_mql ["case" ] = case .as_mql (compiler , connection )
41+ case_mql ["case" ] = case .as_mql (compiler , connection , as_path = False )
3742 except EmptyResultSet :
3843 continue
3944 except FullResultSet :
@@ -45,12 +50,16 @@ def case(self, compiler, connection):
4550 default_mql = self .default .as_mql (compiler , connection )
4651 if not case_parts :
4752 return default_mql
48- return {
53+ expr = {
4954 "$switch" : {
5055 "branches" : case_parts ,
5156 "default" : default_mql ,
5257 }
5358 }
59+ if as_path :
60+ return {"$expr" : expr }
61+
62+ return expr
5463
5564
5665def col (self , compiler , connection , as_path = False ): # noqa: ARG001
@@ -76,34 +85,34 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
7685 return f"{ prefix } { self .target .column } "
7786
7887
79- def col_pairs (self , compiler , connection ):
88+ def col_pairs (self , compiler , connection , as_path = False ):
8089 cols = self .get_cols ()
8190 if len (cols ) > 1 :
8291 raise NotSupportedError ("ColPairs is not supported." )
83- return cols [0 ].as_mql (compiler , connection )
92+ return cols [0 ].as_mql (compiler , connection , as_path = as_path )
8493
8594
86- def combined_expression (self , compiler , connection ):
95+ def combined_expression (self , compiler , connection , as_path = False ):
8796 expressions = [
88- self .lhs .as_mql (compiler , connection ),
89- self .rhs .as_mql (compiler , connection ),
97+ self .lhs .as_mql (compiler , connection , as_path = as_path ),
98+ self .rhs .as_mql (compiler , connection , as_path = as_path ),
9099 ]
91100 return connection .ops .combine_expression (self .connector , expressions )
92101
93102
94- def expression_wrapper (self , compiler , connection ):
95- return self .expression .as_mql (compiler , connection )
103+ def expression_wrapper (self , compiler , connection , as_path = False ):
104+ return self .expression .as_mql (compiler , connection , as_path = as_path )
96105
97106
98- def negated_expression (self , compiler , connection ):
99- return {"$not" : expression_wrapper (self , compiler , connection )}
107+ def negated_expression (self , compiler , connection , as_path = False ):
108+ return {"$not" : expression_wrapper (self , compiler , connection , as_path = as_path )}
100109
101110
102111def order_by (self , compiler , connection ):
103112 return self .expression .as_mql (compiler , connection )
104113
105114
106- def query (self , compiler , connection , get_wrapping_pipeline = None ):
115+ def query (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
107116 subquery_compiler = self .get_compiler (connection = connection )
108117 subquery_compiler .pre_sql_setup (with_col_aliases = False )
109118 field_name , expr = subquery_compiler .columns [0 ]
@@ -145,14 +154,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None):
145154 # Erase project_fields since the required value is projected above.
146155 subquery .project_fields = None
147156 compiler .subqueries .append (subquery )
157+ if as_path :
158+ return f"{ table_output } .{ field_name } "
148159 return f"${ table_output } .{ field_name } "
149160
150161
151162def raw_sql (self , compiler , connection ): # noqa: ARG001
152163 raise NotSupportedError ("RawSQL is not supported on MongoDB." )
153164
154165
155- def ref (self , compiler , connection ): # noqa: ARG001
166+ def ref (self , compiler , connection , as_path = False ): # noqa: ARG001
156167 prefix = (
157168 f"{ self .source .alias } ."
158169 if isinstance (self .source , Col ) and self .source .alias != compiler .collection_name
@@ -162,32 +173,47 @@ def ref(self, compiler, connection): # noqa: ARG001
162173 refs , _ = compiler .columns [self .ordinal - 1 ]
163174 else :
164175 refs = self .refs
165- return f"${ prefix } { refs } "
176+ if not as_path :
177+ prefix = f"${ prefix } "
178+ return f"{ prefix } { refs } "
166179
167180
168- def star (self , compiler , connection ): # noqa: ARG001
181+ def star (self , compiler , connection , ** extra ): # noqa: ARG001
169182 return {"$literal" : True }
170183
171184
172- def subquery (self , compiler , connection , get_wrapping_pipeline = None ):
173- return self .query .as_mql (compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline )
185+ def subquery (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
186+ expr = self .query .as_mql (
187+ compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline , as_path = False
188+ )
189+ if as_path :
190+ return {"$expr" : expr }
191+ return expr
174192
175193
176- def exists (self , compiler , connection , get_wrapping_pipeline = None ):
194+ def exists (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
177195 try :
178- lhs_mql = subquery (self , compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline )
196+ lhs_mql = subquery (
197+ self ,
198+ compiler ,
199+ connection ,
200+ get_wrapping_pipeline = get_wrapping_pipeline ,
201+ as_path = as_path ,
202+ )
179203 except EmptyResultSet :
180204 return Value (False ).as_mql (compiler , connection )
181- return connection .mongo_operators ["isnull" ](lhs_mql , False )
205+ if as_path :
206+ return {"$expr" : connection .mongo_operators_match ["isnull" ](lhs_mql , False )}
207+ return connection .mongo_operators_expr ["isnull" ](lhs_mql , False )
182208
183209
184- def when (self , compiler , connection ):
185- return self .condition .as_mql (compiler , connection )
210+ def when (self , compiler , connection , as_path = False ):
211+ return self .condition .as_mql (compiler , connection , as_path = as_path )
186212
187213
188- def value (self , compiler , connection ): # noqa: ARG001
214+ def value (self , compiler , connection , as_path = False ): # noqa: ARG001
189215 value = self .value
190- if isinstance (value , (list , int )):
216+ if isinstance (value , (list , int )) and not as_path :
191217 # Wrap lists & numbers in $literal to prevent ambiguity when Value
192218 # appears in $project.
193219 return {"$literal" : value }
@@ -209,6 +235,36 @@ def value(self, compiler, connection): # noqa: ARG001
209235 return value
210236
211237
238+ @staticmethod
239+ def _is_constant_value (value ):
240+ if isinstance (value , list | Array ):
241+ iterable = value .get_source_expressions () if isinstance (value , Array ) else value
242+ return all (_is_constant_value (e ) for e in iterable )
243+ if is_direct_value (value ):
244+ return True
245+ return isinstance (value , Func | Value ) and not (
246+ value .contains_aggregate
247+ or value .contains_over_clause
248+ or value .contains_column_references
249+ or value .contains_subquery
250+ )
251+
252+
253+ @staticmethod
254+ def _is_simple_column (lhs ):
255+ while isinstance (lhs , KeyTransform ):
256+ if "." in getattr (lhs , "key_name" , "" ):
257+ return False
258+ lhs = lhs .lhs
259+ col = lhs .source if isinstance (lhs , Ref ) else lhs
260+ # Foreign columns from parent cannot be addressed as single match
261+ return isinstance (col , Col ) and col .alias is not None
262+
263+
264+ def _is_simple_expression (self ):
265+ return self .is_simple_column (self .lhs ) and self .is_constant_value (self .rhs )
266+
267+
212268def register_expressions ():
213269 Case .as_mql = case
214270 Col .as_mql = col
@@ -227,3 +283,6 @@ def register_expressions():
227283 Subquery .as_mql = subquery
228284 When .as_mql = when
229285 Value .as_mql = value
286+ BaseExpression .is_simple_expression = _is_simple_expression
287+ BaseExpression .is_simple_column = _is_simple_column
288+ BaseExpression .is_constant_value = _is_constant_value
0 commit comments