Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/helpers/assertVectorSearchFilterFieldsAreIndexed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import { type CompositeLogger, LogId } from "../common/logger.js";
// https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#mongodb-vector-search-pre-filter
const ALLOWED_LOGICAL_OPERATORS = ["$not", "$nor", "$and", "$or"];

export type VectorSearchIndex = {
export type SearchIndex = VectorSearchIndex | AtlasSearchIndex;

type VectorSearchIndex = {
name: string;
latestDefinition: {
fields: Array<
Expand All @@ -21,19 +23,28 @@ export type VectorSearchIndex = {
}
>;
};
type: "vectorSearch";
};

type AtlasSearchIndex = {
name: string;
latestDefinition: unknown;
type: "search";
};

export function assertVectorSearchFilterFieldsAreIndexed({
searchIndexes,
pipeline,
logger,
}: {
searchIndexes: VectorSearchIndex[];
searchIndexes: SearchIndex[];
pipeline: Record<string, unknown>[];
logger: CompositeLogger;
}): void {
const searchIndexesWithFilterFields = searchIndexes.reduce<Record<string, string[]>>(
(indexFieldMap, searchIndex) => {
const searchIndexesWithFilterFields = searchIndexes
// Ensure we only process vector search indexes and not lexical search ones
.filter((index) => index.type === "vectorSearch")
.reduce<Record<string, string[]>>((indexFieldMap, searchIndex) => {
const filterFields = searchIndex.latestDefinition.fields
.map<string | undefined>((field) => {
return field.type === "filter" ? field.path : undefined;
Expand All @@ -42,9 +53,7 @@ export function assertVectorSearchFilterFieldsAreIndexed({

indexFieldMap[searchIndex.name] = filterFields;
return indexFieldMap;
},
{}
);
}, {});
for (const stage of pipeline) {
if ("$vectorSearch" in stage) {
const { $vectorSearch: vectorSearchStage } = stage as z.infer<typeof VectorSearchStage>;
Expand Down
4 changes: 2 additions & 2 deletions src/tools/mongodb/read/aggregate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import { LogId } from "../../../common/logger.js";
import { AnyAggregateStage, VectorSearchStage } from "../mongodbSchemas.js";
import {
assertVectorSearchFilterFieldsAreIndexed,
type VectorSearchIndex,
type SearchIndex,
} from "../../../helpers/assertVectorSearchFilterFieldsAreIndexed.js";

const pipelineDescriptionWithVectorSearch = `\
Expand Down Expand Up @@ -66,7 +66,7 @@ export class AggregateTool extends MongoDBToolBase {
await this.assertOnlyUsesPermittedStages(pipeline);
if (await this.session.isSearchSupported()) {
assertVectorSearchFilterFieldsAreIndexed({
searchIndexes: (await provider.getSearchIndexes(database, collection)) as VectorSearchIndex[],
searchIndexes: (await provider.getSearchIndexes(database, collection)) as SearchIndex[],
pipeline,
logger: this.session.logger,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { describe, expect, it, vi } from "vitest";
import {
assertVectorSearchFilterFieldsAreIndexed,
collectFieldsFromVectorSearchFilter,
type VectorSearchIndex,
type SearchIndex,
} from "../../../src/helpers/assertVectorSearchFilterFieldsAreIndexed.js";
import { ErrorCodes, MongoDBError } from "../../../src/common/errors.js";
import { type CompositeLogger, LogId } from "../../../src/common/logger.js";
Expand Down Expand Up @@ -184,7 +184,7 @@ describe("#assertVectorSearchFilterFieldsAreIndexed", () => {
error: vi.fn(),
} as unknown as CompositeLogger;

const createMockSearchIndexes = (indexName: string, filterFields: string[]): VectorSearchIndex[] => [
const createMockSearchIndexes = (indexName: string, filterFields: string[]): SearchIndex[] => [
{
name: indexName,
latestDefinition: {
Expand All @@ -196,6 +196,7 @@ describe("#assertVectorSearchFilterFieldsAreIndexed", () => {
})),
],
},
type: "vectorSearch",
},
];

Expand Down Expand Up @@ -547,12 +548,13 @@ describe("#assertVectorSearchFilterFieldsAreIndexed", () => {
});

it("should handle search index with no filter fields", () => {
const searchIndexes: VectorSearchIndex[] = [
const searchIndexes: SearchIndex[] = [
{
name: "myIndex",
latestDefinition: {
fields: [{ type: "vector" }],
},
type: "vectorSearch",
},
];
const pipeline = [
Expand Down Expand Up @@ -583,4 +585,45 @@ describe("#assertVectorSearchFilterFieldsAreIndexed", () => {
)
);
});

it("should ignore atlas search indexes", () => {
const searchIndexes: SearchIndex[] = [
...createMockSearchIndexes("index1", ["field1", "field2"]),
// Atlas search index - it should be ignored by the validation
// and not cause any errors
{
name: "atlasSearchIndex",
latestDefinition: {
analyzer: "lucene.standard",
mappings: {
dynamic: false,
},
},
type: "search",
},
];

const pipeline = [
{
$vectorSearch: {
index: "index1",
path: "embedding",
queryVector: [1, 2, 3],
numCandidates: 100,
limit: 10,
filter: {
field1: "value",
},
},
},
];

expect(() =>
assertVectorSearchFilterFieldsAreIndexed({
searchIndexes: searchIndexes,
pipeline,
logger: mockLogger,
})
).not.toThrow();
});
});
Loading