From c4ac3fb477a98c40c10cee834146729a35f4bd60 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 7 Nov 2025 11:19:54 -0500 Subject: [PATCH 1/2] Adds new ESAcceptDocs class and usage, allowing for future use in knn searching --- .../diskbbq/ES920DiskBBQVectorsReader.java | 2 + .../vectors/diskbbq/IVFVectorsReader.java | 14 +- .../next/ESNextDiskBBQVectorsReader.java | 2 + .../vectors/AbstractIVFKnnVectorQuery.java | 28 +- .../search/vectors/ESAcceptDocs.java | 243 ++++++++++++++++++ .../search/vectors/ESAcceptDocsTests.java | 139 ++++++++++ 6 files changed, 419 insertions(+), 9 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/ESAcceptDocs.java create mode 100644 server/src/test/java/org/elasticsearch/search/vectors/ESAcceptDocsTests.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java index 3f0c7cd6d64ec..338e3cdc46428 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; @@ -88,6 +89,7 @@ public CentroidIterator getCentroidIterator( IndexInput centroids, float[] targetQuery, IndexInput postingListSlice, + AcceptDocs acceptDocs, float visitRatio ) throws IOException { final FieldEntry fieldEntry = fields.get(fieldInfo.number); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java index 3fccf9de8095a..5e6828582f75c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java @@ -31,6 +31,7 @@ import org.apache.lucene.util.Bits; import org.elasticsearch.core.IOUtils; import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders; +import org.elasticsearch.search.vectors.ESAcceptDocs; import org.elasticsearch.search.vectors.IVFKnnSearchStrategy; import java.io.Closeable; @@ -114,6 +115,7 @@ public abstract CentroidIterator getCentroidIterator( IndexInput centroids, float[] target, IndexInput postingListSlice, + AcceptDocs acceptDocs, float visitRatio ) throws IOException; @@ -283,8 +285,17 @@ public final void search(String field, float[] target, KnnCollector knnCollector "vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension() ); } + final ESAcceptDocs esAcceptDocs; + if (acceptDocs instanceof ESAcceptDocs) { + esAcceptDocs = (ESAcceptDocs) acceptDocs; + } else { + esAcceptDocs = null; + } int numVectors = getReaderForField(field).getFloatVectorValues(field).size(); - float percentFiltered = Math.max(0f, Math.min(1f, (float) acceptDocs.cost() / numVectors)); + float percentFiltered = Math.max( + 0f, + Math.min(1f, (float) (esAcceptDocs == null ? acceptDocs.cost() : esAcceptDocs.approximateCost()) / numVectors) + ); float visitRatio = DYNAMIC_VISIT_RATIO; // Search strategy may be null if this is being called from checkIndex (e.g. from a test) if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) { @@ -311,6 +322,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector entry.centroidSlice(ivfCentroids), target, postListSlice, + esAcceptDocs == null ? acceptDocs : esAcceptDocs, visitRatio ); Bits acceptDocsBits = acceptDocs.bits(); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java index c788f2264b000..64f099185eec5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; @@ -87,6 +88,7 @@ public CentroidIterator getCentroidIterator( IndexInput centroids, float[] targetQuery, IndexInput postingListSlice, + AcceptDocs acceptDocs, float visitRatio ) throws IOException { final FieldEntry fieldEntry = fields.get(fieldInfo.number); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java index 08825c55029fc..61ca0ed48fb93 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java @@ -25,13 +25,14 @@ import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.Weight; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.util.Bits; import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; @@ -182,20 +183,31 @@ private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, IVFCollec TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, IVFCollectorManager knnCollectorManager, float visitRatio) throws IOException { final LeafReader reader = ctx.reader(); + final Bits liveDocs = reader.getLiveDocs(); + final int maxDoc = reader.maxDoc(); if (filterWeight == null) { - AcceptDocs acceptDocs = AcceptDocs.fromLiveDocs(reader.getLiveDocs(), reader.maxDoc()); - return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE, knnCollectorManager, visitRatio); + return approximateSearch( + ctx, + liveDocs == null ? ESAcceptDocs.ESAcceptDocsAll.INSTANCE : new ESAcceptDocs.BitsAcceptDocs(liveDocs, maxDoc), + Integer.MAX_VALUE, + knnCollectorManager, + visitRatio + ); } - Scorer scorer = filterWeight.scorer(ctx); - if (scorer == null) { + ScorerSupplier supplier = filterWeight.scorerSupplier(ctx); + if (supplier == null) { return TopDocsCollector.EMPTY_TOPDOCS; } - AcceptDocs acceptDocs = AcceptDocs.fromIteratorSupplier(scorer::iterator, reader.getLiveDocs(), reader.maxDoc()); - final int cost = acceptDocs.cost(); - return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager, visitRatio); + return approximateSearch( + ctx, + new ESAcceptDocs.ScorerSupplierAcceptDocs(supplier, liveDocs, maxDoc), + Integer.MAX_VALUE, + knnCollectorManager, + visitRatio + ); } abstract TopDocs approximateSearch( diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESAcceptDocs.java b/server/src/main/java/org/elasticsearch/search/vectors/ESAcceptDocs.java new file mode 100644 index 0000000000000..7658d1081f560 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESAcceptDocs.java @@ -0,0 +1,243 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ +package org.elasticsearch.search.vectors; + +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.FilteredDocIdSetIterator; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; + +import java.io.IOException; +import java.util.Objects; +import java.util.Optional; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * An extension of {@link AcceptDocs} that provides additional methods to get an approximate cost + * and a BitSet representation of the accepted documents. + */ +public abstract sealed class ESAcceptDocs extends AcceptDocs { + + /** Returns an approximate cost of the accepted documents. + * This is generally much cheaper than {@link #cost()}, as implementations may + * not fully evaluate filters to provide this estimate and may ignore deletions + * @return the approximate cost + * @throws IOException if an I/O error occurs + */ + public abstract int approximateCost() throws IOException; + + /** + * Returns an optional BitSet representing the accepted documents. + * If a BitSet representation is not available, returns an empty Optional. An empty optional indicates that + * there are some accepted documents, but they cannot be represented as a BitSet efficiently. + * Null implies that all documents are accepted. + * @return an Optional containing the BitSet of accepted documents, or empty if not available, or null if all documents are accepted + * @throws IOException if an I/O error occurs + */ + public abstract Optional getBitSet() throws IOException; + + private static BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException { + if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { + // If we already have a BitSet and no deletions, reuse the BitSet + return bitSetIterator.getBitSet(); + } else { + int threshold = maxDoc >> 7; // same as BitSet#of + if (iterator.cost() >= threshold) { + FixedBitSet bitSet = new FixedBitSet(maxDoc); + bitSet.or(iterator); + if (liveDocs != null) { + liveDocs.applyMask(bitSet, 0); + } + return bitSet; + } else { + return BitSet.of(liveDocs == null ? iterator : new FilteredDocIdSetIterator(iterator) { + @Override + protected boolean match(int doc) { + return liveDocs.get(doc); + } + }, maxDoc); // create a sparse bitset + } + } + } + + /** An AcceptDocs that accepts all documents. */ + public static final class ESAcceptDocsAll extends ESAcceptDocs { + public static final ESAcceptDocsAll INSTANCE = new ESAcceptDocsAll(); + + private ESAcceptDocsAll() {} + + @Override + public int approximateCost() throws IOException { + return 0; + } + + @Override + public Optional getBitSet() throws IOException { + return null; + } + + @Override + public Bits bits() throws IOException { + return null; + } + + @Override + public DocIdSetIterator iterator() throws IOException { + return null; + } + + @Override + public int cost() throws IOException { + return 0; + } + } + + /** An AcceptDocs that wraps a Bits instance. Generally indicates that no filter was provided, but there are deleted docs */ + public static final class BitsAcceptDocs extends ESAcceptDocs { + private final Bits bits; + private final BitSet bitSetRef; + private final int maxDoc; + private final int approximateCost; + + BitsAcceptDocs(Bits bits, int maxDoc) { + if (bits != null && bits.length() != maxDoc) { + throw new IllegalArgumentException("Bits length = " + bits.length() + " != maxDoc = " + maxDoc); + } + this.bits = bits; + if (bits instanceof BitSet bitSet) { + this.maxDoc = Objects.requireNonNull(bitSet).cardinality(); + this.approximateCost = Objects.requireNonNull(bitSet).approximateCardinality(); + this.bitSetRef = bitSet; + } else { + this.maxDoc = maxDoc; + this.approximateCost = maxDoc; + this.bitSetRef = null; + } + } + + @Override + public Bits bits() { + return bits; + } + + @Override + public DocIdSetIterator iterator() { + if (bits instanceof BitSet bitSet) { + return new BitSetIterator(bitSet, maxDoc); + } + return new FilteredDocIdSetIterator(DocIdSetIterator.all(maxDoc)) { + @Override + protected boolean match(int doc) { + return bits.get(doc); + } + }; + } + + @Override + public int cost() { + // We have no better estimate. This should be ok in practice since background merges should + // keep the number of deletes under control (< 20% by default). + return maxDoc; + } + + @Override + public int approximateCost() { + return approximateCost; + } + + @Override + public Optional getBitSet() { + if (bits == null) { + return null; + } + return Optional.ofNullable(bitSetRef); + } + } + + /** An AcceptDocs that wraps a ScorerSupplier. Indicates that a filter was provided. */ + public static final class ScorerSupplierAcceptDocs extends ESAcceptDocs { + private final ScorerSupplier scorerSupplier; + private BitSet acceptBitSet; + private final Bits liveDocs; + private final int maxDoc; + private int cardinality = -1; + + ScorerSupplierAcceptDocs(ScorerSupplier scorerSupplier, Bits liveDocs, int maxDoc) { + this.scorerSupplier = scorerSupplier; + this.liveDocs = liveDocs; + this.maxDoc = maxDoc; + } + + private void createBitSetIfNecessary() throws IOException { + if (acceptBitSet == null) { + acceptBitSet = createBitSet(scorerSupplier.get(NO_MORE_DOCS).iterator(), liveDocs, maxDoc); + } + } + + @Override + public Bits bits() throws IOException { + createBitSetIfNecessary(); + return acceptBitSet; + } + + @Override + public DocIdSetIterator iterator() throws IOException { + if (acceptBitSet != null) { + return new BitSetIterator(acceptBitSet, cardinality); + } + return liveDocs == null + ? scorerSupplier.get(NO_MORE_DOCS).iterator() + : new FilteredDocIdSetIterator(scorerSupplier.get(NO_MORE_DOCS).iterator()) { + @Override + protected boolean match(int doc) { + return liveDocs.get(doc); + } + }; + } + + @Override + public int cost() throws IOException { + createBitSetIfNecessary(); + if (cardinality == -1) { + cardinality = acceptBitSet.cardinality(); + } + return cardinality; + } + + @Override + public int approximateCost() throws IOException { + if (acceptBitSet != null) { + return cardinality != -1 ? cardinality : acceptBitSet.approximateCardinality(); + } + return Math.toIntExact(scorerSupplier.cost()); + } + + @Override + public Optional getBitSet() throws IOException { + createBitSetIfNecessary(); + return Optional.of(acceptBitSet); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/ESAcceptDocsTests.java b/server/src/test/java/org/elasticsearch/search/vectors/ESAcceptDocsTests.java new file mode 100644 index 0000000000000..b4ddf071afa2a --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/ESAcceptDocsTests.java @@ -0,0 +1,139 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +public class ESAcceptDocsTests extends ESTestCase { + + public void testAcceptAllDocs() throws IOException { + ESAcceptDocs acceptDocs = ESAcceptDocs.ESAcceptDocsAll.INSTANCE; + assertEquals(0L, acceptDocs.approximateCost()); + assertEquals(0L, acceptDocs.cost()); + assertNull(acceptDocs.iterator()); + assertNull(acceptDocs.bits()); + assertNull(acceptDocs.getBitSet()); + } + + public void testFromScorerSupplier() throws IOException { + int[] docIds = new int[] { 1, 3, 5, 7, 9 }; + BitSet bitSet = new FixedBitSet(10); + for (int docId : docIds) { + bitSet.set(docId); + } + { + DocIdSetIterator iterator = new BitSetIterator(bitSet, bitSet.cardinality()); + ESAcceptDocs acceptDocs = new ESAcceptDocs.ScorerSupplierAcceptDocs(new TestScorerSupplier(iterator), null, 10); + assertEquals(iterator.cost(), acceptDocs.approximateCost()); + assertEquals(iterator.cost(), acceptDocs.cost()); + // iterate the docs ensuring they match + DocIdSetIterator acceptDocsIterator = acceptDocs.iterator(); + for (int docId : docIds) { + assertEquals(docId, acceptDocsIterator.nextDoc()); + } + } + { + DocIdSetIterator iterator = new BitSetIterator(bitSet, bitSet.cardinality()); + ESAcceptDocs acceptDocs = new ESAcceptDocs.ScorerSupplierAcceptDocs(new TestScorerSupplier(iterator), null, 10); + Bits acceptDocsBits = acceptDocs.bits(); + for (int i = 0; i < 10; i++) { + assertEquals(bitSet.get(i), acceptDocsBits.get(i)); + } + } + { + DocIdSetIterator iterator = new BitSetIterator(bitSet, bitSet.cardinality()); + FixedBitSet liveDocs = new FixedBitSet(10); + liveDocs.set(0, 10); + // lets delete docs 1, 3, 9 + liveDocs.clear(1); + liveDocs.clear(3); + liveDocs.clear(9); + ESAcceptDocs acceptDocs = new ESAcceptDocs.ScorerSupplierAcceptDocs(new TestScorerSupplier(iterator), liveDocs, 10); + // verify approximate cost doesn't count deleted docs + assertEquals(5L, acceptDocs.approximateCost()); + // actual cost should count only live docs + assertEquals(2L, acceptDocs.cost()); + // iterate the docs ensuring they match + DocIdSetIterator acceptDocsIterator = acceptDocs.iterator(); + assertEquals(5, acceptDocsIterator.nextDoc()); + assertEquals(7, acceptDocsIterator.nextDoc()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, acceptDocsIterator.nextDoc()); + } + } + + public void testFromBits() throws IOException { + FixedBitSet acceptedDocs = new FixedBitSet(10); + acceptedDocs.set(1); + acceptedDocs.set(3); + acceptedDocs.set(5); + ESAcceptDocs acceptDocs = new ESAcceptDocs.BitsAcceptDocs(acceptedDocs, 10); + assertEquals(3L, acceptDocs.approximateCost()); + assertEquals(3L, acceptDocs.cost()); + // iterate the docs ensuring they match + DocIdSetIterator acceptDocsIterator = acceptDocs.iterator(); + assertEquals(1, acceptDocsIterator.nextDoc()); + assertEquals(3, acceptDocsIterator.nextDoc()); + assertEquals(5, acceptDocsIterator.nextDoc()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, acceptDocsIterator.nextDoc()); + // verify bits + Bits acceptDocsBits = acceptDocs.bits(); + for (int i = 0; i < 10; i++) { + assertEquals(acceptedDocs.get(i), acceptDocsBits.get(i)); + } + } + + private static class TestScorerSupplier extends ScorerSupplier { + private final DocIdSetIterator iterator; + + TestScorerSupplier(DocIdSetIterator iterator) { + this.iterator = iterator; + } + + @Override + public Scorer get(long leadCost) throws IOException { + return new Scorer() { + @Override + public int docID() { + return iterator.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return Float.MAX_VALUE; + } + + @Override + public float score() throws IOException { + return 0; + } + }; + } + + @Override + public long cost() { + return iterator.cost(); + } + } + +} From ede10928bd191686b322177f513939023bd6428d Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 11 Nov 2025 08:07:30 -0500 Subject: [PATCH 2/2] iter --- .../java/org/elasticsearch/search/vectors/ESAcceptDocs.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESAcceptDocs.java b/server/src/main/java/org/elasticsearch/search/vectors/ESAcceptDocs.java index 7658d1081f560..0eeaaa85df3e6 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESAcceptDocs.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESAcceptDocs.java @@ -125,7 +125,7 @@ public static final class BitsAcceptDocs extends ESAcceptDocs { if (bits != null && bits.length() != maxDoc) { throw new IllegalArgumentException("Bits length = " + bits.length() + " != maxDoc = " + maxDoc); } - this.bits = bits; + this.bits = Objects.requireNonNull(bits); if (bits instanceof BitSet bitSet) { this.maxDoc = Objects.requireNonNull(bitSet).cardinality(); this.approximateCost = Objects.requireNonNull(bitSet).approximateCardinality(); @@ -144,8 +144,8 @@ public Bits bits() { @Override public DocIdSetIterator iterator() { - if (bits instanceof BitSet bitSet) { - return new BitSetIterator(bitSet, maxDoc); + if (bitSetRef != null) { + return new BitSetIterator(bitSetRef, maxDoc); } return new FilteredDocIdSetIterator(DocIdSetIterator.all(maxDoc)) { @Override