Skip to content

Commit 1967753

Browse files
authored
[DiskBBQ] save same vector operations when performing centroid filtering (#138312)
1 parent 740ef90 commit 1967753

File tree

3 files changed

+85
-44
lines changed

3 files changed

+85
-44
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,14 @@ private static CentroidIterator getCentroidIteratorNoParent(
272272
FixedBitSet acceptCentroids
273273
) throws IOException {
274274
final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true);
275+
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
275276
score(
276277
neighborQueue,
277278
numCentroids,
278279
0,
279280
scorer,
281+
centroids,
282+
centroidQuantizeSize,
280283
quantizeQuery,
281284
queryParams,
282285
globalCentroidDp,
@@ -315,26 +318,41 @@ private static CentroidIterator getCentroidIteratorWithParents(
315318
FixedBitSet acceptCentroids
316319
) throws IOException {
317320
// build the three queues we are going to use
321+
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
318322
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
319323
final int maxChildrenSize = centroids.readVInt();
320324
final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true);
321325
final int bufferSize = (int) Math.min(Math.max(centroidRatio * numCentroids, 1), numCentroids);
322-
final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
323-
// score the parents
326+
final int numCentroidsFiltered = acceptCentroids == null ? numCentroids : acceptCentroids.cardinality();
324327
final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
325-
score(
326-
parentsQueue,
327-
numParents,
328-
0,
329-
scorer,
330-
quantizeQuery,
331-
queryParams,
332-
globalCentroidDp,
333-
fieldInfo.getVectorSimilarityFunction(),
334-
scores,
335-
null
336-
);
337-
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
328+
final NeighborQueue neighborQueue;
329+
if (acceptCentroids != null && numCentroidsFiltered <= bufferSize) {
330+
// we are collecting every non-filter centroid, therefore we do not need to score the
331+
// parents. We give each of them the same score.
332+
neighborQueue = new NeighborQueue(numCentroidsFiltered, true);
333+
for (int i = 0; i < numParents; i++) {
334+
parentsQueue.add(i, 0.5f);
335+
}
336+
centroids.skipBytes(centroidQuantizeSize * numParents);
337+
} else {
338+
neighborQueue = new NeighborQueue(bufferSize, true);
339+
// score the parents
340+
score(
341+
parentsQueue,
342+
numParents,
343+
0,
344+
scorer,
345+
centroids,
346+
centroidQuantizeSize,
347+
quantizeQuery,
348+
queryParams,
349+
globalCentroidDp,
350+
fieldInfo.getVectorSimilarityFunction(),
351+
scores,
352+
null
353+
);
354+
}
355+
338356
final long offset = centroids.getFilePointer();
339357
final long childrenOffset = offset + (long) Long.BYTES * numParents;
340358
// populate the children's queue by reading parents one by one
@@ -429,6 +447,8 @@ private static void populateOneChildrenGroup(
429447
numChildren,
430448
childrenOrdinal,
431449
scorer,
450+
centroids,
451+
centroidQuantizeSize,
432452
quantizeQuery,
433453
queryParams,
434454
globalCentroidDp,
@@ -443,48 +463,56 @@ private static void score(
443463
int size,
444464
int scoresOffset,
445465
ES92Int7VectorsScorer scorer,
466+
IndexInput centroids,
467+
long centroidQuantizeSize,
446468
byte[] quantizeQuery,
447469
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
448470
float centroidDp,
449471
VectorSimilarityFunction similarityFunction,
450472
float[] scores,
451473
FixedBitSet acceptCentroids
452474
) throws IOException {
453-
// TODO: if accept centroids is not null, we can save some vector ops here
454475
int limit = size - ES92Int7VectorsScorer.BULK_SIZE + 1;
455476
int i = 0;
456477
for (; i < limit; i += ES92Int7VectorsScorer.BULK_SIZE) {
457-
scorer.scoreBulk(
458-
quantizeQuery,
459-
queryCorrections.lowerInterval(),
460-
queryCorrections.upperInterval(),
461-
queryCorrections.quantizedComponentSum(),
462-
queryCorrections.additionalCorrection(),
463-
similarityFunction,
464-
centroidDp,
465-
scores
466-
);
467-
for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) {
468-
int centroidOrd = scoresOffset + i + j;
469-
if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) {
470-
neighborQueue.add(centroidOrd, scores[j]);
478+
if (acceptCentroids == null
479+
|| acceptCentroids.cardinality(scoresOffset + i, scoresOffset + i + ES92Int7VectorsScorer.BULK_SIZE) > 0) {
480+
scorer.scoreBulk(
481+
quantizeQuery,
482+
queryCorrections.lowerInterval(),
483+
queryCorrections.upperInterval(),
484+
queryCorrections.quantizedComponentSum(),
485+
queryCorrections.additionalCorrection(),
486+
similarityFunction,
487+
centroidDp,
488+
scores
489+
);
490+
for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) {
491+
int centroidOrd = scoresOffset + i + j;
492+
if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) {
493+
neighborQueue.add(centroidOrd, scores[j]);
494+
}
471495
}
496+
} else {
497+
centroids.skipBytes(ES92Int7VectorsScorer.BULK_SIZE * centroidQuantizeSize);
472498
}
473499
}
474500

475501
for (; i < size; i++) {
476-
float score = scorer.score(
477-
quantizeQuery,
478-
queryCorrections.lowerInterval(),
479-
queryCorrections.upperInterval(),
480-
queryCorrections.quantizedComponentSum(),
481-
queryCorrections.additionalCorrection(),
482-
similarityFunction,
483-
centroidDp
484-
);
485502
int centroidOrd = scoresOffset + i;
486503
if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) {
504+
float score = scorer.score(
505+
quantizeQuery,
506+
queryCorrections.lowerInterval(),
507+
queryCorrections.upperInterval(),
508+
queryCorrections.quantizedComponentSum(),
509+
queryCorrections.additionalCorrection(),
510+
similarityFunction,
511+
centroidDp
512+
);
487513
neighborQueue.add(centroidOrd, score);
514+
} else {
515+
centroids.skipBytes(centroidQuantizeSize);
488516
}
489517
}
490518
}

