@@ -1113,12 +1113,13 @@ static int vector_serialize_quantization (sqlite3 *db, const char *table_name, c
11131113 return rc ;
11141114}
11151115
1116- static int vector_rebuild_quantization (sqlite3_context * context , const char * table_name , const char * column_name , table_context * t_ctx , vector_qtype qtype , uint64_t max_memory ) {
1116+ static int vector_rebuild_quantization (sqlite3_context * context , const char * table_name , const char * column_name , table_context * t_ctx , vector_qtype qtype , uint64_t max_memory , uint32_t * count ) {
11171117
11181118 int rc = SQLITE_NOMEM ;
11191119 sqlite3_stmt * vm = NULL ;
11201120 char sql [STATIC_SQL_SIZE ];
11211121 sqlite3 * db = sqlite3_context_db_handle (context );
1122+ uint32_t tot_processed = 0 ;
11221123
11231124 const char * pk_name = t_ctx -> pk_name ;
11241125 int dim = t_ctx -> options .v_dim ;
@@ -1245,7 +1246,6 @@ static int vector_rebuild_quantization (sqlite3_context *context, const char *ta
12451246 // STEP 3
12461247 // actual quantization (ONLY 8bit is supported in this version)
12471248 uint32_t n_processed = 0 ;
1248- uint32_t tot_processed = 0 ;
12491249 int64_t min_rowid = 0 , max_rowid = 0 ;
12501250 while (1 ) {
12511251 rc = sqlite3_step (vm );
@@ -1299,16 +1299,86 @@ static int vector_rebuild_quantization (sqlite3_context *context, const char *ta
12991299 if (original ) sqlite3_free (original );
13001300 if (tempv ) sqlite3_free (tempv );
13011301 if (vm ) sqlite3_finalize (vm );
1302+ if (count ) * count = tot_processed ;
13021303 return rc ;
13031304}
13041305
1305- static void vector_quantize (sqlite3_context * context , const char * table_name , const char * column_name , const char * arg_options ) {
1306+ static void vector_quantize_preload (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
1307+ int types [] = {SQLITE_TEXT , SQLITE_TEXT };
1308+ if (sanity_check_args (context , "vector_quantize_preload" , argc , argv , 2 , types ) == false) return ;
1309+
1310+ const char * table_name = (const char * )sqlite3_value_text (argv [0 ]);
1311+ const char * column_name = (const char * )sqlite3_value_text (argv [1 ]);
1312+
1313+ vector_context * v_ctx = (vector_context * )sqlite3_user_data (context );
1314+ table_context * t_ctx = vector_context_lookup (v_ctx , table_name , column_name );
1315+ if (!t_ctx ) {
1316+ context_result_error (context , SQLITE_ERROR , "Vector context not found for table '%s' and column '%s'. Ensure that vector_init() has been called before using vector_quantize_preload()." , table_name , column_name );
1317+ return ;
1318+ }
1319+
1320+ if (t_ctx -> preloaded ) {
1321+ sqlite3_free (t_ctx -> preloaded );
1322+ t_ctx -> preloaded = NULL ;
1323+ t_ctx -> precounter = 0 ;
1324+ }
1325+
1326+ char sql [STATIC_SQL_SIZE ];
1327+ generate_memory_quant_table (table_name , column_name , sql );
1328+ sqlite3 * db = sqlite3_context_db_handle (context );
1329+ sqlite3_int64 required = sqlite_read_int64 (db , sql );
1330+ if (required == 0 ) {
1331+ context_result_error (context , SQLITE_ERROR , "Unable to read data from database. Ensure that vector_quantize() has been called before using vector_quantize_preload()." );
1332+ return ;
1333+ }
1334+
1335+ int counter = 0 ;
1336+ void * buffer = (void * )sqlite3_malloc64 (required );
1337+ if (!buffer ) {
1338+ context_result_error (context , SQLITE_NOMEM , "Out of memory: unable to allocate %lld bytes for quant buffer." , (long long )required );
1339+ return ;
1340+ }
1341+
1342+ generate_select_quant_table (table_name , column_name , sql );
1343+
1344+ int rc = SQLITE_NOMEM ;
1345+ sqlite3_stmt * vm = NULL ;
1346+ rc = sqlite3_prepare_v2 (db , sql , -1 , & vm , NULL );
1347+ if (rc != SQLITE_OK ) goto vector_preload_cleanup ;
1348+
1349+ int seek = 0 ;
1350+ while (1 ) {
1351+ rc = sqlite3_step (vm );
1352+ if (rc == SQLITE_DONE ) {rc = SQLITE_OK ; break ;} // return error: rebuild must be call (only if first time run)
1353+ else if (rc != SQLITE_ROW ) goto vector_preload_cleanup ;
1354+
1355+ int n = sqlite3_column_int (vm , 0 );
1356+ int bytes = sqlite3_column_bytes (vm , 1 );
1357+ uint8_t * data = (uint8_t * )sqlite3_column_blob (vm , 1 );
1358+
1359+ memcpy (buffer + seek , data , bytes );
1360+ seek += bytes ;
1361+ counter += n ;
1362+ }
1363+ rc = SQLITE_OK ;
1364+
1365+ t_ctx -> preloaded = buffer ;
1366+ t_ctx -> precounter = counter ;
1367+
1368+ vector_preload_cleanup :
1369+ if (rc != SQLITE_OK ) printf ("Error in vector_quantize_preload: %s\n" , sqlite3_errmsg (db ));
1370+ if (vm ) sqlite3_finalize (vm );
1371+ return ;
1372+ }
1373+
1374+ static int vector_quantize (sqlite3_context * context , const char * table_name , const char * column_name , const char * arg_options , bool * was_preloaded ) {
13061375 table_context * t_ctx = vector_context_lookup ((vector_context * )sqlite3_user_data (context ), table_name , column_name );
13071376 if (!t_ctx ) {
13081377 context_result_error (context , SQLITE_ERROR , "Vector context not found for table '%s' and column '%s'. Ensure that vector_init() has been called before using vector_quantize()." , table_name , column_name );
1309- return ;
1378+ return SQLITE_ERROR ;
13101379 }
13111380
1381+ uint32_t counter = 0 ;
13121382 int rc = SQLITE_ERROR ;
13131383 char sql [STATIC_SQL_SIZE ];
13141384 sqlite3 * db = sqlite3_context_db_handle (context );
@@ -1326,9 +1396,9 @@ static void vector_quantize (sqlite3_context *context, const char *table_name, c
13261396
13271397 vector_options options = vector_options_create ();
13281398 bool res = parse_keyvalue_string (context , arg_options , vector_keyvalue_callback , & options );
1329- if (res == false) return ;
1399+ if (res == false) return SQLITE_ERROR ;
13301400
1331- rc = vector_rebuild_quantization (context , table_name , column_name , t_ctx , options .q_type , options .max_memory );
1401+ rc = vector_rebuild_quantization (context , table_name , column_name , t_ctx , options .q_type , options .max_memory , & counter );
13321402 if (rc != SQLITE_OK ) goto quantize_cleanup ;
13331403
13341404 rc = sqlite3_exec (db , "COMMIT;" , NULL , NULL , NULL );
@@ -1347,8 +1417,13 @@ static void vector_quantize (sqlite3_context *context, const char *table_name, c
13471417 printf ("%s" , sqlite3_errmsg (db ));
13481418 sqlite3_exec (db , "ROLLBACK;" , NULL , NULL , NULL );
13491419 sqlite3_result_error_code (context , rc );
1350- return ;
1420+ return rc ;
13511421 }
1422+
1423+ // returns the total number of quantized rows
1424+ sqlite3_result_int64 (context , (sqlite3_int64 )counter );
1425+ if (was_preloaded ) * was_preloaded = (t_ctx -> preloaded != NULL );
1426+ return SQLITE_OK ;
13521427}
13531428
13541429static void vector_quantize3 (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
@@ -1358,7 +1433,10 @@ static void vector_quantize3 (sqlite3_context *context, int argc, sqlite3_value
13581433 const char * table_name = (const char * )sqlite3_value_text (argv [0 ]);
13591434 const char * column_name = (const char * )sqlite3_value_text (argv [1 ]);
13601435 const char * options = (const char * )sqlite3_value_text (argv [2 ]);
1361- vector_quantize (context , table_name , column_name , options );
1436+
1437+ bool was_preloaded = false;
1438+ int rc = vector_quantize (context , table_name , column_name , options , & was_preloaded );
1439+ if ((rc == SQLITE_OK ) && (was_preloaded )) vector_quantize_preload (context , argc , argv );
13621440}
13631441
13641442static void vector_quantize2 (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
@@ -1367,7 +1445,10 @@ static void vector_quantize2 (sqlite3_context *context, int argc, sqlite3_value
13671445
13681446 const char * table_name = (const char * )sqlite3_value_text (argv [0 ]);
13691447 const char * column_name = (const char * )sqlite3_value_text (argv [1 ]);
1370- vector_quantize (context , table_name , column_name , NULL );
1448+
1449+ bool was_preloaded = false;
1450+ int rc = vector_quantize (context , table_name , column_name , NULL , & was_preloaded );
1451+ if ((rc == SQLITE_OK ) && (was_preloaded )) vector_quantize_preload (context , argc , argv );
13711452}
13721453
13731454static void vector_quantize_memory (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
@@ -1385,99 +1466,29 @@ static void vector_quantize_memory (sqlite3_context *context, int argc, sqlite3_
13851466 sqlite3_result_int64 (context , memory );
13861467}
13871468
1388- static void vector_quantize_preload (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
1469+ static void vector_quantize_cleanup (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
13891470 int types [] = {SQLITE_TEXT , SQLITE_TEXT };
1390- if (sanity_check_args (context , "vector_quantize_preload " , argc , argv , 2 , types ) == false) return ;
1471+ if (sanity_check_args (context , "vector_quantize_cleanup " , argc , argv , 2 , types ) == false) return ;
13911472
13921473 const char * table_name = (const char * )sqlite3_value_text (argv [0 ]);
13931474 const char * column_name = (const char * )sqlite3_value_text (argv [1 ]);
13941475
13951476 vector_context * v_ctx = (vector_context * )sqlite3_user_data (context );
13961477 table_context * t_ctx = vector_context_lookup (v_ctx , table_name , column_name );
1397- if (!t_ctx ) {
1398- context_result_error (context , SQLITE_ERROR , "Vector context not found for table '%s' and column '%s'. Ensure that vector_init() has been called before using vector_quantize_preload()." , table_name , column_name );
1399- return ;
1400- }
1478+ if (!t_ctx ) return ; // if no table context exists then do nothing
14011479
1480+ // release any memory used in quantization
14021481 if (t_ctx -> preloaded ) {
14031482 sqlite3_free (t_ctx -> preloaded );
14041483 t_ctx -> preloaded = NULL ;
14051484 t_ctx -> precounter = 0 ;
14061485 }
14071486
1408- char sql [STATIC_SQL_SIZE ];
1409- generate_memory_quant_table (table_name , column_name , sql );
1410- sqlite3 * db = sqlite3_context_db_handle (context );
1411- sqlite3_int64 required = sqlite_read_int64 (db , sql );
1412- if (required == 0 ) {
1413- context_result_error (context , SQLITE_ERROR , "Unable to read data from database. Ensure that vector_quantize() has been called before using vector_quantize_preload()." );
1414- return ;
1415- }
1416-
1417- int counter = 0 ;
1418- void * buffer = (void * )sqlite3_malloc64 (required );
1419- if (!buffer ) {
1420- context_result_error (context , SQLITE_NOMEM , "Out of memory: unable to allocate %lld bytes for quant buffer." , (long long )required );
1421- return ;
1422- }
1423-
1424- generate_select_quant_table (table_name , column_name , sql );
1425-
1426- int rc = SQLITE_NOMEM ;
1427- sqlite3_stmt * vm = NULL ;
1428- rc = sqlite3_prepare_v2 (db , sql , -1 , & vm , NULL );
1429- if (rc != SQLITE_OK ) goto vector_preload_cleanup ;
1430-
1431- int seek = 0 ;
1432- while (1 ) {
1433- rc = sqlite3_step (vm );
1434- if (rc == SQLITE_DONE ) {rc = SQLITE_OK ; break ;} // return error: rebuild must be call (only if first time run)
1435- else if (rc != SQLITE_ROW ) goto vector_preload_cleanup ;
1436-
1437- int n = sqlite3_column_int (vm , 0 );
1438- int bytes = sqlite3_column_bytes (vm , 1 );
1439- uint8_t * data = (uint8_t * )sqlite3_column_blob (vm , 1 );
1440-
1441- memcpy (buffer + seek , data , bytes );
1442- seek += bytes ;
1443- counter += n ;
1444- }
1445- rc = SQLITE_OK ;
1446-
1447- t_ctx -> preloaded = buffer ;
1448- t_ctx -> precounter = counter ;
1449-
1450- vector_preload_cleanup :
1451- if (rc != SQLITE_OK ) printf ("Error in vector_quantize_preload: %s\n" , sqlite3_errmsg (db ));
1452- if (vm ) sqlite3_finalize (vm );
1453- return ;
1454- }
1455-
1456- static void vector_cleanup (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
1457- int types [] = {SQLITE_TEXT , SQLITE_TEXT };
1458- if (sanity_check_args (context , "vector_cleanup" , argc , argv , 2 , types ) == false) return ;
1459-
1460- const char * table_name = (const char * )sqlite3_value_text (argv [0 ]);
1461- const char * column_name = (const char * )sqlite3_value_text (argv [1 ]);
1462-
1463- vector_context * v_ctx = (vector_context * )sqlite3_user_data (context );
1464- table_context * t_ctx = vector_context_lookup (v_ctx , table_name , column_name );
1465- if (!t_ctx ) return ; // if no table context exists then do nothing
1466-
1467- // release memory
1468- if (t_ctx -> t_name ) sqlite3_free (t_ctx -> t_name );
1469- if (t_ctx -> c_name ) sqlite3_free (t_ctx -> c_name );
1470- if (t_ctx -> pk_name ) sqlite3_free (t_ctx -> pk_name );
1471- if (t_ctx -> preloaded ) sqlite3_free (t_ctx -> preloaded );
1472- memset (t_ctx , 0 , sizeof (table_context ));
1473-
14741487 // drop quant table (if any)
14751488 char sql [STATIC_SQL_SIZE ];
14761489 sqlite3 * db = sqlite3_context_db_handle (context );
14771490 generate_drop_quant_table (table_name , column_name , sql );
14781491 sqlite3_exec (db , sql , NULL , NULL , NULL );
1479-
1480- // do not decrease v_ctx->table_count
14811492}
14821493
14831494// MARK: -
@@ -2260,7 +2271,7 @@ SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const s
22602271 if (rc != SQLITE_OK ) goto cleanup ;
22612272
22622273 // table_name, column_name
2263- rc = sqlite3_create_function (db , "vector_cleanup " , 2 , SQLITE_UTF8 , ctx , vector_cleanup , NULL , NULL );
2274+ rc = sqlite3_create_function (db , "vector_quantize_cleanup " , 2 , SQLITE_UTF8 , ctx , vector_quantize_cleanup , NULL , NULL );
22642275 if (rc != SQLITE_OK ) goto cleanup ;
22652276
22662277 rc = sqlite3_create_function (db , "vector_as_f32" , 1 , SQLITE_UTF8 , ctx , vector_as_f32 , NULL , NULL );
0 commit comments