Skip to content

Commit 082fce9

Browse files
committed
chore: add the ability to disable embedding validation
While we do our best to make sure we don't break anything, there might be situations where users want to disable the validation and insert documents as they please.
1 parent cb52116 commit 082fce9

File tree

9 files changed

+184
-124
lines changed

9 files changed

+184
-124
lines changed

src/common/config.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ const OPTIONS = {
5858
boolean: [
5959
"apiDeprecationErrors",
6060
"apiStrict",
61+
"disableEmbeddingsValidation",
6162
"help",
6263
"indexCheck",
6364
"ipv6",
@@ -183,6 +184,7 @@ export interface UserConfig extends CliOptions {
183184
maxBytesPerQuery: number;
184185
atlasTemporaryDatabaseUserLifetimeMs: number;
185186
voyageApiKey: string;
187+
disableEmbeddingsValidation: boolean;
186188
}
187189

188190
export const defaultUserConfig: UserConfig = {
@@ -213,6 +215,7 @@ export const defaultUserConfig: UserConfig = {
213215
maxBytesPerQuery: 16 * 1024 * 1024, // By default, we only return ~16 mb of data per query / aggregation
214216
atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours
215217
voyageApiKey: "",
218+
disableEmbeddingsValidation: false,
216219
};
217220

218221
export const config = setupUserConfig({

src/common/search/vectorSearchEmbeddings.ts

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
22
import { BSON, type Document } from "bson";
3+
import type { UserConfig } from "../config.js";
34

4-
type VectorFieldIndexDefinition = {
5+
export type VectorFieldIndexDefinition = {
56
type: "vector";
67
path: string;
78
numDimensions: number;
@@ -11,7 +12,10 @@ type VectorFieldIndexDefinition = {
1112

1213
export type EmbeddingNamespace = `${string}.${string}`;
1314
export class VectorSearchEmbeddings {
14-
constructor(private readonly embeddings: Map<EmbeddingNamespace, VectorFieldIndexDefinition[]> = new Map()) {}
15+
constructor(
16+
private readonly config: UserConfig,
17+
private readonly embeddings: Map<EmbeddingNamespace, VectorFieldIndexDefinition[]> = new Map()
18+
) {}
1519

1620
cleanupEmbeddingsForNamespace({ database, collection }: { database: string; collection: string }): void {
1721
const embeddingDefKey: EmbeddingNamespace = `${database}.${collection}`;
@@ -71,6 +75,13 @@ export class VectorSearchEmbeddings {
7175
}
7276

7377
private documentPassesEmbeddingValidation(definition: VectorFieldIndexDefinition, document: Document): boolean {
78+
// While we can do our best effort to ensure that the embedding validation is correct
79+
// based on https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-quantization/
80+
// it's a complex process so we will also give the user the ability to disable this validation
81+
if (this.config.disableEmbeddingsValidation) {
82+
return true;
83+
}
84+
7485
const fieldPath = definition.path.split(".");
7586
let fieldRef: unknown = document;
7687

@@ -84,30 +95,37 @@ export class VectorSearchEmbeddings {
8495

8596
switch (definition.quantization) {
8697
case "none":
98+
return true;
8799
case "scalar":
88-
if (!Array.isArray(fieldRef)) {
89-
return false;
90-
}
91-
92-
if (fieldRef.length !== definition.numDimensions) {
93-
return false;
94-
}
95-
96-
if (typeof fieldRef[0] !== "number") {
97-
return false;
98-
}
99-
break;
100100
case "binary":
101101
if (fieldRef instanceof BSON.Binary) {
102102
try {
103-
const bits = fieldRef.toBits();
104-
return bits.length === definition.numDimensions;
103+
const elements = fieldRef.toFloat32Array();
104+
return elements.length === definition.numDimensions;
105105
} catch {
106-
return false;
106+
// bits are also supported
107+
try {
108+
const bits = fieldRef.toBits();
109+
return bits.length === definition.numDimensions;
110+
} catch {
111+
return false;
112+
}
107113
}
108114
} else {
109-
return false;
115+
if (!Array.isArray(fieldRef)) {
116+
return false;
117+
}
118+
119+
if (fieldRef.length !== definition.numDimensions) {
120+
return false;
121+
}
122+
123+
if (typeof fieldRef[0] !== "number") {
124+
return false;
125+
}
110126
}
127+
128+
break;
111129
}
112130

113131
return true;

src/transports/base.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ export abstract class TransportRunnerBase {
9090
exportsManager,
9191
connectionManager,
9292
keychain: Keychain.root,
93-
vectorSearchEmbeddings: new VectorSearchEmbeddings(),
93+
vectorSearchEmbeddings: new VectorSearchEmbeddings(this.userConfig),
9494
});
9595

9696
const telemetry = Telemetry.create(session, this.userConfig, this.deviceId, {

tests/integration/helpers.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ export function setupIntegrationTest(
102102
exportsManager,
103103
connectionManager,
104104
keychain: new Keychain(),
105-
vectorSearchEmbeddings: new VectorSearchEmbeddings(),
105+
vectorSearchEmbeddings: new VectorSearchEmbeddings(userConfig),
106106
});
107107

108108
// Mock hasValidAccessToken for tests

tests/integration/telemetry.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ describe("Telemetry", () => {
2424
exportsManager: ExportsManager.init(config, logger),
2525
connectionManager: new MCPConnectionManager(config, driverOptions, logger, deviceId),
2626
keychain: new Keychain(),
27-
vectorSearchEmbeddings: new VectorSearchEmbeddings(),
27+
vectorSearchEmbeddings: new VectorSearchEmbeddings(config),
2828
}),
2929
config,
3030
deviceId

tests/integration/tools/mongodb/mongodbTool.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ describe("MongoDBTool implementations", () => {
109109
exportsManager,
110110
connectionManager,
111111
keychain: new Keychain(),
112-
vectorSearchEmbeddings: new VectorSearchEmbeddings(),
112+
vectorSearchEmbeddings: new VectorSearchEmbeddings(userConfig),
113113
});
114114
const telemetry = Telemetry.create(session, userConfig, deviceId);
115115

0 commit comments

Comments
 (0)