@@ -186,6 +186,8 @@ typedef int (*vcursor_sort_callback)(vFullScanCursor *c);
186186extern distance_function_t dispatch_distance_table [VECTOR_DISTANCE_MAX ][VECTOR_TYPE_MAX ];
187187extern char * distance_backend_name ;
188188
189+ static sqlite3_mutex * qmutex ;
190+
189191// MARK: - SQLite Utils -
190192
191193bool sqlite_system_exists (sqlite3 * db , const char * name , const char * type ) {
@@ -1349,11 +1351,14 @@ static void vector_quantize_preload (sqlite3_context *context, int argc, sqlite3
13491351 return ;
13501352 }
13511353
1354+ // free previous preload (if any)
1355+ sqlite3_mutex_enter (qmutex );
13521356 if (t_ctx -> preloaded ) {
13531357 sqlite3_free (t_ctx -> preloaded );
13541358 t_ctx -> preloaded = NULL ;
13551359 t_ctx -> precounter = 0 ;
13561360 }
1361+ sqlite3_mutex_leave (qmutex );
13571362
13581363 char sql [STATIC_SQL_SIZE ];
13591364 generate_memory_quant_table (table_name , column_name , sql );
@@ -1371,36 +1376,43 @@ static void vector_quantize_preload (sqlite3_context *context, int argc, sqlite3
13711376 return ;
13721377 }
13731378
1374- generate_select_quant_table (table_name , column_name , sql );
1375-
1376- int rc = SQLITE_NOMEM ;
13771379 sqlite3_stmt * vm = NULL ;
1378- rc = sqlite3_prepare_v2 (db , sql , -1 , & vm , NULL );
1379- if (rc != SQLITE_OK ) goto vector_preload_cleanup ;
1380+ generate_select_quant_table (table_name , column_name , sql );
1381+ int rc = sqlite3_prepare_v2 (db , sql , -1 , & vm , NULL );
1382+ if (rc != SQLITE_OK ) {
1383+ context_result_error (context , rc , "Internal statement error: %s" , sqlite3_errmsg (db ));
1384+ sqlite3_finalize (vm );
1385+ sqlite3_free (buffer );
1386+ return ;
1387+ }
13801388
13811389 int seek = 0 ;
13821390 while (1 ) {
13831391 rc = sqlite3_step (vm );
13841392 if (rc == SQLITE_DONE ) {rc = SQLITE_OK ; break ;} // return error: rebuild must be call (only if first time run)
1385- else if (rc != SQLITE_ROW ) goto vector_preload_cleanup ;
1393+ else if (rc != SQLITE_ROW ) { break ;}
13861394
13871395 int n = sqlite3_column_int (vm , 0 );
13881396 int bytes = sqlite3_column_bytes (vm , 1 );
13891397 uint8_t * data = (uint8_t * )sqlite3_column_blob (vm , 1 );
13901398
1399+ // no check here because I am sure quantization was performed only on non NULL data
13911400 memcpy (buffer + seek , data , bytes );
13921401 seek += bytes ;
13931402 counter += n ;
13941403 }
1395- rc = SQLITE_OK ;
1404+ sqlite3_finalize (vm );
1405+
1406+ if (rc != SQLITE_OK ) {
1407+ sqlite3_free (buffer );
1408+ context_result_error (context , rc , "vector_quantize_preload failed: %s" , sqlite3_errmsg (db ));
1409+ return ;
1410+ }
13961411
1412+ sqlite3_mutex_enter (qmutex );
13971413 t_ctx -> preloaded = buffer ;
13981414 t_ctx -> precounter = counter ;
1399-
1400- vector_preload_cleanup :
1401- if (rc != SQLITE_OK ) printf ("Error in vector_quantize_preload: %s\n" , sqlite3_errmsg (db ));
1402- if (vm ) sqlite3_finalize (vm );
1403- return ;
1415+ sqlite3_mutex_leave (qmutex );
14041416}
14051417
14061418static int vector_quantize (sqlite3_context * context , const char * table_name , const char * column_name , const char * arg_options , bool * was_preloaded ) {
@@ -1415,8 +1427,10 @@ static int vector_quantize (sqlite3_context *context, const char *table_name, co
14151427 char sql [STATIC_SQL_SIZE ];
14161428 sqlite3 * db = sqlite3_context_db_handle (context );
14171429
1418- rc = sqlite3_exec (db , "BEGIN;" , NULL , NULL , NULL );
1430+ bool savepoint_open = false;
1431+ rc = sqlite3_exec (db , "SAVEPOINT quantize;" , NULL , NULL , NULL );
14191432 if (rc != SQLITE_OK ) goto quantize_cleanup ;
1433+ savepoint_open = true;
14201434
14211435 generate_drop_quant_table (table_name , column_name , sql );
14221436 rc = sqlite3_exec (db , sql , NULL , NULL , NULL );
@@ -1428,12 +1442,11 @@ static int vector_quantize (sqlite3_context *context, const char *table_name, co
14281442
14291443 vector_options options = t_ctx -> options ; // t_ctx guarantees to exist
14301444 bool res = parse_keyvalue_string (context , arg_options , vector_keyvalue_callback , & options );
1431- if (res == false) return SQLITE_ERROR ;
1445+ if (res == false) { rc = SQLITE_ERROR ; goto quantize_cleanup ;}
14321446
1447+ sqlite3_mutex_enter (qmutex );
14331448 rc = vector_rebuild_quantization (context , table_name , column_name , t_ctx , options .q_type , options .max_memory , & counter );
1434- if (rc != SQLITE_OK ) goto quantize_cleanup ;
1435-
1436- rc = sqlite3_exec (db , "COMMIT;" , NULL , NULL , NULL );
1449+ sqlite3_mutex_leave (qmutex );
14371450 if (rc != SQLITE_OK ) goto quantize_cleanup ;
14381451
14391452 // serialize quantization options
@@ -1444,18 +1457,26 @@ static int vector_quantize (sqlite3_context *context, const char *table_name, co
14441457 rc = sqlite_serialize (context , table_name , column_name , SQLITE_FLOAT , OPTION_KEY_QUANTOFFSET , 0 , t_ctx -> offset );
14451458 if (rc != SQLITE_OK ) goto quantize_cleanup ;
14461459
1447- quantize_cleanup :
1448- if (rc != SQLITE_OK ) {
1449- printf ("%s" , sqlite3_errmsg (db ));
1450- sqlite3_exec (db , "ROLLBACK;" , NULL , NULL , NULL );
1451- sqlite3_result_error_code (context , rc );
1452- return rc ;
1453- }
1460+ rc = sqlite3_exec (db , "RELEASE quantize;" , NULL , NULL , NULL );
1461+ if (rc != SQLITE_OK ) goto quantize_cleanup ;
1462+ savepoint_open = false;
14541463
1455- // returns the total number of quantized rows
1464+ // success: returns the total number of quantized rows
14561465 sqlite3_result_int64 (context , (sqlite3_int64 )counter );
14571466 if (was_preloaded ) * was_preloaded = (t_ctx -> preloaded != NULL );
14581467 return SQLITE_OK ;
1468+
1469+ quantize_cleanup : {
1470+ const char * errmsg = sqlite3_errmsg (db );
1471+ if (savepoint_open ) {
1472+ sqlite3_exec (db , "ROLLBACK TO quantize;" , NULL , NULL , NULL );
1473+ sqlite3_exec (db , "RELEASE quantize;" , NULL , NULL , NULL );
1474+ }
1475+
1476+ sqlite3_result_error (context , errmsg , -1 );
1477+ sqlite3_result_error_code (context , rc );
1478+ return rc ;
1479+ }
14591480}
14601481
14611482static void vector_quantize3 (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
@@ -2558,6 +2579,15 @@ SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const s
25582579 #endif
25592580 int rc = SQLITE_OK ;
25602581
2582+ // there's no built-in way to verify if sqlite3_vector_init has already been called for this specific database connection
2583+ // the workaround is to attempt to execute vector_version and check for an error
2584+ // an error indicates that initialization has not been performed
2585+ if (sqlite3_exec (db , "SELECT vector_version();" , NULL , NULL , NULL ) == SQLITE_OK ) return SQLITE_OK ;
2586+
2587+ // get an app global static mutex
2588+ qmutex = sqlite3_mutex_alloc (SQLITE_MUTEX_STATIC_APP1 );
2589+
2590+ // init internal distance functions (do not force CPU)
25612591 init_distance_functions (false);
25622592
25632593 // create internal table
0 commit comments