Skip to content

Commit 1de9b15

Browse files
committed
Added more protection to some critical functions
1 parent c2ca2b4 commit 1de9b15

File tree

2 files changed

+56
-26
lines changed

2 files changed

+56
-26
lines changed

src/sqlite-vector.c

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ typedef int (*vcursor_sort_callback)(vFullScanCursor *c);
186186
extern distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX];
187187
extern char *distance_backend_name;
188188

189+
static sqlite3_mutex *qmutex;
190+
189191
// MARK: - SQLite Utils -
190192

191193
bool sqlite_system_exists (sqlite3 *db, const char *name, const char *type) {
@@ -1349,11 +1351,14 @@ static void vector_quantize_preload (sqlite3_context *context, int argc, sqlite3
13491351
return;
13501352
}
13511353

1354+
// free previous preload (if any)
1355+
sqlite3_mutex_enter(qmutex);
13521356
if (t_ctx->preloaded) {
13531357
sqlite3_free(t_ctx->preloaded);
13541358
t_ctx->preloaded = NULL;
13551359
t_ctx->precounter = 0;
13561360
}
1361+
sqlite3_mutex_leave(qmutex);
13571362

13581363
char sql[STATIC_SQL_SIZE];
13591364
generate_memory_quant_table(table_name, column_name, sql);
@@ -1371,36 +1376,43 @@ static void vector_quantize_preload (sqlite3_context *context, int argc, sqlite3
13711376
return;
13721377
}
13731378

1374-
generate_select_quant_table(table_name, column_name, sql);
1375-
1376-
int rc = SQLITE_NOMEM;
13771379
sqlite3_stmt *vm = NULL;
1378-
rc = sqlite3_prepare_v2(db, sql, -1, &vm, NULL);
1379-
if (rc != SQLITE_OK) goto vector_preload_cleanup;
1380+
generate_select_quant_table(table_name, column_name, sql);
1381+
int rc = sqlite3_prepare_v2(db, sql, -1, &vm, NULL);
1382+
if (rc != SQLITE_OK) {
1383+
context_result_error(context, rc, "Internal statement error: %s", sqlite3_errmsg(db));
1384+
sqlite3_finalize(vm);
1385+
sqlite3_free(buffer);
1386+
return;
1387+
}
13801388

13811389
int seek = 0;
13821390
while (1) {
13831391
rc = sqlite3_step(vm);
13841392
if (rc == SQLITE_DONE) {rc = SQLITE_OK; break;} // return error: rebuild must be call (only if first time run)
1385-
else if (rc != SQLITE_ROW) goto vector_preload_cleanup;
1393+
else if (rc != SQLITE_ROW) {break;}
13861394

13871395
int n = sqlite3_column_int(vm, 0);
13881396
int bytes = sqlite3_column_bytes(vm, 1);
13891397
uint8_t *data = (uint8_t *)sqlite3_column_blob(vm, 1);
13901398

1399+
// no check here because I am sure quantization was performed only on non NULL data
13911400
memcpy(buffer+seek, data, bytes);
13921401
seek += bytes;
13931402
counter += n;
13941403
}
1395-
rc = SQLITE_OK;
1404+
sqlite3_finalize(vm);
1405+
1406+
if (rc != SQLITE_OK) {
1407+
sqlite3_free(buffer);
1408+
context_result_error(context, rc, "vector_quantize_preload failed: %s", sqlite3_errmsg(db));
1409+
return;
1410+
}
13961411

1412+
sqlite3_mutex_enter(qmutex);
13971413
t_ctx->preloaded = buffer;
13981414
t_ctx->precounter = counter;
1399-
1400-
vector_preload_cleanup:
1401-
if (rc != SQLITE_OK) printf("Error in vector_quantize_preload: %s\n", sqlite3_errmsg(db));
1402-
if (vm) sqlite3_finalize(vm);
1403-
return;
1415+
sqlite3_mutex_leave(qmutex);
14041416
}
14051417

