Skip to content

Commit 6d68f0b

Browse files
POC-2 - Auto prefiltering for knn vector query
This POC implements automatic prefiltering for `semantic_text` queries. We achieve this by adding an `AutoPrefilteringScope` object to the `SearchExecutionContext`. When we convert a query to a lucene query, queries may push prefilters to the `AutoPrefilteringScope`. At that stage, queries have already been rewritten. Semantic queries using `text_embedding` inference endpoints are rewritten to knn vector queries that are auto-prefiltering enabled. Then, when an auto-prefiltering enabled knn vector query is converted to its lucene equivalent, we fetch the prefilters from the `SearchExecutionContext` and we apply them to the knn vector query - which supports prefiltering already.
1 parent f08e731 commit 6d68f0b

File tree

6 files changed

+193
-11
lines changed

6 files changed

+193
-11
lines changed

server/src/main/java/org/elasticsearch/index/query/BoolQueryBuilder.java

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,21 @@
1919
import org.elasticsearch.common.io.stream.StreamInput;
2020
import org.elasticsearch.common.io.stream.StreamOutput;
2121
import org.elasticsearch.common.lucene.search.Queries;
22+
import org.elasticsearch.index.query.support.AutoPrefilteringScope;
2223
import org.elasticsearch.xcontent.ObjectParser;
2324
import org.elasticsearch.xcontent.ParseField;
2425
import org.elasticsearch.xcontent.XContentBuilder;
2526
import org.elasticsearch.xcontent.XContentParser;
2627

2728
import java.io.IOException;
2829
import java.util.ArrayList;
30+
import java.util.Collection;
2931
import java.util.List;
3032
import java.util.Map;
3133
import java.util.Objects;
3234
import java.util.function.Consumer;
35+
import java.util.stream.Collectors;
36+
import java.util.stream.Stream;
3337

3438
import static org.elasticsearch.common.lucene.search.Queries.fixNegativeQueryIfNeeded;
3539

