Skip to content

Commit 43c4d6c

Browse files
committed
chore: simplify, cleanup, add tests
1 parent 5917de1 commit 43c4d6c

File tree

9 files changed

+714
-176
lines changed

9 files changed

+714
-176
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ tests/tmp
1313
coverage
1414
# Generated assets by accuracy runs
1515
.accuracy
16+
17+
.DS_Store

CONTRIBUTING.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ npm test -- path/to/test/file.test.ts
7676
npm test -- path/to/directory
7777
```
7878

79+
#### Accuracy Tests and colima
80+
81+
If you use [colima](https://github.com/abiosoft/colima) to run Docker on Mac, you will need to apply [additional configuration](https://node.testcontainers.org/supported-container-runtimes/#colima) to ensure the accuracy tests run correctly.
82+
7983
## Troubleshooting
8084

8185
### Restart Server

src/common/errors.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export enum ErrorCodes {
77
NoEmbeddingsProviderConfigured = 1_000_005,
88
AtlasVectorSearchIndexNotFound = 1_000_006,
99
AtlasVectorSearchInvalidQuery = 1_000_007,
10+
Unexpected = 1_000_008,
1011
}
1112

1213
export class MongoDBError<ErrorCode extends ErrorCodes = ErrorCodes> extends Error {

src/common/search/vectorSearchEmbeddingsManager.ts

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,28 @@ export class VectorSearchEmbeddingsManager {
8484
return definition;
8585
}
8686

87-
async findFieldsWithWrongEmbeddings(
87+
async assertFieldsHaveCorrectEmbeddings(
88+
{ database, collection }: { database: string; collection: string },
89+
documents: Document[]
90+
): Promise<void> {
91+
const embeddingValidationResults = await Promise.all(
92+
documents.map((document) => this.findFieldsWithWrongEmbeddings({ database, collection }, document))
93+
);
94+
const embeddingValidations = new Set(embeddingValidationResults.flat());
95+
96+
if (embeddingValidations.size > 0) {
97+
const embeddingValidationMessages = Array.from(embeddingValidations).map(
98+
(validation) =>
99+
`- Field ${validation.path} is an embedding with ${validation.expectedNumDimensions} dimensions and ${validation.expectedQuantization}` +
100+
` quantization, and the provided value is not compatible. Actual dimensions: ${validation.actualNumDimensions}, ` +
101+
`actual quantization: ${validation.actualQuantization}. Error: ${validation.error}`
102+
);
103+
104+
throw new MongoDBError(ErrorCodes.AtlasVectorSearchInvalidQuery, embeddingValidationMessages.join("\n"));
105+
}
106+
}
107+
108+
private async findFieldsWithWrongEmbeddings(
88109
{
89110
database,
90111
collection,
@@ -220,21 +241,34 @@ export class VectorSearchEmbeddingsManager {
220241
return undefined;
221242
}
222243

223-
public async generateEmbeddings({
244+
public async assertVectorSearchIndexExists({
224245
database,
225246
collection,
226247
path,
227-
rawValues,
228-
embeddingParameters,
229-
inputType,
230248
}: {
231249
database: string;
232250
collection: string;
233251
path: string;
252+
}): Promise<void> {
253+
const embeddingInfoForCollection = await this.embeddingsForNamespace({ database, collection });
254+
const embeddingInfoForPath = embeddingInfoForCollection.find((definition) => definition.path === path);
255+
if (!embeddingInfoForPath) {
256+
throw new MongoDBError(
257+
ErrorCodes.AtlasVectorSearchIndexNotFound,
258+
`No Vector Search index found for path "${path}" in namespace "${database}.${collection}"`
259+
);
260+
}
261+
}
262+
263+
public async generateEmbeddings({
264+
rawValues,
265+
embeddingParameters,
266+
inputType,
267+
}: {
234268
rawValues: string[];
235269
embeddingParameters: SupportedEmbeddingParameters;
236270
inputType: EmbeddingParameters["inputType"];
237-
}): Promise<unknown[]> {
271+
}): Promise<number[][]> {
238272
const provider = await this.atlasSearchEnabledProvider();
239273
if (!provider) {
240274
throw new MongoDBError(
@@ -256,15 +290,6 @@ export class VectorSearchEmbeddingsManager {
256290
});
257291
}
258292

259-
const embeddingInfoForCollection = await this.embeddingsForNamespace({ database, collection });
260-
const embeddingInfoForPath = embeddingInfoForCollection.find((definition) => definition.path === path);
261-
if (!embeddingInfoForPath) {
262-
throw new MongoDBError(
263-
ErrorCodes.AtlasVectorSearchIndexNotFound,
264-
`No Vector Search index found for path "${path}" in namespace "${database}.${collection}"`
265-
);
266-
}
267-
268293
return await embeddingsProvider.embed(embeddingParameters.model, rawValues, {
269294
inputType,
270295
...embeddingParameters,

src/tools/mongodb/create/insertMany.ts

Lines changed: 63 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@ import { zEJSON } from "../../args.js";
66
import { type Document } from "bson";
77
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";
88
import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
9-
import type { VectorFieldIndexDefinition } from "../../../common/search/vectorSearchEmbeddingsManager.js";
9+
10+
const zSupportedEmbeddingParametersWithInput = zSupportedEmbeddingParameters.extend({
11+
input: z
12+
.array(z.object({}).passthrough())
13+
.describe(
14+
"Array of objects with vector search index fields as keys (in dot notation) and the raw text values to generate embeddings for as values. The index of each object corresponds to the index of the document in the documents array."
15+
),
16+
});
1017

1118
export class InsertManyTool extends MongoDBToolBase {
1219
public name = "insert-many";
@@ -16,9 +23,9 @@ export class InsertManyTool extends MongoDBToolBase {
1623
documents: z
1724
.array(zEJSON().describe("An individual MongoDB document"))
1825
.describe(
19-
"The array of documents to insert, matching the syntax of the document argument of db.collection.insertMany(). If you are asked to generate a embedding for a field, you have to explicitly specify the field name with a raw text string value of the field and an embedding will be generated if embeddingParameters is provided."
26+
"The array of documents to insert, matching the syntax of the document argument of db.collection.insertMany()."
2027
),
21-
embeddingParameters: zSupportedEmbeddingParameters
28+
embeddingParameters: zSupportedEmbeddingParametersWithInput
2229
.optional()
2330
.describe(
2431
"The embedding model and its parameters to use to generate embeddings for fields that have vector search indexes. When a field has a vector search index and contains a plain text string in the document, embeddings will be automatically generated from that string value. Note to LLM: If unsure which embedding model to use, ask the user before providing one."
@@ -34,45 +41,14 @@ export class InsertManyTool extends MongoDBToolBase {
3441
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
3542
const provider = await this.ensureConnected();
3643

37-
// Get vector search indexes for the collection
38-
const vectorIndexes = await this.session.vectorSearchEmbeddingsManager.embeddingsForNamespace({
39-
database,
40-
collection,
41-
});
42-
4344
// Process documents to replace raw string values with generated embeddings
4445
documents = await this.replaceRawValuesWithEmbeddingsIfNecessary({
4546
database,
4647
collection,
4748
documents,
48-
vectorIndexes,
4949
embeddingParameters,
5050
});
5151

52-
const embeddingValidationPromises = documents.map((document) =>
53-
this.session.vectorSearchEmbeddingsManager.findFieldsWithWrongEmbeddings({ database, collection }, document)
54-
);
55-
const embeddingValidationResults = await Promise.all(embeddingValidationPromises);
56-
const embeddingValidations = new Set(embeddingValidationResults.flat());
57-
58-
if (embeddingValidations.size > 0) {
59-
// tell the LLM what happened
60-
const embeddingValidationMessages = Array.from(embeddingValidations).map(
61-
(validation) =>
62-
`- Field ${validation.path} is an embedding with ${validation.expectedNumDimensions} dimensions and ${validation.expectedQuantization}` +
63-
` quantization, and the provided value is not compatible. Actual dimensions: ${validation.actualNumDimensions}, ` +
64-
`actual quantization: ${validation.actualQuantization}. Error: ${validation.error}`
65-
);
66-
67-
return {
68-
content: formatUntrustedData(
69-
"There were errors when inserting documents. No document was inserted.",
70-
...embeddingValidationMessages
71-
),
72-
isError: true,
73-
};
74-
}
75-
7652
const result = await provider.insertMany(database, collection, documents);
7753
const content = formatUntrustedData(
7854
"Documents were inserted successfully.",
@@ -88,144 +64,84 @@ export class InsertManyTool extends MongoDBToolBase {
8864
database,
8965
collection,
9066
documents,
91-
vectorIndexes,
9267
embeddingParameters,
9368
}: {
9469
database: string;
9570
collection: string;
9671
documents: Document[];
97-
vectorIndexes: VectorFieldIndexDefinition[];
98-
embeddingParameters?: z.infer<typeof zSupportedEmbeddingParameters>;
72+
embeddingParameters?: z.infer<typeof zSupportedEmbeddingParametersWithInput>;
9973
}): Promise<Document[]> {
100-
// If no vector indexes, return documents as-is
101-
if (vectorIndexes.length === 0) {
74+
// If no embedding parameters or no input specified, return documents as-is
75+
if (!embeddingParameters?.input || embeddingParameters.input.length === 0) {
10276
return documents;
10377
}
10478

105-
const processedDocuments: Document[] = [];
106-
107-
for (let i = 0; i < documents.length; i++) {
108-
const document = documents[i];
109-
if (!document) {
110-
continue;
111-
}
112-
const processedDoc = await this.processDocumentForEmbeddings(
113-
database,
114-
collection,
115-
document,
116-
vectorIndexes,
117-
embeddingParameters
118-
);
119-
processedDocuments.push(processedDoc);
120-
}
121-
122-
return processedDocuments;
123-
}
124-
125-
private async processDocumentForEmbeddings(
126-
database: string,
127-
collection: string,
128-
document: Document,
129-
vectorIndexes: VectorFieldIndexDefinition[],
130-
embeddingParameters?: z.infer<typeof zSupportedEmbeddingParameters>
131-
): Promise<Document> {
132-
// Find all fields in the document that match vector search indexed fields and need embeddings
133-
const fieldsNeedingEmbeddings: Array<{
134-
path: string;
135-
rawValue: string;
136-
indexDef: VectorFieldIndexDefinition;
137-
}> = [];
79+
// Get vector search indexes for the collection
80+
const vectorIndexes = await this.session.vectorSearchEmbeddingsManager.embeddingsForNamespace({
81+
database,
82+
collection,
83+
});
13884

139-
for (const indexDef of vectorIndexes) {
140-
// Check if the field exists in the document and is a string (raw text)
141-
const fieldValue = this.getFieldValue(document, indexDef.path);
142-
if (typeof fieldValue === "string") {
143-
fieldsNeedingEmbeddings.push({
144-
path: indexDef.path,
145-
rawValue: fieldValue,
146-
indexDef,
147-
});
85+
// Ensure for inputted fields, the vector search index exists.
86+
for (const input of embeddingParameters.input) {
87+
for (const fieldPath of Object.keys(input)) {
88+
if (!vectorIndexes.some((index) => index.path === fieldPath)) {
89+
throw new MongoDBError(
90+
ErrorCodes.AtlasVectorSearchInvalidQuery,
91+
`Field '${fieldPath}' does not have a vector search index in collection ${database}.${collection}. Only fields with vector search indexes can have embeddings generated.`
92+
);
93+
}
14894
}
14995
}
15096

151-
// If no fields need embeddings, return document as-is
152-
if (fieldsNeedingEmbeddings.length === 0) {
153-
return document;
154-
}
155-
156-
// Check if embeddingParameters is provided
157-
if (!embeddingParameters) {
158-
const fieldPaths = fieldsNeedingEmbeddings.map((f) => f.path).join(", ");
159-
throw new MongoDBError(
160-
ErrorCodes.AtlasVectorSearchInvalidQuery,
161-
`Fields [${fieldPaths}] have vector search indexes and contain raw text strings. The embeddingParameters parameter is required to generate embeddings for these fields.`
162-
);
163-
}
164-
165-
// Generate embeddings for all fields
166-
const embeddingsMap = new Map<string, number[]>();
167-
168-
for (const field of fieldsNeedingEmbeddings) {
169-
const embeddings = await this.session.vectorSearchEmbeddingsManager.generateEmbeddings({
170-
database,
171-
collection,
172-
path: field.path,
173-
rawValues: [field.rawValue],
174-
embeddingParameters,
175-
inputType: "document",
176-
});
97+
// We make one call to generate embeddings for all documents at once to avoid making too many API calls.
98+
const flattenedEmbeddingsInput = embeddingParameters.input.flatMap((documentInput, index) =>
99+
Object.entries(documentInput).map(([fieldPath, rawTextValue]) => ({
100+
fieldPath,
101+
rawTextValue,
102+
documentIndex: index,
103+
}))
104+
);
177105

178-
if (embeddings.length > 0 && Array.isArray(embeddings[0])) {
179-
embeddingsMap.set(field.path, embeddings[0] as number[]);
180-
}
181-
}
106+
const generatedEmbeddings = await this.session.vectorSearchEmbeddingsManager.generateEmbeddings({
107+
rawValues: flattenedEmbeddingsInput.map(({ rawTextValue }) => rawTextValue) as string[],
108+
embeddingParameters,
109+
inputType: "document",
110+
});
182111

183-
// Replace raw string values with generated embeddings
184-
const processedDoc = { ...document };
112+
const processedDocuments: Document[] = [...documents];
185113

186-
for (const field of fieldsNeedingEmbeddings) {
187-
const embedding = embeddingsMap.get(field.path);
188-
if (embedding) {
189-
this.setFieldValue(processedDoc, field.path, embedding);
114+
for (const [index, { fieldPath, documentIndex }] of flattenedEmbeddingsInput.entries()) {
115+
if (!processedDocuments[documentIndex]) {
116+
throw new MongoDBError(ErrorCodes.Unexpected, `Document at index ${documentIndex} does not exist.`);
190117
}
118+
// Ensure no nested fields are present in the field path.
119+
this.deleteFieldPath(processedDocuments[documentIndex], fieldPath);
120+
processedDocuments[documentIndex][fieldPath] = generatedEmbeddings[index];
191121
}
192122

193-
return processedDoc;
194-
}
195-
196-
private getFieldValue(document: Document, path: string): unknown {
197-
const parts = path.split(".");
198-
let current: unknown = document;
199-
200-
for (const part of parts) {
201-
if (current && typeof current === "object" && part in current) {
202-
current = (current as Record<string, unknown>)[part];
203-
} else {
204-
return undefined;
205-
}
206-
}
123+
await this.session.vectorSearchEmbeddingsManager.assertFieldsHaveCorrectEmbeddings(
124+
{ database, collection },
125+
processedDocuments
126+
);
207127

208-
return current;
128+
return processedDocuments;
209129
}
210130

211-
private setFieldValue(document: Document, path: string, value: unknown): void {
212-
const parts = path.split(".");
131+
// Delete a specified field path from a document using dot notation.
132+
private deleteFieldPath(document: Record<string, unknown>, fieldPath: string): void {
133+
const parts = fieldPath.split(".");
213134
let current: Record<string, unknown> = document;
214-
215-
for (let i = 0; i < parts.length - 1; i++) {
135+
for (let i = 0; i < parts.length; i++) {
216136
const part = parts[i];
217-
if (!part) {
218-
continue;
219-
}
220-
if (!(part in current) || typeof current[part] !== "object") {
221-
current[part] = {};
137+
const key = part as keyof typeof current;
138+
if (!current[key]) {
139+
return;
140+
} else if (i === parts.length - 1) {
141+
delete current[key];
142+
} else {
143+
current = current[key] as Record<string, unknown>;
222144
}
223-
current = current[part] as Record<string, unknown>;
224-
}
225-
226-
const lastPart = parts[parts.length - 1];
227-
if (lastPart) {
228-
current[lastPart] = value;
229145
}
230146
}
231147
}

0 commit comments

Comments
 (0)