@@ -132,11 +132,11 @@ def search_index_data_types(self, field, db_type):
132132 """
133133 if field .get_internal_type () == "UUIDField" :
134134 return "uuid"
135- if field .get_internal_type () in ( "ObjectIdAutoField" , "ObjectIdField" ) :
135+ if field .get_internal_type () in { "ObjectIdAutoField" , "ObjectIdField" } :
136136 return "ObjectId"
137137 if field .get_internal_type () == "EmbeddedModelField" :
138138 return "embeddedDocuments"
139- if db_type in ( "int" , "long" ) :
139+ if db_type in { "int" , "long" } :
140140 return "number"
141141 if db_type == "binData" :
142142 return "string"
@@ -164,26 +164,24 @@ def get_pymongo_index_model(
164164
165165class VectorSearchIndex (SearchIndex ):
166166 suffix = "vsi"
167- ALLOWED_SIMILARITY_FUNCTIONS = frozenset (("euclidean" , "cosine" , "dotProduct" ))
167+ VALID_SIMILARITIES = frozenset (("euclidean" , "cosine" , "dotProduct" ))
168168 _error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"
169169
170170 def __init__ (self , * expressions , fields = (), similarities = "cosine" , name = None , ** kwargs ):
171171 super ().__init__ (* expressions , fields = fields , name = name , ** kwargs )
172172 self .similarities = similarities
173- for func in similarities if isinstance (similarities , list ) else (similarities ,):
174- if func not in self .ALLOWED_SIMILARITY_FUNCTIONS :
173+ self ._multiple_similarities = isinstance (similarities , tuple | list )
174+ for func in similarities if self ._multiple_similarities else (similarities ,):
175+ if func not in self .VALID_SIMILARITIES :
175176 raise ValueError (
176- f"{ func } isn't a valid similarity function, options "
177- f"are { ', ' .join (sorted (self .ALLOWED_SIMILARITY_FUNCTIONS ))} "
177+ f"' { func } ' isn't a valid similarity function "
178+ f"( { ', ' .join (sorted (self .VALID_SIMILARITIES ))} ). "
178179 )
179- viewed = set ()
180+ seen_fields = set ()
180181 for field_name , _ in self .fields_orders :
181- if field_name in viewed :
182- raise ValueError (
183- f"Field '{ field_name } ' is defined more than once. Vector and filter "
184- "fields must use distinct field names." ,
185- )
186- viewed .add (field_name )
182+ if field_name in seen_fields :
183+ raise ValueError (f"Field '{ field_name } ' is duplicated in fields." )
184+ seen_fields .add (field_name )
187185
188186 def check (self , model , connection ):
189187 errors = super ().check (model , connection )
@@ -197,20 +195,20 @@ def check(self, model, connection):
197195 except (ValueError , TypeError ):
198196 errors .append (
199197 Error (
200- f"Atlas vector search requires size on { field_name } ." ,
198+ f"VectorSearchIndex requires ' size' on field ' { field_name } ' ." ,
201199 obj = model ,
202- id = f"{ self ._error_id_prefix } .E001 " ,
200+ id = f"{ self ._error_id_prefix } .E002 " ,
203201 )
204202 )
205203 if not isinstance (field_ .base_field , FloatField | DecimalField ):
206204 errors .append (
207205 Error (
208- "An Atlas vector search index requires the base "
209- "field of ArrayField Model.field_name "
210- "to be FloatField or DecimalField but "
211- f"is { field_ .base_field .get_internal_type ()} ." ,
206+ "VectorSearchIndex requires the base field of "
207+ f" ArrayField ' { field_ . name } ' to be FloatField or "
208+ "DecimalField but is "
209+ f"{ field_ .base_field .get_internal_type ()} ." ,
212210 obj = model ,
213- id = f"{ self ._error_id_prefix } .E002 " ,
211+ id = f"{ self ._error_id_prefix } .E003 " ,
214212 )
215213 )
216214 else :
@@ -223,24 +221,22 @@ def check(self, model, connection):
223221 errors .append (
224222 Error (
225223 "VectorSearchIndex does not support "
226- f"' { field_ .get_internal_type ()} ' { field_name } ." ,
224+ f"{ field_ .get_internal_type ()} ' { field_name } ' ." ,
227225 obj = model ,
228- id = f"{ self ._error_id_prefix } .E003 " ,
226+ id = f"{ self ._error_id_prefix } .E004 " ,
229227 )
230228 )
231- if isinstance ( self .similarities , list ) and expected_similarities != len (self .similarities ):
229+ if self ._multiple_similarities and expected_similarities != len (self .similarities ):
232230 similarity_function_text = (
233- "similarities functions " if expected_similarities != 1 else "similarity function "
231+ "similarity " if expected_similarities == 1 else "similarities "
234232 )
235233 errors .append (
236234 Error (
237- f"An Atlas vector search index requires the same number of similarities and "
238- f"vector fields, but { expected_similarities } "
239- f"{ similarity_function_text } were expected and "
240- f"{ len (self .similarities )} { 'were' if len (self .similarities ) != 1 else 'was' } "
241- "provided." ,
235+ f"VectorSearchIndex requires the same number of similarities and "
236+ f"vector fields; expected { expected_similarities } "
237+ f"{ similarity_function_text } but got { len (self .similarities )} ." ,
242238 obj = model ,
243- id = f"{ self ._error_id_prefix } .E004 " ,
239+ id = f"{ self ._error_id_prefix } .E005 " ,
244240 )
245241 )
246242 return errors
0 commit comments