|
16 | 16 |
|
17 | 17 | package com.mongodb.client.model.search; |
18 | 18 |
|
19 | | -import com.mongodb.MongoInterruptedException; |
20 | 19 | import com.mongodb.MongoNamespace; |
21 | 20 | import com.mongodb.client.model.Aggregates; |
22 | | -import com.mongodb.client.model.SearchIndexType; |
23 | 21 | import com.mongodb.client.test.CollectionHelper; |
24 | | -import com.mongodb.internal.operation.SearchIndexRequest; |
25 | 22 | import org.bson.BinaryVector; |
26 | | -import org.bson.BsonDocument; |
27 | 23 | import org.bson.Document; |
28 | 24 | import org.bson.codecs.DocumentCodec; |
29 | 25 | import org.bson.conversions.Bson; |
30 | | -import org.junit.jupiter.api.AfterAll; |
31 | 26 | import org.junit.jupiter.api.Assertions; |
32 | 27 | import org.junit.jupiter.api.BeforeAll; |
33 | 28 | import org.junit.jupiter.params.ParameterizedTest; |
34 | 29 | import org.junit.jupiter.params.provider.Arguments; |
35 | 30 | import org.junit.jupiter.params.provider.MethodSource; |
36 | 31 |
|
37 | 32 | import java.util.List; |
38 | | -import java.util.Optional; |
39 | | -import java.util.concurrent.TimeUnit; |
40 | 33 | import java.util.function.Consumer; |
41 | 34 | import java.util.stream.Stream; |
42 | 35 |
|
|
57 | 50 | import static com.mongodb.client.model.search.SearchPath.fieldPath; |
58 | 51 | import static com.mongodb.client.model.search.VectorSearchOptions.approximateVectorSearchOptions; |
59 | 52 | import static com.mongodb.client.model.search.VectorSearchOptions.exactVectorSearchOptions; |
60 | | -import static java.lang.String.format; |
61 | 53 | import static java.util.Arrays.asList; |
62 | 54 | import static java.util.Collections.singletonList; |
63 | 55 | import static org.junit.jupiter.api.Assertions.assertAll; |
|
67 | 59 | import static org.junit.jupiter.api.Assumptions.assumeTrue; |
68 | 60 | import static org.junit.jupiter.params.provider.Arguments.arguments; |
69 | 61 |
|
70 | | -class AggregatesBinaryVectorSearchIntegrationTest { |
71 | | - private static final String EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE = |
72 | | - "Exceeded maximum attempts waiting for Search Index creation in Atlas cluster. Index document: %s"; |
73 | 62 |
|
| 63 | +/** |
| 64 | + * This test runs on an atlas qa cluster in the `javaExtraTests.binaryVectorTests` namespace. |
| 65 | + * With readOnly user permissions. |
| 66 | + * <p> |
| 67 | + * With the following index: |
| 68 | + * <code> |
| 69 | + * { |
| 70 | + * "name": "vector_search_index", "type": "vectorSearch", |
| 71 | + * "definition": {"fields": [ |
| 72 | + * {"path": "int8Vector", "numDimensions": 5, "similarity": "cosine", "type": "vector"}, |
| 73 | + * {"path": "float32Vector", "numDimensions": 5, "similarity": "cosine", "type": "vector"}, |
| 74 | + * {"path": "legacyDoubleVector", "numDimensions": 5, "similarity": "cosine", "type": "vector"}, |
| 75 | + * {"path": "year", "type": "filter"}]} |
| 76 | + * } |
| 77 | + * </code> |
| 78 | + * <p> |
| 79 | + * And the following test data: |
| 80 | + * <code> |
| 81 | + * [{"_id":0, "int8Vector":{"$binary":{"base64":"AwAAAQIDBA==", "subType":"09"}}, |
| 82 | + * "float32Vector":{"$binary":{"base64":"JwAXt9E4Ns2PPwgDD0B1H1ZA8Z2OQA==", "subType":"09"}}, |
| 83 | + * "legacyDoubleVector":[0.0001,1.12345,2.23456,3.34567,4.45678], "year":2016}, |
| 84 | + * {"_id":1, "int8Vector":{"$binary":{"base64":"AwABAgMEBQ==", "subType":"09"}}, |
| 85 | + * "float32Vector":{"$binary":{"base64":"JwBHA4A/m+YHQAgDT0C7D4tA8Z2uQA==", "subType":"09"}}, |
| 86 | + * "legacyDoubleVector":[1.0001,2.12345,3.23456,4.34567,5.45678], "year":2017}, |
| 87 | + * {"_id":2, "int8Vector":{"$binary":{"base64":"AwACAwQFBg==", "subType":"09"}}, |
| 88 | + * "float32Vector":{"$binary":{"base64":"JwBHAwBAm+ZHQISBh0C7D6tA8Z3OQA==", "subType":"09"}}, |
| 89 | + * "legacyDoubleVector":[2.0002,3.12345,4.23456,5.34567,6.45678], "year":2018}}, |
| 90 | + * {"_id":3, "int8Vector":{"$binary":{"base64":"AwADBAUGBw==", "subType":"09"}}, |
| 91 | + * "float32Vector":{"$binary":{"base64":"JwDqBEBATfODQISBp0C7D8tA8Z3uQA==", "subType":"09"}}, |
| 92 | + * "legacyDoubleVector":[3.0003,4.12345,5.23456,6.34567,7.45678], "year":2019}}, |
| 93 | + * {"_id":4, "int8Vector":{"$binary":{"base64":"AwAEBQYHCA==", "subType":"09"}}, |
| 94 | + * "float32Vector":{"$binary":{"base64":"JwBHA4BATfOjQISBx0C7D+tA+U4HQQ==", "subType":"09"}}, |
| 95 | + * "legacyDoubleVector":[4.0004,5.12345,6.23456,7.34567,8.45678], "year":2020}}, |
| 96 | + * {"_id":5, "int8Vector":{"$binary":{"base64":"AwAFBgcICQ==", "subType":"09"}}, |
| 97 | + * "float32Vector":{"$binary":{"base64":"JwAZBKBATfPDQISB50DdhwVB+U4XQQ==", "subType":"09"}}, |
| 98 | + * "legacyDoubleVector":[5.0005,6.12345,7.23456,8.34567,9.45678], "year":2021}}, |
| 99 | + * {"_id":6, "int8Vector":{"$binary":{"base64":"AwAGBwgJCg==", "subType":"09"}}, |
| 100 | + * "float32Vector":{"$binary":{"base64":"JwDqBMBATfPjQMLAA0HdhxVB+U4nQQ==", "subType":"09"}}, |
| 101 | + * "legacyDoubleVector":[6.0006,7.12345,8.23456,9.34567,10.45678], "year":2022}}, |
| 102 | + * {"_id":7, "int8Vector":{"$binary":{"base64":"AwAHCAkKCw==", "subType":"09"}}, |
| 103 | + * "float32Vector":{"$binary":{"base64":"JwC8BeBAp/kBQcLAE0HdhyVB+U43QQ==", "subType":"09"}}, |
| 104 | + * "legacyDoubleVector":[7.0007,8.12345,9.23456,10.34567,11.45678], "year":2023}}, |
| 105 | + * {"_id":8, "int8Vector":{"$binary":{"base64":"AwAICQoLDA==", "subType":"09"}}, |
| 106 | + * "float32Vector":{"$binary":{"base64":"JwBHAwBBp/kRQcLAI0HdhzVB+U5HQQ==", "subType":"09"}}, |
| 107 | + * "legacyDoubleVector":[8.0008,9.12345,10.23456,11.34567,12.45678], "year":2024}}, |
| 108 | + * {"_id":9, "int8Vector":{"$binary":{"base64":"AwAJCgsMDQ==", "subType":"09"}}, |
| 109 | + * "float32Vector":{"$binary":{"base64":"JwCwAxBBp/khQcLAM0Hdh0VB+U5XQQ==", "subType":"09"}}, |
| 110 | + * "legacyDoubleVector":[9.0009,10.12345,11.23456,12.34567,13.45678], "year":2025}] |
| 111 | + * </code> |
| 112 | + */ |
| 113 | +class AggregatesBinaryVectorSearchIntegrationTest { |
| 114 | + private static final MongoNamespace BINARY_VECTOR_NAMESPACE = new MongoNamespace("javaExtraTests", "binaryVectorTests"); |
74 | 115 | private static final String VECTOR_INDEX = "vector_search_index"; |
75 | 116 | private static final String VECTOR_FIELD_INT_8 = "int8Vector"; |
76 | 117 | private static final String VECTOR_FIELD_FLOAT_32 = "float32Vector"; |
77 | 118 | private static final String VECTOR_FIELD_LEGACY_DOUBLE_LIST = "legacyDoubleVector"; |
78 | 119 | private static final int LIMIT = 5; |
79 | | - private static final String FIELD_YEAR = "year"; |
80 | 120 | private static CollectionHelper<Document> collectionHelper; |
81 | | - private static final BsonDocument VECTOR_SEARCH_INDEX_DEFINITION = BsonDocument.parse( |
82 | | - "{" |
83 | | - + " fields: [" |
84 | | - + " {" |
85 | | - + " path: '" + VECTOR_FIELD_INT_8 + "'," |
86 | | - + " numDimensions: 5," |
87 | | - + " similarity: 'cosine'," |
88 | | - + " type: 'vector'," |
89 | | - + " }," |
90 | | - + " {" |
91 | | - + " path: '" + VECTOR_FIELD_FLOAT_32 + "'," |
92 | | - + " numDimensions: 5," |
93 | | - + " similarity: 'cosine'," |
94 | | - + " type: 'vector'," |
95 | | - + " }," |
96 | | - + " {" |
97 | | - + " path: '" + VECTOR_FIELD_LEGACY_DOUBLE_LIST + "'," |
98 | | - + " numDimensions: 5," |
99 | | - + " similarity: 'cosine'," |
100 | | - + " type: 'vector'," |
101 | | - + " }," |
102 | | - + " {" |
103 | | - + " path: '" + FIELD_YEAR + "'," |
104 | | - + " type: 'filter'," |
105 | | - + " }," |
106 | | - + " ]" |
107 | | - + "}"); |
| 121 | + |
108 | 122 |
|
109 | 123 | @BeforeAll |
110 | 124 | static void beforeAll() { |
111 | 125 | assumeTrue(isAtlasSearchTest()); |
112 | 126 | assumeTrue(serverVersionAtLeast(6, 0)); |
113 | 127 |
|
114 | | - collectionHelper = |
115 | | - new CollectionHelper<>(new DocumentCodec(), new MongoNamespace("javaVectorSearchTest", AggregatesBinaryVectorSearchIntegrationTest.class.getSimpleName())); |
116 | | - collectionHelper.drop(); |
117 | | - collectionHelper.insertDocuments( |
118 | | - new Document() |
119 | | - .append("_id", 0) |
120 | | - .append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{0, 1, 2, 3, 4})) |
121 | | - .append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f})) |
122 | | - .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{0.0001, 1.12345, 2.23456, 3.34567, 4.45678}) |
123 | | - .append(FIELD_YEAR, 2016), |
124 | | - new Document() |
125 | | - .append("_id", 1) |
126 | | - .append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{1, 2, 3, 4, 5})) |
127 | | - .append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{1.0001f, 2.12345f, 3.23456f, 4.34567f, 5.45678f})) |
128 | | - .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{1.0001, 2.12345, 3.23456, 4.34567, 5.45678}) |
129 | | - .append(FIELD_YEAR, 2017), |
130 | | - new Document() |
131 | | - .append("_id", 2) |
132 | | - .append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{2, 3, 4, 5, 6})) |
133 | | - .append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{2.0002f, 3.12345f, 4.23456f, 5.34567f, 6.45678f})) |
134 | | - .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{2.0002, 3.12345, 4.23456, 5.34567, 6.45678}) |
135 | | - .append(FIELD_YEAR, 2018), |
136 | | - new Document() |
137 | | - .append("_id", 3) |
138 | | - .append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{3, 4, 5, 6, 7})) |
139 | | - .append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{3.0003f, 4.12345f, 5.23456f, 6.34567f, 7.45678f})) |
140 | | - .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{3.0003, 4.12345, 5.23456, 6.34567, 7.45678}) |
141 | | - .append(FIELD_YEAR, 2019), |
142 | | - new Document() |
143 | | - .append("_id", 4) |
144 | | - .append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{4, 5, 6, 7, 8})) |
145 | | - .append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{4.0004f, 5.12345f, 6.23456f, 7.34567f, 8.45678f})) |
146 | | - .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{4.0004, 5.12345, 6.23456, 7.34567, 8.45678}) |
147 | | - .append(FIELD_YEAR, 2020), |
148 | | - new Document() |
149 | | - .append("_id", 5) |
150 | | - .append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{5, 6, 7, 8, 9})) |
151 | | - .append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{5.0005f, 6.12345f, 7.23456f, 8.34567f, 9.45678f})) |
152 | | - .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{5.0005, 6.12345, 7.23456, 8.34567, 9.45678}) |
153 | | - .append(FIELD_YEAR, 2021), |
154 | | - new Document() |
155 | | - .append("_id", 6) |
156 | | - .append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{6, 7, 8, 9, 10})) |
157 | | - .append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{6.0006f, 7.12345f, 8.23456f, 9.34567f, 10.45678f})) |
158 | | - .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{6.0006, 7.12345, 8.23456, 9.34567, 10.45678}) |
159 | | - .append(FIELD_YEAR, 2022), |
160 | | - new Document() |
161 | | - .append("_id", 7) |
162 | | - .append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{7, 8, 9, 10, 11})) |
163 | | - .append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{7.0007f, 8.12345f, 9.23456f, 10.34567f, 11.45678f})) |
164 | | - .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{7.0007, 8.12345, 9.23456, 10.34567, 11.45678}) |
165 | | - .append(FIELD_YEAR, 2023), |
166 | | - new Document() |
167 | | - .append("_id", 8) |
168 | | - .append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{8, 9, 10, 11, 12})) |
169 | | - .append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{8.0008f, 9.12345f, 10.23456f, 11.34567f, 12.45678f})) |
170 | | - .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{8.0008, 9.12345, 10.23456, 11.34567, 12.45678}) |
171 | | - .append(FIELD_YEAR, 2024), |
172 | | - new Document() |
173 | | - .append("_id", 9) |
174 | | - .append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{9, 10, 11, 12, 13})) |
175 | | - .append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{9.0009f, 10.12345f, 11.23456f, 12.34567f, 13.45678f})) |
176 | | - .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{9.0009, 10.12345, 11.23456, 12.34567, 13.45678}) |
177 | | - .append(FIELD_YEAR, 2025) |
178 | | - ); |
179 | | - |
180 | | - collectionHelper.createSearchIndex( |
181 | | - new SearchIndexRequest(VECTOR_SEARCH_INDEX_DEFINITION, VECTOR_INDEX, |
182 | | - SearchIndexType.vectorSearch())); |
183 | | - awaitIndexCreation(); |
184 | | - } |
185 | | - |
186 | | - @AfterAll |
187 | | - static void afterAll() { |
188 | | - if (collectionHelper != null) { |
189 | | - collectionHelper.drop(); |
190 | | - } |
| 128 | + collectionHelper = new CollectionHelper<>(new DocumentCodec(), BINARY_VECTOR_NAMESPACE); |
191 | 129 | } |
192 | 130 |
|
193 | 131 | private static Stream<Arguments> provideSupportedVectors() { |
@@ -268,7 +206,7 @@ void shouldSearchByVector(final BinaryVector vector, |
268 | 206 | final FieldSearchPath fieldSearchPath, |
269 | 207 | final VectorSearchOptions vectorSearchOptions) { |
270 | 208 | //given |
271 | | - List<Bson> pipeline = asList( |
| 209 | + List<Bson> pipeline = singletonList( |
272 | 210 | Aggregates.vectorSearch( |
273 | 211 | fieldSearchPath, |
274 | 212 | vector, |
@@ -327,27 +265,4 @@ private static void assertScoreIsDecreasing(final List<Document> aggregate) { |
327 | 265 | } |
328 | 266 | } |
329 | 267 |
|
330 | | - private static void awaitIndexCreation() { |
331 | | - int attempts = 10; |
332 | | - Optional<Document> searchIndex = Optional.empty(); |
333 | | - |
334 | | - while (attempts-- > 0) { |
335 | | - searchIndex = collectionHelper.listSearchIndex(VECTOR_INDEX); |
336 | | - if (searchIndex.filter(document -> document.getBoolean("queryable")) |
337 | | - .isPresent()) { |
338 | | - return; |
339 | | - } |
340 | | - |
341 | | - try { |
342 | | - TimeUnit.SECONDS.sleep(5); |
343 | | - } catch (InterruptedException e) { |
344 | | - Thread.currentThread().interrupt(); |
345 | | - throw new MongoInterruptedException(null, e); |
346 | | - } |
347 | | - } |
348 | | - |
349 | | - searchIndex.ifPresent(document -> |
350 | | - Assertions.fail(format(EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE, document.toJson()))); |
351 | | - Assertions.fail(format(EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE, "null")); |
352 | | - } |
353 | 268 | } |
0 commit comments