@@ -44,7 +44,7 @@ class Article(models.Model):
4444 vector = ArrayField (models .FloatField (), size = 10 )
4545
4646 class Meta :
47- indexes = [VectorSearchIndex (fields = ["title" , "vector" ])]
47+ indexes = [VectorSearchIndex (fields = ["title" , "vector" ], similarities = "cosine" )]
4848
4949 errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
5050 self .assertEqual (
@@ -72,7 +72,7 @@ class Article(models.Model):
7272 title_embedded = ArrayField (models .FloatField ())
7373
7474 class Meta :
75- indexes = [VectorSearchIndex (fields = ["title_embedded" ])]
75+ indexes = [VectorSearchIndex (fields = ["title_embedded" ], similarities = "cosine" )]
7676
7777 errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
7878 self .assertEqual (
@@ -91,7 +91,7 @@ class Article(models.Model):
9191 title_embedded = ArrayField (models .CharField (), size = 30 )
9292
9393 class Meta :
94- indexes = [VectorSearchIndex (fields = ["title_embedded" ])]
94+ indexes = [VectorSearchIndex (fields = ["title_embedded" ], similarities = "cosine" )]
9595
9696 errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
9797 self .assertEqual (
@@ -112,7 +112,7 @@ class Article(models.Model):
112112 vector = ArrayField (models .FloatField (), size = 10 )
113113
114114 class Meta :
115- indexes = [VectorSearchIndex (fields = ["data" , "vector" ])]
115+ indexes = [VectorSearchIndex (fields = ["data" , "vector" ], similarities = "cosine" )]
116116
117117 errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
118118 self .assertEqual (
@@ -127,7 +127,7 @@ class Meta:
127127 ],
128128 )
129129
130- def test_invalid_number_similarity_function_singular (self ):
130+ def test_fields_and_similarities_mismatch (self ):
131131 class Article (models .Model ):
132132 vector = ArrayField (models .FloatField (), size = 10 )
133133
@@ -153,44 +153,17 @@ class Meta:
153153 ],
154154 )
155155
156- def test_invalid_number_similarity_function_plural (self ):
157- class Article (models .Model ):
158- vector1 = ArrayField (models .FloatField (), size = 10 )
159- vector2 = ArrayField (models .FloatField (), size = 10 )
160-
161- class Meta :
162- indexes = [
163- VectorSearchIndex (
164- fields = ["vector1" , "vector2" ],
165- similarities = ["dotProduct" ],
166- )
167- ]
168-
169- errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
170- self .assertEqual (
171- errors ,
172- [
173- checks .Error (
174- "VectorSearchIndex requires the same number of similarities "
175- "and vector fields; Article has 2 ArrayField(s) but similarities "
176- "has 1 element(s)." ,
177- id = "django_mongodb_backend.indexes.VectorSearchIndex.E005" ,
178- obj = Article ,
179- ),
180- ],
181- )
182-
183156 def test_simple (self ):
184157 class Article (models .Model ):
185158 vector = ArrayField (models .FloatField (), size = 10 )
186159
187160 class Meta :
188- indexes = [VectorSearchIndex (fields = ["vector" ])]
161+ indexes = [VectorSearchIndex (fields = ["vector" ], similarities = "cosine" )]
189162
190163 errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
191164 self .assertEqual (errors , [])
192165
193- def test_all_valid_fields (self ):
166+ def test_valid_fields (self ):
194167 class Data (EmbeddedModel ):
195168 integer = models .IntegerField ()
196169
@@ -216,6 +189,7 @@ class Meta:
216189 "boolean" ,
217190 "date" ,
218191 ],
192+ similarities = "cosine" ,
219193 )
220194 ]
221195
@@ -227,7 +201,11 @@ class NoSearchVectorModel(models.Model):
227201 text = models .CharField (max_length = 100 )
228202
229203 class Meta :
230- indexes = [VectorSearchIndex (name = "recent_test_idx" , fields = ["text" ])]
204+ indexes = [
205+ VectorSearchIndex (
206+ name = "recent_test_idx" , fields = ["text" ], similarities = "cosine"
207+ )
208+ ]
231209
232210 errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
233211 self .assertEqual (
0 commit comments