@@ -299,16 +303,17 @@ public String getWriteableName() {
299303
@Override
300304
protected Query doToQuery(SearchExecutionContext context) throws IOException {
301305
BooleanQuery.Builder booleanQueryBuilder = new BooleanQuery.Builder();
302-
addBooleanClauses(context, booleanQueryBuilder, mustClauses, BooleanClause.Occur.MUST);
306+
final List<QueryBuilder> prefilters = collectPrefilters();
307+
addBooleanClauses(context, booleanQueryBuilder, mustClauses, BooleanClause.Occur.MUST, prefilters);
303308
try {
304309
// disable tracking of the @timestamp range for must_not and should clauses
305310
context.setTrackTimeRangeFilterFrom(false);
306-
addBooleanClauses(context, booleanQueryBuilder, mustNotClauses, BooleanClause.Occur.MUST_NOT);
307-
addBooleanClauses(context, booleanQueryBuilder, shouldClauses, BooleanClause.Occur.SHOULD);
311+
addBooleanClauses(context, booleanQueryBuilder, mustNotClauses, BooleanClause.Occur.MUST_NOT, List.of());
312+
addBooleanClauses(context, booleanQueryBuilder, shouldClauses, BooleanClause.Occur.SHOULD, prefilters);
308313
} finally {
309314
context.setTrackTimeRangeFilterFrom(true);
310315
}
311-
addBooleanClauses(context, booleanQueryBuilder, filterClauses, BooleanClause.Occur.FILTER);
316+
addBooleanClauses(context, booleanQueryBuilder, filterClauses, BooleanClause.Occur.FILTER, List.of());
312317
BooleanQuery booleanQuery = booleanQueryBuilder.build();
313318
if (booleanQuery.clauses().isEmpty()) {
314319
return new MatchAllDocsQuery();
@@ -318,15 +323,25 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
318323
return adjustPureNegative ? fixNegativeQueryIfNeeded(query) : query;
319324
}
320325

326+
private List<QueryBuilder> collectPrefilters() {
327+
return Stream.of(mustClauses, mustNotClauses.stream().map(c -> QueryBuilders.boolQuery().mustNot(c)).toList(), filterClauses)
328+
.flatMap(Collection::stream)
329+
.collect(Collectors.toList());
330+
}
331+
321332
private static void addBooleanClauses(
322333
SearchExecutionContext context,
323334
BooleanQuery.Builder booleanQueryBuilder,
324335
List<QueryBuilder> clauses,
325-
Occur occurs
336+
Occur occurs,
337+
List<QueryBuilder> prefilters
326338
) throws IOException {
327339
for (QueryBuilder query : clauses) {
328-
Query luceneQuery = query.toQuery(context);
329-
booleanQueryBuilder.add(new BooleanClause(luceneQuery, occurs));
340+
try (AutoPrefilteringScope autoPrefilteringScope = context.autoPrefilteringScope()) {
341+
autoPrefilteringScope.push(prefilters.stream().filter(c -> c != query).toList());
342+
Query luceneQuery = query.toQuery(context);
343+
booleanQueryBuilder.add(new BooleanClause(luceneQuery, occurs));
344+
}
330345
}
331346
}
332347

server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.elasticsearch.index.mapper.ParsedDocument;
5050
import org.elasticsearch.index.mapper.SourceLoader;
5151
import org.elasticsearch.index.mapper.SourceToParse;
52+
import org.elasticsearch.index.query.support.AutoPrefilteringScope;
5253
import org.elasticsearch.index.query.support.NestedScope;
5354
import org.elasticsearch.index.similarity.SimilarityService;
5455
import org.elasticsearch.script.Script;
@@ -103,6 +104,7 @@ public class SearchExecutionContext extends QueryRewriteContext {
103104

104105
private final Map<String, Query> namedQueries = new HashMap<>();
105106
private NestedScope nestedScope;
107+
private AutoPrefilteringScope autoPrefilteringScope;
106108
private QueryBuilder aliasFilter;
107109
private boolean rewriteToNamedQueries = false;
108110

@@ -291,6 +293,7 @@ private SearchExecutionContext(
291293
this.bitsetFilterCache = bitsetFilterCache;
292294
this.indexFieldDataLookup = indexFieldDataLookup;
293295
this.nestedScope = new NestedScope();
296+
this.autoPrefilteringScope = new AutoPrefilteringScope();
294297
this.searcher = searcher;
295298
this.requestSize = requestSize;
296299
this.mapperMetrics = mapperMetrics;
@@ -301,7 +304,7 @@ private void reset() {
301304
this.lookup = null;
302305
this.namedQueries.clear();
303306
this.nestedScope = new NestedScope();
304-
307+
this.autoPrefilteringScope = new AutoPrefilteringScope();
305308
}
306309

307310
// Set alias filter, so it can be applied for queries that need it (e.g. knn query)
@@ -556,6 +559,10 @@ public NestedScope nestedScope() {
556559
return nestedScope;
557560
}
558561

562+
public AutoPrefilteringScope autoPrefilteringScope() {
563+
return autoPrefilteringScope;
564+
}
565+
559566
public IndexVersion indexVersionCreated() {
560567
return indexSettings.getIndexVersionCreated();
561568
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.query.support;
11+
12+
import org.elasticsearch.index.query.QueryBuilder;
13+
14+
import java.util.Deque;
15+
import java.util.LinkedList;
16+
import java.util.List;
17+
18+
/**
19+
* During query parsing this keeps track of the current prefiltering level.
20+
*/
21+
public final class AutoPrefilteringScope implements AutoCloseable {
22+
23+
private final Deque<List<QueryBuilder>> prefiltersStack = new LinkedList<>();
24+
25+
public List<QueryBuilder> getPrefilters() {
26+
return prefiltersStack.stream().flatMap(List::stream).toList();
27+
}
28+
29+
public void push(List<QueryBuilder> prefilters) {
30+
prefiltersStack.push(prefilters);
31+
}
32+
33+
public void pop() {
34+
prefiltersStack.pop();
35+
}
36+
37+
@Override
38+
public void close() {
39+
pop();
40+
}
41+
}

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
133133
private final QueryVectorBuilder queryVectorBuilder;
134134
private final Supplier<float[]> queryVectorSupplier;
135135
private final RescoreVectorBuilder rescoreVectorBuilder;
136+
private boolean isAutoPrefiltering = false;
136137

137138
public KnnVectorQueryBuilder(
138139
String fieldName,
@@ -579,8 +580,9 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
579580
}
580581
DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
581582

582-
List<Query> filtersInitial = new ArrayList<>(filterQueries.size());
583-
for (QueryBuilder query : this.filterQueries) {
583+
List<QueryBuilder> allApplicableFilters = getAllApplicableFilters(context);
584+
List<Query> filtersInitial = new ArrayList<>(allApplicableFilters.size());
585+
for (QueryBuilder query : allApplicableFilters) {
584586
filtersInitial.add(query.toQuery(context));
585587
}
586588
if (context.getAliasFilter() != null) {
@@ -650,6 +652,14 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
650652
);
651653
}
652654

655+
private List<QueryBuilder> getAllApplicableFilters(SearchExecutionContext context) {
656+
List<QueryBuilder> applicableFilters = new ArrayList<>(filterQueries);
657+
if (isAutoPrefiltering) {
658+
applicableFilters.addAll(context.autoPrefilteringScope().getPrefilters());
659+
}
660+
return applicableFilters;
661+
}
662+
653663
private static Query buildFilterQuery(List<Query> filters) {
654664
BooleanQuery.Builder builder = new BooleanQuery.Builder();
655665
for (Query f : filters) {
@@ -692,4 +702,9 @@ protected boolean doEquals(KnnVectorQueryBuilder other) {
692702
public TransportVersion getMinimalSupportedVersion() {
693703
return TransportVersions.V_8_0_0;
694704
}
705+
706+
public KnnVectorQueryBuilder enableAutoPrefiltering() {
707+
isAutoPrefiltering = true;
708+
return this;
709+
}
695710
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.query.support;
11+
12+
import org.elasticsearch.index.query.QueryBuilder;
13+
import org.elasticsearch.index.query.RandomQueryBuilder;
14+
import org.elasticsearch.test.ESTestCase;
15+
16+
import java.util.Collection;
17+
import java.util.List;
18+
import java.util.stream.Stream;
19+
20+
import static org.hamcrest.Matchers.empty;
21+
import static org.hamcrest.Matchers.equalTo;
22+
import static org.hamcrest.Matchers.is;
23+
24+
public class AutoPrefilteringScopeTests extends ESTestCase {
25+
26+
public void testMultipleLevels() {
27+
AutoPrefilteringScope autoPrefilteringScope = new AutoPrefilteringScope();
28+
assertThat(autoPrefilteringScope.getPrefilters(), is(empty()));
29+
30+
List<QueryBuilder> prefilters_1_1 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random()));
31+
List<QueryBuilder> prefilters_1_2 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random()));
32+
List<QueryBuilder> prefilters_2_1 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random()));
33+
List<QueryBuilder> prefilters_2_2 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random()));
34+
List<QueryBuilder> prefilters_3_1 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random()));
35+
36+
// Given + increases level and - decreases level, we add scope as follows:
37+
// + 1_1 + 2_1 + 3_1
38+
// - 3_1 + 2_2
39+
// - 2_2 + 1_2
40+
// - 1_2 + 1_1
41+
// - 1_1
42+
// and we check current prefilters after each operation.
43+
44+
autoPrefilteringScope.push(prefilters_1_1);
45+
assertThat(autoPrefilteringScope.getPrefilters(), equalTo(prefilters_1_1));
46+
autoPrefilteringScope.push(prefilters_2_1);
47+
assertThat(
48+
autoPrefilteringScope.getPrefilters(),
49+
equalTo(Stream.of(prefilters_2_1, prefilters_1_1).flatMap(Collection::stream).toList())
50+
);
51+
autoPrefilteringScope.push(prefilters_3_1);
52+
assertThat(
53+
autoPrefilteringScope.getPrefilters(),
54+
equalTo(Stream.of(prefilters_3_1, prefilters_2_1, prefilters_1_1).flatMap(Collection::stream).toList())
55+
);
56+
autoPrefilteringScope.pop();
57+
assertThat(
58+
autoPrefilteringScope.getPrefilters(),
59+
equalTo(Stream.of(prefilters_2_1, prefilters_1_1).flatMap(Collection::stream).toList())
60+
);
61+
autoPrefilteringScope.push(prefilters_2_2);
62+
assertThat(
63+
autoPrefilteringScope.getPrefilters(),
64+
equalTo(Stream.of(prefilters_2_2, prefilters_2_1, prefilters_1_1).flatMap(Collection::stream).toList())
65+
);
66+
autoPrefilteringScope.pop();
67+
assertThat(
68+
autoPrefilteringScope.getPrefilters(),
69+
equalTo(Stream.of(prefilters_2_1, prefilters_1_1).flatMap(Collection::stream).toList())
70+
);
71+
autoPrefilteringScope.pop();
72+
assertThat(autoPrefilteringScope.getPrefilters(), equalTo(prefilters_1_1));
73+
autoPrefilteringScope.push(prefilters_1_2);
74+
assertThat(
75+
autoPrefilteringScope.getPrefilters(),
76+
equalTo(Stream.of(prefilters_1_2, prefilters_1_1).flatMap(Collection::stream).toList())
77+
);
78+
autoPrefilteringScope.pop();
79+
assertThat(autoPrefilteringScope.getPrefilters(), equalTo(prefilters_1_1));
80+
autoPrefilteringScope.pop();
81+
assertThat(autoPrefilteringScope.getPrefilters(), empty());
82+
}
83+
84+
public void testAutoCloseable() {
85+
AutoPrefilteringScope autoPrefilteringScope = new AutoPrefilteringScope();
86+
List<QueryBuilder> prefilters_1 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random()));
87+
List<QueryBuilder> prefilters_2 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random()));
88+
89+
try (autoPrefilteringScope) {
90+
autoPrefilteringScope.push(prefilters_1);
91+
92+
try (autoPrefilteringScope) {
93+
autoPrefilteringScope.push(prefilters_2);
94+
assertThat(
95+
autoPrefilteringScope.getPrefilters(),
96+
equalTo(Stream.of(prefilters_2, prefilters_1).flatMap(Collection::stream).toList())
97+
);
98+
}
99+
assertThat(autoPrefilteringScope.getPrefilters(), equalTo(prefilters_1));
100+
}
101+
assertThat(autoPrefilteringScope.getPrefilters(), is(empty()));
102+
}
103+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,8 @@ yield new SparseVectorQueryBuilder(
10551055
k = Math.max(k, DEFAULT_SIZE);
10561056
}
10571057

1058-
yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null, null, null);
1058+
yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null, null, null)
1059+
.enableAutoPrefiltering();
10591060
}
10601061
default -> throw new IllegalStateException(
10611062
"Field ["

0 commit comments

Comments
 (0)