33
44from django .core .exceptions import EmptyResultSet , FullResultSet
55from django .db import DatabaseError , IntegrityError , NotSupportedError
6- from django .db .models .expressions import Case , When
6+ from django .db .models .expressions import Case , Col , When
77from django .db .models .functions import Mod
88from django .db .models .lookups import Exact
99from django .db .models .sql .constants import INNER
@@ -105,6 +105,7 @@ def join(self, compiler, connection):
105105 lhs_fields = []
106106 rhs_fields = []
107107 # Add a join condition for each pair of joining fields.
108+ parent_template = "parent__field__"
108109 for lhs , rhs in self .join_fields :
109110 lhs , rhs = connection .ops .prepare_join_on_clause (
110111 self .parent_alias , lhs , compiler .collection_name , rhs
@@ -113,8 +114,41 @@ def join(self, compiler, connection):
113114 # In the lookup stage, the reference to this column doesn't include
114115 # the collection name.
115116 rhs_fields .append (rhs .as_mql (compiler , connection ))
117+ # Handle any join conditions besides matching field pairs.
118+ extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
119+ if extra :
120+ columns = []
121+ for expr in extra .leaves ():
122+ # Determine whether the column needs to be transformed or rerouted
123+ # as part of the subquery.
124+ for hand_side in ["lhs" , "rhs" ]:
125+ hand_side_value = getattr (expr , hand_side , None )
126+ if isinstance (hand_side_value , Col ):
127+ # If the column is not part of the joined table, add it to
128+ # lhs_fields.
129+ if hand_side_value .alias != self .table_name :
130+ pos = len (lhs_fields )
131+ lhs_fields .append (expr .lhs .as_mql (compiler , connection ))
132+ else :
133+ pos = None
134+ columns .append ((hand_side_value , pos ))
135+ # Replace columns in the extra conditions with new column references
136+ # based on their rerouted positions in the join pipeline.
137+ replacements = {}
138+ for col , parent_pos in columns :
139+ column_target = Col (compiler .collection_name , expr .output_field .__class__ ())
140+ if parent_pos is not None :
141+ target_col = f"${ parent_template } { parent_pos } "
142+ column_target .target .db_column = target_col
143+ column_target .target .set_attributes_from_name (target_col )
144+ else :
145+ column_target .target = col .target
146+ replacements [col ] = column_target
147+ # Apply the transformed expressions in the extra condition.
148+ extra_condition = [extra .replace_expressions (replacements ).as_mql (compiler , connection )]
149+ else :
150+ extra_condition = []
116151
117- parent_template = "parent__field__"
118152 lookup_pipeline = [
119153 {
120154 "$lookup" : {
@@ -140,6 +174,7 @@ def join(self, compiler, connection):
140174 {"$eq" : [f"$${ parent_template } { i } " , field ]}
141175 for i , field in enumerate (rhs_fields )
142176 ]
177+ + extra_condition
143178 }
144179 }
145180 }
0 commit comments