@@ -94,8 +94,12 @@ SQLITE_EXTENSION_INIT1
9494#define OPTION_KEY_DIMENSION "dimension"
9595#define OPTION_KEY_NORMALIZED "normalized"
9696#define OPTION_KEY_MAXMEMORY "max_memory"
97- #define OPTION_KEY_QUANTTYPE "quant_type"
9897#define OPTION_KEY_DISTANCE "distance"
98+ #define OPTION_KEY_QUANTTYPE "qtype"
99+ #define OPTION_KEY_QUANTSCALE "qscale" // used only in serialize/unserialize
100+ #define OPTION_KEY_QUANTOFFSET "qoffset" // used only in serialize/unserialize
101+
102+ #define VECTOR_INTERNAL_TABLE "CREATE TABLE IF NOT EXISTS __vector (tblname TEXT, colname TEXT, key TEXT, value ANY, PRIMARY KEY(tblname, colname, key));"
99103
100104typedef struct {
101105 vector_type v_type ; // vector type
@@ -472,9 +476,83 @@ static void *sqlite_common_set_error (sqlite3_context *context, sqlite3_vtab *vt
472476 return NULL ;
473477}
474478
479+ static int sqlite_serialize (sqlite3_context * context , const char * table_name , const char * column_name , int type , const char * key , int64_t ivalue , double fvalue ) {
480+ const char * sql = "REPLACE INTO __vector (tblname, colname, key, value) VALUES (?, ?, ?, ?);" ;
481+ sqlite3 * db = sqlite3_context_db_handle (context );
482+ sqlite3_stmt * vm = NULL ;
483+
484+ int rc = sqlite3_prepare_v2 (db , sql , -1 , & vm , NULL );
485+ if (rc != SQLITE_OK ) goto cleanup ;
486+
487+ rc = sqlite3_bind_text (vm , 1 , table_name , -1 , SQLITE_STATIC );
488+ if (rc != SQLITE_OK ) goto cleanup ;
489+
490+ rc = sqlite3_bind_text (vm , 2 , column_name , -1 , SQLITE_STATIC );
491+ if (rc != SQLITE_OK ) goto cleanup ;
492+
493+ rc = sqlite3_bind_text (vm , 3 , key , -1 , SQLITE_STATIC );
494+ if (rc != SQLITE_OK ) goto cleanup ;
495+
496+ switch (type ) {
497+ case SQLITE_INTEGER : rc = sqlite3_bind_int64 (vm , 4 , (sqlite3_int64 )ivalue ); break ;
498+ case SQLITE_FLOAT : rc = sqlite3_bind_double (vm , 4 , fvalue ); break ;
499+ }
500+ if (rc != SQLITE_OK ) goto cleanup ;
501+
502+ rc = sqlite3_step (vm );
503+ if (rc == SQLITE_DONE ) rc = SQLITE_OK ;
504+
505+ cleanup :
506+ if (rc != SQLITE_OK ) sqlite3_result_error (context , sqlite3_errmsg (db ), -1 );
507+ if (vm ) sqlite3_finalize (vm );
508+ return rc ;
509+ }
510+
511+ static int sqlite_unserialize (sqlite3_context * context , table_context * ctx ) {
512+ const char * sql = "SELECT key, value FROM __vector WHERE tblname = ? AND colname = ?;" ;
513+ sqlite3 * db = sqlite3_context_db_handle (context );
514+ sqlite3_stmt * vm = NULL ;
515+
516+ int rc = sqlite3_prepare_v2 (db , sql , -1 , & vm , NULL );
517+ if (rc != SQLITE_OK ) goto cleanup ;
518+
519+ rc = sqlite3_bind_text (vm , 1 , ctx -> t_name , -1 , SQLITE_STATIC );
520+ if (rc != SQLITE_OK ) goto cleanup ;
521+
522+ rc = sqlite3_bind_text (vm , 2 , ctx -> c_name , -1 , SQLITE_STATIC );
523+ if (rc != SQLITE_OK ) goto cleanup ;
524+
525+ while (1 ) {
526+ rc = sqlite3_step (vm );
527+ if (rc == SQLITE_DONE ) {rc = SQLITE_OK ; break ;}
528+ if (rc != SQLITE_ROW ) break ;
529+
530+ const char * key = (const char * )sqlite3_column_text (vm , 0 );
531+ if (strcmp (key , OPTION_KEY_QUANTTYPE ) == 0 ) {
532+ ctx -> options .q_type = (vector_qtype )sqlite3_column_int (vm , 1 );
533+ continue ;
534+ }
535+
536+ if (strcmp (key , OPTION_KEY_QUANTSCALE ) == 0 ) {
537+ ctx -> scale = (float )sqlite3_column_double (vm , 1 );
538+ continue ;
539+ }
540+
541+ if (strcmp (key , OPTION_KEY_QUANTOFFSET ) == 0 ) {
542+ ctx -> offset = (float )sqlite3_column_double (vm , 1 );
543+ continue ;
544+ }
545+ }
546+
547+ cleanup :
548+ //if (rc != SQLITE_OK) sqlite3_result_error(context, sqlite3_errmsg(db), -1);
549+ if (vm ) sqlite3_finalize (vm );
550+ return rc ;
551+ }
552+
475553// MARK: - General Utils -
476554
477- static inline void quantize_float32_to_u8 (float * v , uint8_t * q , float offset , float scale , int n ) {
555+ static inline void quantize_float32_to_unsigned8bit (float * v , uint8_t * q , float offset , float scale , int n ) {
478556 int i = 0 ;
479557 for (; i + 3 < n ; i += 4 ) {
480558 float s0 = (v [i ] - offset ) * scale ;
@@ -507,6 +585,38 @@ static inline void quantize_float32_to_u8 (float *v, uint8_t *q, float offset, f
507585 }
508586}
509587
588+ static inline void quantize_float32_to_signed8bit (float * v , int8_t * q , float offset , float scale , int n ) {
589+ int i = 0 ;
590+ for (; i + 3 < n ; i += 4 ) {
591+ float s0 = (v [i ] - offset ) * scale ;
592+ float s1 = (v [i + 1 ] - offset ) * scale ;
593+ float s2 = (v [i + 2 ] - offset ) * scale ;
594+ float s3 = (v [i + 3 ] - offset ) * scale ;
595+
596+ int r0 = (int )(s0 + 0.5f * (1.0f - 2.0f * (s0 < 0.0f )));
597+ int r1 = (int )(s1 + 0.5f * (1.0f - 2.0f * (s1 < 0.0f )));
598+ int r2 = (int )(s2 + 0.5f * (1.0f - 2.0f * (s2 < 0.0f )));
599+ int r3 = (int )(s3 + 0.5f * (1.0f - 2.0f * (s3 < 0.0f )));
600+
601+ r0 = r0 > 127 ? 127 : (r0 < -128 ? -128 : r0 );
602+ r1 = r1 > 127 ? 127 : (r1 < -128 ? -128 : r1 );
603+ r2 = r2 > 127 ? 127 : (r2 < -128 ? -128 : r2 );
604+ r3 = r3 > 127 ? 127 : (r3 < -128 ? -128 : r3 );
605+
606+ q [i ] = (int8_t )r0 ;
607+ q [i + 1 ] = (int8_t )r1 ;
608+ q [i + 2 ] = (int8_t )r2 ;
609+ q [i + 3 ] = (int8_t )r3 ;
610+ }
611+
612+ for (; i < n ; ++ i ) {
613+ float scaled = (v [i ] - offset ) * scale ;
614+ int rounded = (int )(scaled + 0.5f * (1.0f - 2.0f * (scaled < 0.0f )));
615+ rounded = rounded > 127 ? 127 : (rounded < -128 ? -128 : rounded );
616+ q [i ] = (int8_t )rounded ;
617+ }
618+ }
619+
510620static size_t vector_type_to_size (vector_type type ) {
511621 switch (type ) {
512622 case VECTOR_TYPE_F32 : return sizeof (float );
@@ -539,8 +649,9 @@ const char *vector_type_to_name (vector_type type) {
539649}
540650
541651static vector_qtype quant_name_to_type (const char * qname ) {
542- if (strcasecmp (qname , "QUANTU8" ) == 0 ) return VECTOR_QUANT_8BIT ;
543- return 0 ;
652+ if (strcasecmp (qname , "UINT8" ) == 0 ) return VECTOR_QUANT_U8BIT ;
653+ if (strcasecmp (qname , "INT8" ) == 0 ) return VECTOR_QUANT_S8BIT ;
654+ return -1 ;
544655}
545656
546657static vector_distance distance_name_to_type (const char * dname ) {
@@ -683,7 +794,7 @@ bool vector_keyvalue_callback (sqlite3_context *context, void *xdata, const char
683794
684795 if (strncasecmp (key , OPTION_KEY_QUANTTYPE , key_len ) == 0 ) {
685796 vector_qtype type = quant_name_to_type (buffer );
686- if (type == 0 ) return context_result_error (context , SQLITE_ERROR , "Invalid quantization type: '%s' is not a recognized or supported quantization type." , buffer );
797+ if (type == -1 ) return context_result_error (context , SQLITE_ERROR , "Invalid quantization type: '%s' is not a recognized or supported quantization type." , buffer );
687798 options -> q_type = type ;
688799 return true;
689800 }
@@ -736,6 +847,7 @@ void *vector_context_create (void) {
736847}
737848
738849void vector_context_free (void * p ) {
850+ return ;
739851 if (p ) {
740852 vector_context * ctx = (vector_context * )p ;
741853 for (int i = 0 ; i < ctx -> table_count ; ++ i ) {
@@ -791,14 +903,16 @@ void vector_context_add (sqlite3_context *context, vector_context *ctx, const ch
791903 ctx -> tables [index ].pk_name = prikey ;
792904 ctx -> tables [index ].options = * options ;
793905 ctx -> table_count ++ ;
906+
907+ sqlite_unserialize (context , & ctx -> tables [index ]);
794908}
795909
796910void vector_options_init (vector_options * options ) {
797911 memset (options , 0 , sizeof (vector_options ));
798912 options -> v_type = VECTOR_TYPE_F32 ;
799913 options -> v_distance = VECTOR_DISTANCE_L2 ;
800914 options -> max_memory = DEFAULT_MAX_MEMORY ;
801- options -> q_type = VECTOR_QUANT_8BIT ;
915+ options -> q_type = VECTOR_QUANT_AUTO ;
802916}
803917
804918vector_options vector_options_create (void ) {
@@ -889,6 +1003,7 @@ static int vector_rebuild_quantization (sqlite3_context *context, const char *ta
8891003 float max_val = - MAXFLOAT ;
8901004 #endif
8911005
1006+ bool contains_negative = false;
8921007 while (1 ) {
8931008 rc = sqlite3_step (vm );
8941009 if (rc == SQLITE_DONE ) {rc = SQLITE_OK ; break ;}
@@ -927,15 +1042,24 @@ static int vector_rebuild_quantization (sqlite3_context *context, const char *ta
9271042 }
9281043 if (val < min_val ) min_val = val ;
9291044 if (val > max_val ) max_val = val ;
1045+ if (val < 0.0 ) contains_negative = true;
9301046 }
9311047 }
9321048
1049+ // set proper format
1050+ if (qtype == VECTOR_QUANT_AUTO ) {
1051+ if (contains_negative == true) qtype = VECTOR_QUANT_S8BIT ;
1052+ else qtype = VECTOR_QUANT_U8BIT ;
1053+ }
1054+
9331055 // STEP 2
934- // compute scale and offset and set table them to table context
935- // standard min-max linear quantization
936- float scale = 255.0f / (max_val - min_val );
937- float offset = min_val ;
1056+ // compute scale and offset and set table them to table context standard min-max linear quantization
1057+ float abs_max = fmaxf (fabsf (min_val ), fabsf (max_val )); // only used in VECTOR_QUANT_S8BIT
1058+ float scale = (qtype == VECTOR_QUANT_U8BIT ) ? (255.0f / (max_val - min_val )) : (127.0f / abs_max );
1059+ // in the VECTOR_QUANT_S8BIT version I am assuming a symmetric quantization, for asymmetric quantization min_val should be used
1060+ float offset = (qtype == VECTOR_QUANT_U8BIT ) ? min_val : 0.0f ;
9381061
1062+ t_ctx -> options .q_type = qtype ;
9391063 t_ctx -> scale = scale ;
9401064 t_ctx -> offset = offset ;
9411065
@@ -986,7 +1110,8 @@ static int vector_rebuild_quantization (sqlite3_context *context, const char *ta
9861110 data += sizeof (int64_t );
9871111
9881112 // quantize vector
989- quantize_float32_to_u8 (v , data , offset , scale , dim );
1113+ if (qtype == VECTOR_QUANT_U8BIT ) quantize_float32_to_unsigned8bit (v , data , offset , scale , dim );
1114+ else quantize_float32_to_signed8bit (v , (int8_t * )data , offset , scale , dim );
9901115 data += (dim * sizeof (uint8_t ));
9911116
9921117 max_rowid = rowid ;
@@ -1048,6 +1173,14 @@ static void vector_quantize (sqlite3_context *context, const char *table_name, c
10481173 rc = sqlite3_exec (db , "COMMIT;" , NULL , NULL , NULL );
10491174 if (rc != SQLITE_OK ) goto quantize_cleanup ;
10501175
1176+ // serialize quantization options
1177+ rc = sqlite_serialize (context , table_name , column_name , SQLITE_INTEGER , OPTION_KEY_QUANTTYPE , t_ctx -> options .q_type , 0 );
1178+ if (rc != SQLITE_OK ) goto quantize_cleanup ;
1179+ rc = sqlite_serialize (context , table_name , column_name , SQLITE_FLOAT , OPTION_KEY_QUANTSCALE , 0 , t_ctx -> scale );
1180+ if (rc != SQLITE_OK ) goto quantize_cleanup ;
1181+ rc = sqlite_serialize (context , table_name , column_name , SQLITE_FLOAT , OPTION_KEY_QUANTOFFSET , 0 , t_ctx -> offset );
1182+ if (rc != SQLITE_OK ) goto quantize_cleanup ;
1183+
10511184quantize_cleanup :
10521185 if (rc != SQLITE_OK ) {
10531186 printf ("%s" , sqlite3_errmsg (db ));
@@ -1658,7 +1791,7 @@ static int vFullScanCursorFilter (sqlite3_vtab_cursor *cur, int idxNum, const ch
16581791
16591792// MARK: -
16601793
1661- static int vQuantRunMemory (vFullScanCursor * c , uint8_t * v , int dim ) {
1794+ static int vQuantRunMemory (vFullScanCursor * c , uint8_t * v , vector_qtype qtype , int dim ) {
16621795 const int counter = c -> table -> precounter ;
16631796 const uint8_t * data = c -> table -> preloaded ;
16641797 const size_t rowid_size = sizeof (int64_t );
@@ -1672,7 +1805,7 @@ static int vQuantRunMemory(vFullScanCursor *c, uint8_t *v, int dim) {
16721805
16731806 // compute distance function
16741807 vector_distance vd = c -> table -> options .v_distance ;
1675- vector_type vt = VECTOR_TYPE_U8 ;
1808+ vector_type vt = ( qtype == VECTOR_QUANT_U8BIT ) ? VECTOR_TYPE_U8 : VECTOR_TYPE_I8 ;
16761809 distance_function_t distance_fn = dispatch_distance_table [vd ][vt ];
16771810
16781811 for (int i = 0 ; i < counter ; ++ i ) {
@@ -1701,9 +1834,14 @@ static int vQuantRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1siz
17011834 uint8_t * v = (uint8_t * )sqlite3_malloc (dimension * sizeof (int8_t ));
17021835 if (!v ) return SQLITE_NOMEM ;
17031836
1704- quantize_float32_to_u8 ((float * )v1 , v , c -> table -> offset , c -> table -> scale , dimension );
1837+ vector_qtype qtype = c -> table -> options .q_type ;
1838+ if (qtype == VECTOR_QUANT_U8BIT ) {
1839+ quantize_float32_to_unsigned8bit ((float * )v1 , v , c -> table -> offset , c -> table -> scale , dimension );
1840+ } else {
1841+ quantize_float32_to_signed8bit ((float * )v1 , (int8_t * )v , c -> table -> offset , c -> table -> scale , dimension );
1842+ }
17051843 if (c -> table -> preloaded ) {
1706- int rc = vQuantRunMemory (c , v , dimension );
1844+ int rc = vQuantRunMemory (c , v , qtype , dimension );
17071845 if (v ) sqlite3_free (v );
17081846 return rc ;
17091847 }
@@ -1721,7 +1859,7 @@ static int vQuantRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1siz
17211859
17221860 // compute distance function
17231861 vector_distance vd = c -> table -> options .v_distance ;
1724- vector_type vt = VECTOR_TYPE_U8 ;
1862+ vector_type vt = ( qtype == VECTOR_QUANT_U8BIT ) ? VECTOR_TYPE_U8 : VECTOR_TYPE_I8 ;
17251863 distance_function_t distance_fn = dispatch_distance_table [vd ][vt ];
17261864
17271865 while (1 ) {
@@ -1898,6 +2036,11 @@ SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const s
18982036
18992037 init_distance_functions (false);
19002038
2039+ // TODO: error message must be duplicate here?
2040+ // create internal table
2041+ rc = sqlite3_exec (db , VECTOR_INTERNAL_TABLE , NULL , NULL , NULL );
2042+ if (rc != SQLITE_OK ) return rc ;
2043+
19012044 // init context
19022045 void * ctx = vector_context_create ();
19032046 if (!ctx ) {
0 commit comments