Skip to content

Commit 85e9cf2

Browse files
committed
chore: simplify integration with embeddings and make it more configurable
1 parent d1d770d commit 85e9cf2

File tree

5 files changed

+365
-353
lines changed

5 files changed

+365
-353
lines changed

src/common/search/embeddingsProvider.ts

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,28 @@ import { z } from "zod";
88

99
type EmbeddingsInput = string;
1010
type Embeddings = number[];
11-
12-
type EmbeddingParameters = {
13-
numDimensions: number;
14-
quantization: string;
11+
export type EmbeddingParameters = {
1512
inputType: "query" | "document";
1613
};
1714

18-
interface EmbeddingsProvider<SupportedModels extends string> {
19-
embed(modelId: SupportedModels, content: EmbeddingsInput[], parameters: EmbeddingParameters): Promise<Embeddings[]>;
15+
interface EmbeddingsProvider<SupportedModels extends string, SupportedEmbeddingParameters extends EmbeddingParameters> {
16+
embed(
17+
modelId: SupportedModels,
18+
content: EmbeddingsInput[],
19+
parameters: SupportedEmbeddingParameters
20+
): Promise<Embeddings[]>;
2021
}
2122

22-
const zVoyageModels = z.enum(["voyage-3-large", "voyage-3.5", "voyage-3.5-lite", "voyage-code-3"]);
23+
export const zVoyageModels = z.enum(["voyage-3-large", "voyage-3.5", "voyage-3.5-lite", "voyage-code-3"]);
24+
export const zVoyageEmbeddingParameters = z.object({
25+
outputDimension: z.union([z.literal(256), z.literal(512), z.literal(1024), z.literal(2048), z.literal(4096)]),
26+
outputDType: z.enum(["float", "int8", "uint8", "binary", "ubinary"]),
27+
});
2328

2429
type VoyageModels = z.infer<typeof zVoyageModels>;
25-
class VoyageEmbeddingsProvider implements EmbeddingsProvider<VoyageModels> {
30+
type VoyageEmbeddingParameters = z.infer<typeof zVoyageEmbeddingParameters> & EmbeddingParameters;
31+
32+
class VoyageEmbeddingsProvider implements EmbeddingsProvider<VoyageModels, VoyageEmbeddingParameters> {
2633
private readonly voyage: VoyageProvider;
2734

2835
constructor({ voyageApiKey }: UserConfig, providedFetch?: typeof fetch) {
@@ -44,32 +51,28 @@ class VoyageEmbeddingsProvider implements EmbeddingsProvider<VoyageModels> {
4451
async embed<Model extends VoyageModels>(
4552
modelId: Model,
4653
content: EmbeddingsInput[],
47-
parameters: EmbeddingParameters
54+
parameters: VoyageEmbeddingParameters
4855
): Promise<Embeddings[]> {
49-
const voyageParameters = {
50-
inputType: parameters.inputType,
51-
outputDimension: parameters.numDimensions,
52-
outputDtype: "float", // it is hardcoded on purpose as we don't do quantization yet
53-
};
54-
5556
const model = this.voyage.textEmbeddingModel(modelId);
5657
const { embeddings } = await embedMany({
5758
model,
5859
values: content,
59-
providerOptions: { voyage: voyageParameters },
60+
providerOptions: { voyage: parameters },
6061
});
6162

6263
return embeddings;
6364
}
6465
}
6566

66-
export function getEmbeddingsProvider(userConfig: UserConfig): EmbeddingsProvider<VoyageModels> | undefined {
67+
export function getEmbeddingsProvider(
68+
userConfig: UserConfig
69+
): EmbeddingsProvider<VoyageModels, VoyageEmbeddingParameters> | undefined {
6770
if (VoyageEmbeddingsProvider.isConfiguredIn(userConfig)) {
6871
return new VoyageEmbeddingsProvider(userConfig);
6972
}
7073

7174
return undefined;
7275
}
7376

74-
export const zSupportedEmbeddingModels = zVoyageModels;
75-
export type SupportedEmbeddingModels = z.infer<typeof zSupportedEmbeddingModels>;
77+
export const zSupportedEmbeddingParameters = zVoyageEmbeddingParameters.extend({ model: zVoyageModels });
78+
export type SupportedEmbeddingParameters = z.infer<typeof zSupportedEmbeddingParameters>;

src/common/search/vectorSearchEmbeddingsManager.ts

Lines changed: 6 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import type { ConnectionManager } from "../connectionManager.js";
55
import z from "zod";
66
import { ErrorCodes, MongoDBError } from "../errors.js";
77
import { getEmbeddingsProvider } from "./embeddingsProvider.js";
8-
import type { SupportedEmbeddingModels } from "./embeddingsProvider.js";
8+
import type { EmbeddingParameters, SupportedEmbeddingParameters } from "./embeddingsProvider.js";
99

1010
export const similarityEnum = z.enum(["cosine", "euclidean", "dotProduct"]);
1111
export type Similarity = z.infer<typeof similarityEnum>;
@@ -223,16 +223,16 @@ export class VectorSearchEmbeddingsManager {
223223
database,
224224
collection,
225225
path,
226-
model,
227226
rawValues,
227+
embeddingParameters,
228228
inputType,
229229
}: {
230230
database: string;
231231
collection: string;
232232
path: string;
233-
model: SupportedEmbeddingModels;
234233
rawValues: string[];
235-
inputType: "query" | "document";
234+
embeddingParameters: SupportedEmbeddingParameters;
235+
inputType: EmbeddingParameters["inputType"];
236236
}): Promise<unknown[]> {
237237
const provider = await this.assertAtlasSearchIsAvailable();
238238
if (!provider) {
@@ -258,53 +258,10 @@ export class VectorSearchEmbeddingsManager {
258258
);
259259
}
260260

261-
const providerEmbeddings = await embeddingsProvider.embed(model, rawValues, {
261+
return await embeddingsProvider.embed(embeddingParameters.model, rawValues, {
262262
inputType,
263-
numDimensions: embeddingInfoForPath.numDimensions,
264-
quantization: embeddingInfoForPath.quantization,
263+
...embeddingParameters,
265264
});
266-
267-
if (this.config.disableEmbeddingsValidation) {
268-
return providerEmbeddings;
269-
}
270-
271-
const hasDocuments = await provider.estimatedDocumentCount(database, collection);
272-
if (!hasDocuments) {
273-
return providerEmbeddings;
274-
}
275-
276-
const oneDocument: Document = (await provider
277-
.aggregate(database, collection, [{ $sample: { size: 1 } }, { $project: { embeddings: path } }])
278-
.next()) as Document;
279-
280-
if (!oneDocument) {
281-
return providerEmbeddings;
282-
}
283-
284-
const sampleEmbeddings = oneDocument.embeddings as number[] | BSON.Binary;
285-
const adaptedEmbeddings = providerEmbeddings.map((embeddings) => {
286-
// now map based on the sample embeddings
287-
if (Array.isArray(sampleEmbeddings) && Array.isArray(embeddings)) {
288-
return embeddings;
289-
}
290-
if (sampleEmbeddings instanceof BSON.Binary && Array.isArray(embeddings)) {
291-
if (this.matches(() => sampleEmbeddings.toBits())) {
292-
return BSON.Binary.fromBits(embeddings);
293-
}
294-
if (this.matches(() => sampleEmbeddings.toInt8Array())) {
295-
return BSON.Binary.fromInt8Array(new Int8Array(embeddings));
296-
}
297-
if (this.matches(() => sampleEmbeddings.toFloat32Array())) {
298-
return BSON.Binary.fromFloat32Array(new Float32Array(embeddings));
299-
}
300-
if (this.matches(() => sampleEmbeddings.toPackedBits())) {
301-
return BSON.Binary.fromPackedBits(new Uint8Array(embeddings));
302-
}
303-
}
304-
return embeddings;
305-
});
306-
307-
return adaptedEmbeddings;
308265
}
309266

310267
private isANumber(value: unknown): boolean {
@@ -323,13 +280,4 @@ export class VectorSearchEmbeddingsManager {
323280

324281
return false;
325282
}
326-
327-
private matches(fn: () => unknown): boolean {
328-
try {
329-
fn();
330-
return true;
331-
} catch {
332-
return false;
333-
}
334-
}
335283
}

src/tools/mongodb/read/aggregate.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import { operationWithFallback } from "../../../helpers/operationWithFallback.js
1313
import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js";
1414
import { zEJSON } from "../../args.js";
1515
import { LogId } from "../../../common/logger.js";
16-
import { zSupportedEmbeddingModels } from "../../../common/search/embeddingsProvider.js";
16+
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";
1717

1818
const AnyStage = zEJSON();
1919
const VectorSearchStage = z.object({
@@ -47,10 +47,10 @@ const VectorSearchStage = z.object({
4747
filter: zEJSON()
4848
.optional()
4949
.describe("MQL filter that can only use pre-filter fields from the index definition."),
50-
embeddingModel: zSupportedEmbeddingModels
50+
embeddingParameters: zSupportedEmbeddingParameters
5151
.optional()
5252
.describe(
53-
"The embedding model to use to generate embeddings before search. Note to LLM: If unsure, ask the user before providing one."
53+
"The embedding model and its parameters to use to generate embeddings before searching. It is mandatory if queryVector is a string value. Note to LLM: If unsure, ask the user before providing one."
5454
),
5555
})
5656
.passthrough(),
@@ -233,22 +233,22 @@ export class AggregateTool extends MongoDBToolBase {
233233
continue;
234234
}
235235

236-
if (!vectorSearchStage.embeddingModel) {
236+
if (!vectorSearchStage.embeddingParameters) {
237237
throw new MongoDBError(
238238
ErrorCodes.AtlasVectorSearchInvalidQuery,
239239
"embeddingModel is mandatory if queryVector is a raw string."
240240
);
241241
}
242242

243-
const model = vectorSearchStage.embeddingModel;
244-
delete vectorSearchStage.embeddingModel;
243+
const embeddingParameters = vectorSearchStage.embeddingParameters;
244+
delete vectorSearchStage.embeddingParameters;
245245

246246
const [embeddings] = await this.session.vectorSearchEmbeddingsManager.generateEmbeddings({
247247
database,
248248
collection,
249249
path: vectorSearchStage.path,
250-
model,
251250
rawValues: [vectorSearchStage.queryVector],
251+
embeddingParameters,
252252
inputType: "query",
253253
});
254254

0 commit comments

Comments
 (0)