server/src/main/java/org/elasticsearch/search/vectors/IVFKnnSearchStrategy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public class IVFKnnSearchStrategy extends KnnSearchStrategy {
1919
private final SetOnce<AbstractMaxScoreKnnCollector> collector = new SetOnce<>();
2020
private final LongAccumulator accumulator;
2121

22-
IVFKnnSearchStrategy(float visitRatio, LongAccumulator accumulator) {
22+
public IVFKnnSearchStrategy(float visitRatio, LongAccumulator accumulator) {
2323
this.visitRatio = visitRatio;
2424
this.accumulator = accumulator;
2525
}

server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@
2929
import org.apache.lucene.index.VectorEncoding;
3030
import org.apache.lucene.index.VectorSimilarityFunction;
3131
import org.apache.lucene.search.AcceptDocs;
32+
import org.apache.lucene.search.KnnCollector;
3233
import org.apache.lucene.search.TopDocs;
34+
import org.apache.lucene.search.TopKnnCollector;
3335
import org.apache.lucene.store.Directory;
3436
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
3537
import org.apache.lucene.tests.util.TestUtil;
3638
import org.apache.lucene.util.BytesRef;
3739
import org.elasticsearch.common.logging.LogConfigurator;
40+
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
3841
import org.junit.Before;
3942

4043
import java.io.IOException;
@@ -353,17 +356,27 @@ private void doRestrictiveFilter(boolean dense) throws IOException {
353356
LeafReader leafReader = getOnlyLeafReader(reader);
354357
float[] vector = randomVector(dimensions);
355358
// we might collect the same document twice because of soar assignments
356-
TopDocs topDocs = leafReader.searchNearestVectors(
359+
KnnCollector collector;
360+
if (random().nextBoolean()) {
361+
collector = new TopKnnCollector(random().nextInt(2 * matchingDocs, 3 * matchingDocs), Integer.MAX_VALUE);
362+
} else {
363+
collector = new TopKnnCollector(
364+
random().nextInt(2 * matchingDocs, 3 * matchingDocs),
365+
Integer.MAX_VALUE,
366+
new IVFKnnSearchStrategy(0.25f, null)
367+
);
368+
}
369+
leafReader.searchNearestVectors(
357370
"f",
358371
vector,
359-
random().nextInt(2 * matchingDocs, 3 * matchingDocs),
372+
collector,
360373
AcceptDocs.fromIteratorSupplier(
361374
() -> leafReader.postings(new Term("k", new BytesRef("A"))),
362375
leafReader.getLiveDocs(),
363376
leafReader.maxDoc()
364-
),
365-
Integer.MAX_VALUE
377+
)
366378
);
379+
TopDocs topDocs = collector.topDocs();
367380
Set<Integer> uniqueDocIds = new HashSet<>();
368381
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
369382
uniqueDocIds.add(topDocs.scoreDocs[i].doc);

0 commit comments

Comments
 (0)