@@ -213,7 +213,56 @@ def value(self, compiler, connection): # noqa: ARG001
213213 return value
214214
215215
216- class SearchExpression (Expression ):
216+ class Operator :
217+ AND = "AND"
218+ OR = "OR"
219+ NOT = "NOT"
220+
221+ def __init__ (self , operator ):
222+ self .operator = operator
223+
224+ def __eq__ (self , other ):
225+ if isinstance (other , str ):
226+ return self .operator == other
227+ return self .operator == other .operator
228+
229+ def negate (self ):
230+ if self .operator == self .AND :
231+ return Operator (self .OR )
232+ if self .operator == self .OR :
233+ return Operator (self .AND )
234+ return Operator (self .operator )
235+
236+
237+ class SearchCombinable :
238+ def _combine (self , other , connector , reversed ):
239+ if not isinstance (self , CompoundExpression | CombinedSearchExpression ):
240+ lhs = CompoundExpression (must = [self ])
241+ else :
242+ lhs = self
243+ if not isinstance (other , CompoundExpression | CombinedSearchExpression ):
244+ rhs = CompoundExpression (must = [other ])
245+ else :
246+ rhs = other
247+ return CombinedSearchExpression (lhs , connector , rhs )
248+
249+ def __invert__ (self ):
250+ return CombinedSearchExpression (self , Operator (Operator .NOT ), None )
251+
252+ def __and__ (self , other ):
253+ return CombinedSearchExpression (self , Operator (Operator .AND ), other )
254+
255+ def __rand__ (self , other ):
256+ return CombinedSearchExpression (self , Operator (Operator .AND ), other )
257+
258+ def __or__ (self , other ):
259+ return CombinedSearchExpression (self , Operator (Operator .OR ), other )
260+
261+ def __ror__ (self , other ):
262+ return CombinedSearchExpression (self , Operator (Operator .OR ), other )
263+
264+
265+ class SearchExpression (SearchCombinable , Expression ):
217266 output_field = FloatField ()
218267
219268 def get_source_expressions (self ):
@@ -530,6 +579,21 @@ def __init__(
530579 self .filter = filter
531580 super ().__init__ ()
532581
582+ def __invert__ (self ):
583+ return ValueError ("SearchVector cannot be negated" )
584+
585+ def __and__ (self , other ):
586+ raise NotSupportedError ("SearchVector cannot be combined" )
587+
588+ def __rand__ (self , other ):
589+ raise NotSupportedError ("SearchVector cannot be combined" )
590+
591+ def __or__ (self , other ):
592+ raise NotSupportedError ("SearchVector cannot be combined" )
593+
594+ def __ror__ (self , other ):
595+ raise NotSupportedError ("SearchVector cannot be combined" )
596+
533597 def as_mql (self , compiler , connection ):
534598 params = {
535599 "index" : self .index ,
@@ -546,15 +610,16 @@ def as_mql(self, compiler, connection):
546610 return {"$vectorSearch" : params }
547611
548612
549- class SearchScoreOption :
550- """Class to mutate scoring on a search operation"""
551-
552- def __init__ (self , definitions = None ):
553- self .definitions = definitions
554-
555-
556613class CompoundExpression (SearchExpression ):
557- def __init__ (self , must = None , must_not = None , should = None , filter = None , score = None ):
614+ def __init__ (
615+ self ,
616+ must = None ,
617+ must_not = None ,
618+ should = None ,
619+ filter = None ,
620+ score = None ,
621+ minimum_should_match = None ,
622+ ):
558623 self .must = must or []
559624 self .must_not = must_not or []
560625 self .should = should or []
@@ -563,13 +628,67 @@ def __init__(self, must=None, must_not=None, should=None, filter=None, score=Non
563628
564629 def as_mql (self , compiler , connection ):
565630 params = {}
566- for param in ["must" , "must_not" , "should" , "filter" ]:
567- clauses = getattr (self , param )
568- if clauses :
569- params [param ] = [clause .as_mql (compiler , connection ) for clause in clauses ]
631+ if self .must :
632+ params ["must" ] = [clause .as_mql (compiler , connection ) for clause in self .must ]
633+ if self .must_not :
634+ params ["mustNot" ] = [clause .as_mql (compiler , connection ) for clause in self .must_not ]
635+ if self .should :
636+ params ["should" ] = [clause .as_mql (compiler , connection ) for clause in self .should ]
637+ if self .filter :
638+ params ["filter" ] = [clause .as_mql (compiler , connection ) for clause in self .filter ]
639+ if self .minimum_should_match is not None :
640+ params ["minimumShouldMatch" ] = self .minimum_should_match
570641
571642 return {"$compound" : params }
572643
644+ def negate (self ):
645+ return CompoundExpression (must = self .must_not , must_not = self .must + self .filter )
646+
647+
648+ class CombinedSearchExpression (SearchExpression ):
649+ def __init__ (self , lhs , operator , rhs ):
650+ self .lhs = lhs
651+ self .operator = operator
652+ self .rhs = rhs
653+
654+ @staticmethod
655+ def _flatten (node , negated = False ):
656+ if node is None :
657+ return None
658+ # Leaf, resolve the compoundExpression
659+ if isinstance (node , CompoundExpression ):
660+ return node .negate () if negated else node
661+ # Apply De Morgan's Laws.
662+ operator = node .operator .negate () if negated else node .operator
663+ negated = negated != (node .operator == Operator .NOT )
664+ lhs_compound = node ._flatten (node .lhs , negated )
665+ rhs_compound = node ._flatten (node .rhs , negated )
666+ if operator == Operator .OR :
667+ return CompoundExpression (should = [lhs_compound , rhs_compound ], minimum_should_match = 1 )
668+ if node .operator == Operator .AND :
669+ return CompoundExpression (
670+ must = lhs_compound .must + rhs_compound .must ,
671+ must_not = lhs_compound .must_not + rhs_compound .must_not ,
672+ should = lhs_compound .should + rhs_compound .should ,
673+ filter = lhs_compound .filter + rhs_compound .filter ,
674+ )
675+ # it also can be written as:
676+ # this way is more consistent with OR, but the above is shorter in the debug query.
677+ # return CompoundExpression(must=[lhs_compound, rhs_compound])
678+ # not operator
679+ return lhs_compound
680+
681+ def as_mql (self , compiler , connection ):
682+ expression = self ._flatten (self )
683+ return expression .as_mql (compiler , connection )
684+
685+
686+ class SearchScoreOption :
687+ """Class to mutate scoring on a search operation"""
688+
689+ def __init__ (self , definitions = None ):
690+ self .definitions = definitions
691+
573692
574693def register_expressions ():
575694 Case .as_mql = case
0 commit comments