Skip to content

Commit cb52116

Browse files
committed
chore: add unit tests to embedding validation
1 parent 8ac71ba commit cb52116

File tree

2 files changed

+284
-10
lines changed

2 files changed

+284
-10
lines changed

src/common/search/vectorSearchEmbeddings.ts

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
2-
import type { Document } from "bson";
2+
import { BSON, type Document } from "bson";
33

44
type VectorFieldIndexDefinition = {
55
type: "vector";
@@ -9,16 +9,12 @@ type VectorFieldIndexDefinition = {
99
similarity: "euclidean" | "cosine" | "dotProduct";
1010
};
1111

12-
type EmbeddingNamespace = "${string}.${string}";
12+
export type EmbeddingNamespace = `${string}.${string}`;
1313
export class VectorSearchEmbeddings {
14-
private embeddings: Map<EmbeddingNamespace, VectorFieldIndexDefinition[]>;
15-
16-
constructor() {
17-
this.embeddings = new Map();
18-
}
14+
constructor(private readonly embeddings: Map<EmbeddingNamespace, VectorFieldIndexDefinition[]> = new Map()) {}
1915

2016
cleanupEmbeddingsForNamespace({ database, collection }: { database: string; collection: string }): void {
21-
const embeddingDefKey = `${database}.${collection}` as EmbeddingNamespace;
17+
const embeddingDefKey: EmbeddingNamespace = `${database}.${collection}`;
2218
this.embeddings.delete(embeddingDefKey);
2319
}
2420

@@ -31,7 +27,7 @@ export class VectorSearchEmbeddings {
3127
collection: string;
3228
provider: NodeDriverServiceProvider;
3329
}): Promise<VectorFieldIndexDefinition[] | undefined> {
34-
const embeddingDefKey = `${database}.${collection}` as EmbeddingNamespace;
30+
const embeddingDefKey: EmbeddingNamespace = `${database}.${collection}`;
3531
const definition = this.embeddings.get(embeddingDefKey);
3632

3733
if (!definition) {
@@ -49,7 +45,71 @@ export class VectorSearchEmbeddings {
4945
}
5046
}
5147

52-
isVectorFieldIndexDefinition(doc: Document): doc is VectorFieldIndexDefinition {
48+
async findFieldsWithWrongEmbeddings(
49+
{
50+
database,
51+
collection,
52+
provider,
53+
}: {
54+
database: string;
55+
collection: string;
56+
provider: NodeDriverServiceProvider;
57+
},
58+
document: Document
59+
): Promise<VectorFieldIndexDefinition[]> {
60+
const embeddings = await this.embeddingsForNamespace({ database, collection, provider });
61+
62+
if (!embeddings) {
63+
return [];
64+
}
65+
66+
return embeddings.filter((emb) => !this.documentPassesEmbeddingValidation(emb, document));
67+
}
68+
69+
private isVectorFieldIndexDefinition(doc: Document): doc is VectorFieldIndexDefinition {
5370
return doc["type"] === "vector";
5471
}
72+
73+
private documentPassesEmbeddingValidation(definition: VectorFieldIndexDefinition, document: Document): boolean {
74+
const fieldPath = definition.path.split(".");
75+
let fieldRef: unknown = document;
76+
77+
for (const field of fieldPath) {
78+
if (fieldRef && typeof fieldRef === "object" && field in fieldRef) {
79+
fieldRef = (fieldRef as Record<string, unknown>)[field];
80+
} else {
81+
return true;
82+
}
83+
}
84+
85+
switch (definition.quantization) {
86+
case "none":
87+
case "scalar":
88+
if (!Array.isArray(fieldRef)) {
89+
return false;
90+
}
91+
92+
if (fieldRef.length !== definition.numDimensions) {
93+
return false;
94+
}
95+
96+
if (typeof fieldRef[0] !== "number") {
97+
return false;
98+
}
99+
break;
100+
case "binary":
101+
if (fieldRef instanceof BSON.Binary) {
102+
try {
103+
const bits = fieldRef.toBits();
104+
return bits.length === definition.numDimensions;
105+
} catch {
106+
return false;
107+
}
108+
} else {
109+
return false;
110+
}
111+
}
112+
113+
return true;
114+
}
55115
}
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest";
2+
import type { MockedFunction } from "vitest";
3+
import { VectorSearchEmbeddings } from "../../../../src/common/search/vectorSearchEmbeddings.js";
4+
import type { EmbeddingNamespace } from "../../../../src/common/search/vectorSearchEmbeddings.js";
5+
import { BSON } from "bson";
6+
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
7+
8+
type MockedServiceProvider = NodeDriverServiceProvider & {
9+
getSearchIndexes: MockedFunction<NodeDriverServiceProvider["getSearchIndexes"]>;
10+
};
11+
12+
describe("VectorSearchEmbeddings", () => {
13+
const database = "my" as const;
14+
const collection = "collection" as const;
15+
const mapKey = `${database}.${collection}` as EmbeddingNamespace;
16+
17+
const provider: MockedServiceProvider = {
18+
getSearchIndexes: vi.fn(),
19+
} as unknown as MockedServiceProvider;
20+
21+
beforeEach(() => {
22+
provider.getSearchIndexes.mockReset();
23+
});
24+
25+
describe("embedding retrieval", () => {
26+
describe("when the embeddings have not been cached", () => {
27+
beforeEach(() => {
28+
provider.getSearchIndexes.mockImplementation(() => {
29+
return Promise.resolve([
30+
{
31+
id: "65e8c766d0450e3e7ab9855f",
32+
name: "search-test",
33+
type: "search",
34+
status: "READY",
35+
queryable: true,
36+
latestDefinition: { dynamic: true },
37+
},
38+
{
39+
id: "65e8c766d0450e3e7ab9855f",
40+
name: "vector-search-test",
41+
type: "vectorSearch",
42+
status: "READY",
43+
queryable: true,
44+
latestDefinition: {
45+
fields: [
46+
{
47+
type: "vector",
48+
path: "plot_embedding",
49+
numDimensions: 1536,
50+
similarity: "euclidean",
51+
},
52+
{ type: "filter", path: "genres" },
53+
{ type: "filter", path: "year" },
54+
],
55+
},
56+
},
57+
]);
58+
});
59+
});
60+
61+
it("retrieves the list of vector search indexes for that collection from the cluster", async () => {
62+
const embeddings = new VectorSearchEmbeddings();
63+
const result = await embeddings.embeddingsForNamespace({ database, collection, provider });
64+
65+
expect(result).toContainEqual({
66+
type: "vector",
67+
path: "plot_embedding",
68+
numDimensions: 1536,
69+
similarity: "euclidean",
70+
});
71+
});
72+
73+
it("ignores any other type of index", async () => {
74+
const embeddings = new VectorSearchEmbeddings();
75+
const result = await embeddings.embeddingsForNamespace({ database, collection, provider });
76+
77+
expect(result?.filter((emb) => emb.type !== "vector")).toHaveLength(0);
78+
});
79+
});
80+
});
81+
82+
describe("embedding validation", () => {
83+
it("when there are no embeddings, all documents are valid", async () => {
84+
const embeddings = new VectorSearchEmbeddings(new Map([[mapKey, []]]));
85+
const result = await embeddings.findFieldsWithWrongEmbeddings(
86+
{ database, collection, provider },
87+
{ field: "yay" }
88+
);
89+
90+
expect(result).toHaveLength(0);
91+
});
92+
93+
describe("when there are embeddings", () => {
94+
const embeddings = new VectorSearchEmbeddings(
95+
new Map([
96+
[
97+
mapKey,
98+
[
99+
{
100+
type: "vector",
101+
path: "embedding_field",
102+
numDimensions: 8,
103+
quantization: "none",
104+
similarity: "euclidean",
105+
},
106+
{
107+
type: "vector",
108+
path: "embedding_field_binary",
109+
numDimensions: 8,
110+
quantization: "binary",
111+
similarity: "euclidean",
112+
},
113+
{
114+
type: "vector",
115+
path: "a.nasty.scalar.field",
116+
numDimensions: 8,
117+
quantization: "none",
118+
similarity: "euclidean",
119+
},
120+
{
121+
type: "vector",
122+
path: "a.nasty.binary.field",
123+
numDimensions: 8,
124+
quantization: "binary",
125+
similarity: "euclidean",
126+
},
127+
],
128+
],
129+
])
130+
);
131+
132+
it("documents not inserting the field with embeddings are valid", async () => {
133+
const result = await embeddings.findFieldsWithWrongEmbeddings(
134+
{ database, collection, provider },
135+
{ field: "yay" }
136+
);
137+
138+
expect(result).toHaveLength(0);
139+
});
140+
141+
it("documents inserting the field with wrong type are invalid", async () => {
142+
const result = await embeddings.findFieldsWithWrongEmbeddings(
143+
{ database, collection, provider },
144+
{ embedding_field: "some text" }
145+
);
146+
147+
expect(result).toHaveLength(1);
148+
});
149+
150+
it("documents inserting the field with wrong dimensions are invalid", async () => {
151+
const result = await embeddings.findFieldsWithWrongEmbeddings(
152+
{ database, collection, provider },
153+
{ embedding_field: [1, 2, 3] }
154+
);
155+
156+
expect(result).toHaveLength(1);
157+
});
158+
159+
it("documents inserting the field with correct dimensions, but wrong type are invalid", async () => {
160+
const result = await embeddings.findFieldsWithWrongEmbeddings(
161+
{ database, collection, provider },
162+
{ embedding_field: ["1", "2", "3", "4", "5", "6", "7", "8"] }
163+
);
164+
165+
expect(result).toHaveLength(1);
166+
});
167+
168+
it("documents inserting the field with correct dimensions, but wrong quantization are invalid", async () => {
169+
const result = await embeddings.findFieldsWithWrongEmbeddings(
170+
{ database, collection, provider },
171+
{ embedding_field_binary: [1, 2, 3, 4, 5, 6, 7, 8] }
172+
);
173+
174+
expect(result).toHaveLength(1);
175+
});
176+
177+
it("documents inserting the field with correct dimensions and quantization in binary are valid", async () => {
178+
const result = await embeddings.findFieldsWithWrongEmbeddings(
179+
{ database, collection, provider },
180+
{ embedding_field_binary: BSON.Binary.fromBits([0, 0, 0, 0, 0, 0, 0, 0]) }
181+
);
182+
183+
expect(result).toHaveLength(0);
184+
});
185+
186+
it("documents inserting the field with correct dimensions and quantization in scalar/none are valid", async () => {
187+
const result = await embeddings.findFieldsWithWrongEmbeddings(
188+
{ database, collection, provider },
189+
{ embedding_field: [1, 2, 3, 4, 5, 6, 7, 8] }
190+
);
191+
192+
expect(result).toHaveLength(0);
193+
});
194+
195+
it("documents inserting the field with correct dimensions and quantization in scalar/none are valid also on nested fields", async () => {
196+
const result = await embeddings.findFieldsWithWrongEmbeddings(
197+
{ database, collection, provider },
198+
{ a: { nasty: { scalar: { field: [1, 2, 3, 4, 5, 6, 7, 8] } } } }
199+
);
200+
201+
expect(result).toHaveLength(0);
202+
});
203+
204+
it("documents inserting the field with correct dimensions and quantization in binary are valid also on nested fields", async () => {
205+
const result = await embeddings.findFieldsWithWrongEmbeddings(
206+
{ database, collection, provider },
207+
{ a: { nasty: { binary: { field: BSON.Binary.fromBits([0, 0, 0, 0, 0, 0, 0, 0]) } } } }
208+
);
209+
210+
expect(result).toHaveLength(0);
211+
});
212+
});
213+
});
214+
});

0 commit comments

Comments
 (0)