14061418
static int vector_quantize (sqlite3_context *context, const char *table_name, const char *column_name, const char *arg_options, bool *was_preloaded) {
@@ -1415,8 +1427,10 @@ static int vector_quantize (sqlite3_context *context, const char *table_name, co
14151427
char sql[STATIC_SQL_SIZE];
14161428
sqlite3 *db = sqlite3_context_db_handle(context);
14171429

1418-
rc = sqlite3_exec(db, "BEGIN;", NULL, NULL, NULL);
1430+
bool savepoint_open = false;
1431+
rc = sqlite3_exec(db, "SAVEPOINT quantize;", NULL, NULL, NULL);
14191432
if (rc != SQLITE_OK) goto quantize_cleanup;
1433+
savepoint_open = true;
14201434

14211435
generate_drop_quant_table(table_name, column_name, sql);
14221436
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
@@ -1428,12 +1442,11 @@ static int vector_quantize (sqlite3_context *context, const char *table_name, co
14281442

14291443
vector_options options = t_ctx->options; // t_ctx guarantees to exist
14301444
bool res = parse_keyvalue_string(context, arg_options, vector_keyvalue_callback, &options);
1431-
if (res == false) return SQLITE_ERROR;
1445+
if (res == false) {rc = SQLITE_ERROR; goto quantize_cleanup;}
14321446

1447+
sqlite3_mutex_enter(qmutex);
14331448
rc = vector_rebuild_quantization(context, table_name, column_name, t_ctx, options.q_type, options.max_memory, &counter);
1434-
if (rc != SQLITE_OK) goto quantize_cleanup;
1435-
1436-
rc = sqlite3_exec(db, "COMMIT;", NULL, NULL, NULL);
1449+
sqlite3_mutex_leave(qmutex);
14371450
if (rc != SQLITE_OK) goto quantize_cleanup;
14381451

14391452
// serialize quantization options
@@ -1444,18 +1457,26 @@ static int vector_quantize (sqlite3_context *context, const char *table_name, co
14441457
rc = sqlite_serialize(context, table_name, column_name, SQLITE_FLOAT, OPTION_KEY_QUANTOFFSET, 0, t_ctx->offset);
14451458
if (rc != SQLITE_OK) goto quantize_cleanup;
14461459

1447-
quantize_cleanup:
1448-
if (rc != SQLITE_OK) {
1449-
printf("%s", sqlite3_errmsg(db));
1450-
sqlite3_exec(db, "ROLLBACK;", NULL, NULL, NULL);
1451-
sqlite3_result_error_code(context, rc);
1452-
return rc;
1453-
}
1460+
rc = sqlite3_exec(db, "RELEASE quantize;", NULL, NULL, NULL);
1461+
if (rc != SQLITE_OK) goto quantize_cleanup;
1462+
savepoint_open = false;
14541463

1455-
// returns the total number of quantized rows
1464+
// success: returns the total number of quantized rows
14561465
sqlite3_result_int64(context, (sqlite3_int64)counter);
14571466
if (was_preloaded) *was_preloaded = (t_ctx->preloaded != NULL);
14581467
return SQLITE_OK;
1468+
1469+
quantize_cleanup: {
1470+
const char *errmsg = sqlite3_errmsg(db);
1471+
if (savepoint_open) {
1472+
sqlite3_exec(db, "ROLLBACK TO quantize;", NULL, NULL, NULL);
1473+
sqlite3_exec(db, "RELEASE quantize;", NULL, NULL, NULL);
1474+
}
1475+
1476+
sqlite3_result_error(context, errmsg, -1);
1477+
sqlite3_result_error_code(context, rc);
1478+
return rc;
1479+
}
14591480
}
14601481

14611482
static void vector_quantize3 (sqlite3_context *context, int argc, sqlite3_value **argv) {
@@ -2558,6 +2579,15 @@ SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const s
25582579
#endif
25592580
int rc = SQLITE_OK;
25602581

2582+
// there's no built-in way to verify if sqlite3_vector_init has already been called for this specific database connection
2583+
// the workaround is to attempt to execute vector_version and check for an error
2584+
// an error indicates that initialization has not been performed
2585+
if (sqlite3_exec(db, "SELECT vector_version();", NULL, NULL, NULL) == SQLITE_OK) return SQLITE_OK;
2586+
2587+
// get an app global static mutex
2588+
qmutex = sqlite3_mutex_alloc(SQLITE_MUTEX_STATIC_APP1);
2589+
2590+
// init internal distance functions (do not force CPU)
25612591
init_distance_functions(false);
25622592

25632593
// create internal table

src/sqlite-vector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
extern "C" {
2525
#endif
2626

27-
#define SQLITE_VECTOR_VERSION "0.9.29"
27+
#define SQLITE_VECTOR_VERSION "0.9.30"
2828

2929
SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
3030

0 commit comments

Comments
 (0)