11import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver" ;
22import { BSON , type Document } from "bson" ;
3+ import type { UserConfig } from "../config.js" ;
34
4- type VectorFieldIndexDefinition = {
5+ export type VectorFieldIndexDefinition = {
56 type : "vector" ;
67 path : string ;
78 numDimensions : number ;
@@ -11,7 +12,10 @@ type VectorFieldIndexDefinition = {
1112
1213export type EmbeddingNamespace = `${string } .${string } `;
1314export class VectorSearchEmbeddings {
14- constructor ( private readonly embeddings : Map < EmbeddingNamespace , VectorFieldIndexDefinition [ ] > = new Map ( ) ) { }
15+ constructor (
16+ private readonly config : UserConfig ,
17+ private readonly embeddings : Map < EmbeddingNamespace , VectorFieldIndexDefinition [ ] > = new Map ( )
18+ ) { }
1519
1620 cleanupEmbeddingsForNamespace ( { database, collection } : { database : string ; collection : string } ) : void {
1721 const embeddingDefKey : EmbeddingNamespace = `${ database } .${ collection } ` ;
@@ -71,6 +75,13 @@ export class VectorSearchEmbeddings {
7175 }
7276
7377 private documentPassesEmbeddingValidation ( definition : VectorFieldIndexDefinition , document : Document ) : boolean {
78+ // While we can do our best effort to ensure that the embedding validation is correct
79+ // based on https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-quantization/
80+ // it's a complex process so we will also give the user the ability to disable this validation
81+ if ( this . config . disableEmbeddingsValidation ) {
82+ return true ;
83+ }
84+
7485 const fieldPath = definition . path . split ( "." ) ;
7586 let fieldRef : unknown = document ;
7687
@@ -84,30 +95,37 @@ export class VectorSearchEmbeddings {
8495
8596 switch ( definition . quantization ) {
8697 case "none" :
98+ return true ;
8799 case "scalar" :
88- if ( ! Array . isArray ( fieldRef ) ) {
89- return false ;
90- }
91-
92- if ( fieldRef . length !== definition . numDimensions ) {
93- return false ;
94- }
95-
96- if ( typeof fieldRef [ 0 ] !== "number" ) {
97- return false ;
98- }
99- break ;
100100 case "binary" :
101101 if ( fieldRef instanceof BSON . Binary ) {
102102 try {
103- const bits = fieldRef . toBits ( ) ;
104- return bits . length === definition . numDimensions ;
103+ const elements = fieldRef . toFloat32Array ( ) ;
104+ return elements . length === definition . numDimensions ;
105105 } catch {
106- return false ;
106+ // bits are also supported
107+ try {
108+ const bits = fieldRef . toBits ( ) ;
109+ return bits . length === definition . numDimensions ;
110+ } catch {
111+ return false ;
112+ }
107113 }
108114 } else {
109- return false ;
115+ if ( ! Array . isArray ( fieldRef ) ) {
116+ return false ;
117+ }
118+
119+ if ( fieldRef . length !== definition . numDimensions ) {
120+ return false ;
121+ }
122+
123+ if ( typeof fieldRef [ 0 ] !== "number" ) {
124+ return false ;
125+ }
110126 }
127+
128+ break ;
111129 }
112130
113131 return true ;
0 commit comments