@@ -272,11 +272,14 @@ private static CentroidIterator getCentroidIteratorNoParent(
272272 FixedBitSet acceptCentroids
273273 ) throws IOException {
274274 final NeighborQueue neighborQueue = new NeighborQueue (numCentroids , true );
275+ final long centroidQuantizeSize = fieldInfo .getVectorDimension () + 3 * Float .BYTES + Integer .BYTES ;
275276 score (
276277 neighborQueue ,
277278 numCentroids ,
278279 0 ,
279280 scorer ,
281+ centroids ,
282+ centroidQuantizeSize ,
280283 quantizeQuery ,
281284 queryParams ,
282285 globalCentroidDp ,
@@ -315,26 +318,41 @@ private static CentroidIterator getCentroidIteratorWithParents(
315318 FixedBitSet acceptCentroids
316319 ) throws IOException {
317320 // build the three queues we are going to use
321+ final long centroidQuantizeSize = fieldInfo .getVectorDimension () + 3 * Float .BYTES + Integer .BYTES ;
318322 final NeighborQueue parentsQueue = new NeighborQueue (numParents , true );
319323 final int maxChildrenSize = centroids .readVInt ();
320324 final NeighborQueue currentParentQueue = new NeighborQueue (maxChildrenSize , true );
321325 final int bufferSize = (int ) Math .min (Math .max (centroidRatio * numCentroids , 1 ), numCentroids );
322- final NeighborQueue neighborQueue = new NeighborQueue (bufferSize , true );
323- // score the parents
326+ final int numCentroidsFiltered = acceptCentroids == null ? numCentroids : acceptCentroids .cardinality ();
324327 final float [] scores = new float [ES92Int7VectorsScorer .BULK_SIZE ];
325- score (
326- parentsQueue ,
327- numParents ,
328- 0 ,
329- scorer ,
330- quantizeQuery ,
331- queryParams ,
332- globalCentroidDp ,
333- fieldInfo .getVectorSimilarityFunction (),
334- scores ,
335- null
336- );
337- final long centroidQuantizeSize = fieldInfo .getVectorDimension () + 3 * Float .BYTES + Integer .BYTES ;
328+ final NeighborQueue neighborQueue ;
329+ if (acceptCentroids != null && numCentroidsFiltered <= bufferSize ) {
330+ // we are collecting every non-filter centroid, therefore we do not need to score the
331+ // parents. We give each of them the same score.
332+ neighborQueue = new NeighborQueue (numCentroidsFiltered , true );
333+ for (int i = 0 ; i < numParents ; i ++) {
334+ parentsQueue .add (i , 0.5f );
335+ }
336+ centroids .skipBytes (centroidQuantizeSize * numParents );
337+ } else {
338+ neighborQueue = new NeighborQueue (bufferSize , true );
339+ // score the parents
340+ score (
341+ parentsQueue ,
342+ numParents ,
343+ 0 ,
344+ scorer ,
345+ centroids ,
346+ centroidQuantizeSize ,
347+ quantizeQuery ,
348+ queryParams ,
349+ globalCentroidDp ,
350+ fieldInfo .getVectorSimilarityFunction (),
351+ scores ,
352+ null
353+ );
354+ }
355+
338356 final long offset = centroids .getFilePointer ();
339357 final long childrenOffset = offset + (long ) Long .BYTES * numParents ;
340358 // populate the children's queue by reading parents one by one
@@ -429,6 +447,8 @@ private static void populateOneChildrenGroup(
429447 numChildren ,
430448 childrenOrdinal ,
431449 scorer ,
450+ centroids ,
451+ centroidQuantizeSize ,
432452 quantizeQuery ,
433453 queryParams ,
434454 globalCentroidDp ,
@@ -443,48 +463,56 @@ private static void score(
443463 int size ,
444464 int scoresOffset ,
445465 ES92Int7VectorsScorer scorer ,
466+ IndexInput centroids ,
467+ long centroidQuantizeSize ,
446468 byte [] quantizeQuery ,
447469 OptimizedScalarQuantizer .QuantizationResult queryCorrections ,
448470 float centroidDp ,
449471 VectorSimilarityFunction similarityFunction ,
450472 float [] scores ,
451473 FixedBitSet acceptCentroids
452474 ) throws IOException {
453- // TODO: if accept centroids is not null, we can save some vector ops here
454475 int limit = size - ES92Int7VectorsScorer .BULK_SIZE + 1 ;
455476 int i = 0 ;
456477 for (; i < limit ; i += ES92Int7VectorsScorer .BULK_SIZE ) {
457- scorer .scoreBulk (
458- quantizeQuery ,
459- queryCorrections .lowerInterval (),
460- queryCorrections .upperInterval (),
461- queryCorrections .quantizedComponentSum (),
462- queryCorrections .additionalCorrection (),
463- similarityFunction ,
464- centroidDp ,
465- scores
466- );
467- for (int j = 0 ; j < ES92Int7VectorsScorer .BULK_SIZE ; j ++) {
468- int centroidOrd = scoresOffset + i + j ;
469- if (acceptCentroids == null || acceptCentroids .get (centroidOrd )) {
470- neighborQueue .add (centroidOrd , scores [j ]);
478+ if (acceptCentroids == null
479+ || acceptCentroids .cardinality (scoresOffset + i , scoresOffset + i + ES92Int7VectorsScorer .BULK_SIZE ) > 0 ) {
480+ scorer .scoreBulk (
481+ quantizeQuery ,
482+ queryCorrections .lowerInterval (),
483+ queryCorrections .upperInterval (),
484+ queryCorrections .quantizedComponentSum (),
485+ queryCorrections .additionalCorrection (),
486+ similarityFunction ,
487+ centroidDp ,
488+ scores
489+ );
490+ for (int j = 0 ; j < ES92Int7VectorsScorer .BULK_SIZE ; j ++) {
491+ int centroidOrd = scoresOffset + i + j ;
492+ if (acceptCentroids == null || acceptCentroids .get (centroidOrd )) {
493+ neighborQueue .add (centroidOrd , scores [j ]);
494+ }
471495 }
496+ } else {
497+ centroids .skipBytes (ES92Int7VectorsScorer .BULK_SIZE * centroidQuantizeSize );
472498 }
473499 }
474500
475501 for (; i < size ; i ++) {
476- float score = scorer .score (
477- quantizeQuery ,
478- queryCorrections .lowerInterval (),
479- queryCorrections .upperInterval (),
480- queryCorrections .quantizedComponentSum (),
481- queryCorrections .additionalCorrection (),
482- similarityFunction ,
483- centroidDp
484- );
485502 int centroidOrd = scoresOffset + i ;
486503 if (acceptCentroids == null || acceptCentroids .get (centroidOrd )) {
504+ float score = scorer .score (
505+ quantizeQuery ,
506+ queryCorrections .lowerInterval (),
507+ queryCorrections .upperInterval (),
508+ queryCorrections .quantizedComponentSum (),
509+ queryCorrections .additionalCorrection (),
510+ similarityFunction ,
511+ centroidDp
512+ );
487513 neighborQueue .add (centroidOrd , score );
514+ } else {
515+ centroids .skipBytes (centroidQuantizeSize );
488516 }
489517 }
490518 }
0 commit comments