1717from django .utils .functional import cached_property
1818from pymongo import ASCENDING , DESCENDING
1919
20+ from .functions import SearchScore
2021from .query import MongoQuery , wrap_database_errors
2122
2223
@@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs):
3435 # A list of OrderBy objects for this query.
3536 self .order_by_objs = None
3637 self .subqueries = []
38+ # Atlas search calls
39+ self .search_pipeline = []
3740
3841 def _get_group_alias_column (self , expr , annotation_group_idx ):
3942 """Generate a dummy field for use in the ids fields in $group."""
@@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias):
5760 column_target .set_attributes_from_name (alias )
5861 return Col (self .collection_name , column_target )
5962
63+ def _get_replace_expr (self , sub_expr , group , alias ):
64+ column_target = sub_expr .output_field .clone ()
65+ column_target .db_column = alias
66+ column_target .set_attributes_from_name (alias )
67+ inner_column = Col (self .collection_name , column_target )
68+ if getattr (sub_expr , "distinct" , False ):
69+ # If the expression should return distinct values, use
70+ # $addToSet to deduplicate.
71+ rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
72+ group [alias ] = {"$addToSet" : rhs }
73+ replacing_expr = sub_expr .copy ()
74+ replacing_expr .set_source_expressions ([inner_column , None ])
75+ else :
76+ group [alias ] = sub_expr .as_mql (self , self .connection )
77+ replacing_expr = inner_column
78+ # Count must return 0 rather than null.
79+ if isinstance (sub_expr , Count ):
80+ replacing_expr = Coalesce (replacing_expr , 0 )
81+ # Variance = StdDev^2
82+ if isinstance (sub_expr , Variance ):
83+ replacing_expr = Power (replacing_expr , 2 )
84+ return replacing_expr
85+
6086 def _prepare_expressions_for_pipeline (self , expression , target , annotation_group_idx ):
6187 """
6288 Prepare expressions for the aggregation pipeline.
@@ -80,29 +106,42 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
80106 alias = (
81107 f"__aggregation{ next (annotation_group_idx )} " if sub_expr != expression else target
82108 )
83- column_target = sub_expr .output_field .clone ()
84- column_target .db_column = alias
85- column_target .set_attributes_from_name (alias )
86- inner_column = Col (self .collection_name , column_target )
87- if sub_expr .distinct :
88- # If the expression should return distinct values, use
89- # $addToSet to deduplicate.
90- rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
91- group [alias ] = {"$addToSet" : rhs }
92- replacing_expr = sub_expr .copy ()
93- replacing_expr .set_source_expressions ([inner_column , None ])
94- else :
95- group [alias ] = sub_expr .as_mql (self , self .connection )
96- replacing_expr = inner_column
97- # Count must return 0 rather than null.
98- if isinstance (sub_expr , Count ):
99- replacing_expr = Coalesce (replacing_expr , 0 )
100- # Variance = StdDev^2
101- if isinstance (sub_expr , Variance ):
102- replacing_expr = Power (replacing_expr , 2 )
103- replacements [sub_expr ] = replacing_expr
109+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , group , alias )
104110 return replacements , group
105111
112+ def _prepare_search_expressions_for_pipeline (self , expression , target , search_idx ):
113+ searches = {}
114+ replacements = {}
115+ for sub_expr in self ._get_search_expressions (expression ):
116+ alias = f"__search_expr.search{ next (search_idx )} "
117+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , searches , alias )
118+ return replacements , searches
119+
120+ def _prepare_search_query_for_aggregation_pipeline (self , order_by ):
121+ replacements = {}
122+ searches = {}
123+ search_idx = itertools .count (start = 1 )
124+ for target , expr in self .query .annotation_select .items ():
125+ new_replacements , expr_searches = self ._prepare_search_expressions_for_pipeline (
126+ expr , target , search_idx
127+ )
128+ replacements .update (new_replacements )
129+ searches .update (expr_searches )
130+
131+ for expr , _ in order_by :
132+ new_replacements , expr_searches = self ._prepare_search_expressions_for_pipeline (
133+ expr , None , search_idx
134+ )
135+ replacements .update (new_replacements )
136+ searches .update (expr_searches )
137+
138+ having_replacements , having_group = self ._prepare_search_expressions_for_pipeline (
139+ self .having , None , search_idx
140+ )
141+ replacements .update (having_replacements )
142+ searches .update (having_group )
143+ return searches , replacements
144+
106145 def _prepare_annotations_for_aggregation_pipeline (self , order_by ):
107146 """Prepare annotations for the aggregation pipeline."""
108147 replacements = {}
@@ -179,6 +218,9 @@ def _get_group_id_expressions(self, order_by):
179218 ids = self .get_project_fields (tuple (columns ), force_expression = True )
180219 return ids , replacements
181220
221+ def _build_search_pipeline (self , search_queries ):
222+ pass
223+
182224 def _build_aggregation_pipeline (self , ids , group ):
183225 """Build the aggregation pipeline for grouping."""
184226 pipeline = []
@@ -209,7 +251,12 @@ def _build_aggregation_pipeline(self, ids, group):
209251
210252 def pre_sql_setup (self , with_col_aliases = False ):
211253 extra_select , order_by , group_by = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
212- group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
254+ searches , search_replacements = self ._prepare_search_query_for_aggregation_pipeline (
255+ order_by
256+ )
257+ group , group_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
258+ all_replacements = {** search_replacements , ** group_replacements }
259+ self .search_pipeline = searches
213260 # query.group_by is either:
214261 # - None: no GROUP BY
215262 # - True: group by select fields
@@ -557,10 +604,16 @@ def get_lookup_pipeline(self):
557604 return result
558605
559606 def _get_aggregate_expressions (self , expr ):
607+ return self ._get_all_expressions_of_type (expr , Aggregate )
608+
609+ def _get_search_expressions (self , expr ):
610+ return self ._get_all_expressions_of_type (expr , SearchScore )
611+
612+ def _get_all_expressions_of_type (self , expr , target_type ):
560613 stack = [expr ]
561614 while stack :
562615 expr = stack .pop ()
563- if isinstance (expr , Aggregate ):
616+ if isinstance (expr , target_type ):
564617 yield expr
565618 elif hasattr (expr , "get_source_expressions" ):
566619 stack .extend (expr .get_source_expressions ())
0 commit comments