Skip to content

Commit 91af775

Browse files
committed
Added support for signed 8bit quantization
1 parent ecfd56e commit 91af775

File tree

3 files changed

+163
-18
lines changed

3 files changed

+163
-18
lines changed

src/distance-cpu.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ typedef enum {
2121
#define VECTOR_TYPE_MAX 6
2222

2323
typedef enum {
24-
VECTOR_QUANT_8BIT = 1
24+
VECTOR_QUANT_AUTO = 0,
25+
VECTOR_QUANT_U8BIT = 1,
26+
VECTOR_QUANT_S8BIT = 2
2527
} vector_qtype;
2628

2729
typedef enum {

src/sqlite-vector.c

Lines changed: 159 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,12 @@ SQLITE_EXTENSION_INIT1
9494
#define OPTION_KEY_DIMENSION "dimension"
9595
#define OPTION_KEY_NORMALIZED "normalized"
9696
#define OPTION_KEY_MAXMEMORY "max_memory"
97-
#define OPTION_KEY_QUANTTYPE "quant_type"
9897
#define OPTION_KEY_DISTANCE "distance"
98+
#define OPTION_KEY_QUANTTYPE "qtype"
99+
#define OPTION_KEY_QUANTSCALE "qscale" // used only in serialize/unserialize
100+
#define OPTION_KEY_QUANTOFFSET "qoffset" // used only in serialize/unserialize
101+
102+
#define VECTOR_INTERNAL_TABLE "CREATE TABLE IF NOT EXISTS __vector (tblname TEXT, colname TEXT, key TEXT, value ANY, PRIMARY KEY(tblname, colname, key));"
99103

100104
typedef struct {
101105
vector_type v_type; // vector type
@@ -472,9 +476,83 @@ static void *sqlite_common_set_error (sqlite3_context *context, sqlite3_vtab *vt
472476
return NULL;
473477
}
474478

479+
static int sqlite_serialize (sqlite3_context *context, const char *table_name, const char *column_name, int type, const char *key, int64_t ivalue, double fvalue) {
480+
const char *sql = "REPLACE INTO __vector (tblname, colname, key, value) VALUES (?, ?, ?, ?);";
481+
sqlite3 *db = sqlite3_context_db_handle(context);
482+
sqlite3_stmt *vm = NULL;
483+
484+
int rc = sqlite3_prepare_v2(db, sql, -1, &vm, NULL);
485+
if (rc != SQLITE_OK) goto cleanup;
486+
487+
rc = sqlite3_bind_text(vm, 1, table_name, -1, SQLITE_STATIC);
488+
if (rc != SQLITE_OK) goto cleanup;
489+
490+
rc = sqlite3_bind_text(vm, 2, column_name, -1, SQLITE_STATIC);
491+
if (rc != SQLITE_OK) goto cleanup;
492+
493+
rc = sqlite3_bind_text(vm, 3, key, -1, SQLITE_STATIC);
494+
if (rc != SQLITE_OK) goto cleanup;
495+
496+
switch (type) {
497+
case SQLITE_INTEGER: rc = sqlite3_bind_int64(vm, 4, (sqlite3_int64)ivalue); break;
498+
case SQLITE_FLOAT: rc = sqlite3_bind_double(vm, 4, fvalue); break;
499+
}
500+
if (rc != SQLITE_OK) goto cleanup;
501+
502+
rc = sqlite3_step(vm);
503+
if (rc == SQLITE_DONE) rc = SQLITE_OK;
504+
505+
cleanup:
506+
if (rc != SQLITE_OK) sqlite3_result_error(context, sqlite3_errmsg(db), -1);
507+
if (vm) sqlite3_finalize(vm);
508+
return rc;
509+
}
510+
511+
static int sqlite_unserialize (sqlite3_context *context, table_context *ctx) {
512+
const char *sql = "SELECT key, value FROM __vector WHERE tblname = ? AND colname = ?;";
513+
sqlite3 *db = sqlite3_context_db_handle(context);
514+
sqlite3_stmt *vm = NULL;
515+
516+
int rc = sqlite3_prepare_v2(db, sql, -1, &vm, NULL);
517+
if (rc != SQLITE_OK) goto cleanup;
518+
519+
rc = sqlite3_bind_text(vm, 1, ctx->t_name, -1, SQLITE_STATIC);
520+
if (rc != SQLITE_OK) goto cleanup;
521+
522+
rc = sqlite3_bind_text(vm, 2, ctx->c_name, -1, SQLITE_STATIC);
523+
if (rc != SQLITE_OK) goto cleanup;
524+
525+
while (1) {
526+
rc = sqlite3_step(vm);
527+
if (rc == SQLITE_DONE) {rc = SQLITE_OK; break;}
528+
if (rc != SQLITE_ROW) break;
529+
530+
const char *key = (const char *)sqlite3_column_text(vm, 0);
531+
if (strcmp(key, OPTION_KEY_QUANTTYPE) == 0) {
532+
ctx->options.q_type = (vector_qtype)sqlite3_column_int(vm, 1);
533+
continue;
534+
}
535+
536+
if (strcmp(key, OPTION_KEY_QUANTSCALE) == 0) {
537+
ctx->scale = (float)sqlite3_column_double(vm, 1);
538+
continue;
539+
}
540+
541+
if (strcmp(key, OPTION_KEY_QUANTOFFSET) == 0) {
542+
ctx->offset = (float)sqlite3_column_double(vm, 1);
543+
continue;
544+
}
545+
}
546+
547+
cleanup:
548+
//if (rc != SQLITE_OK) sqlite3_result_error(context, sqlite3_errmsg(db), -1);
549+
if (vm) sqlite3_finalize(vm);
550+
return rc;
551+
}
552+
475553
// MARK: - General Utils -
476554

477-
static inline void quantize_float32_to_u8 (float *v, uint8_t *q, float offset, float scale, int n) {
555+
static inline void quantize_float32_to_unsigned8bit (float *v, uint8_t *q, float offset, float scale, int n) {
478556
int i = 0;
479557
for (; i + 3 < n; i += 4) {
480558
float s0 = (v[i] - offset) * scale;
@@ -507,6 +585,38 @@ static inline void quantize_float32_to_u8 (float *v, uint8_t *q, float offset, f
507585
}
508586
}
509587

588+
static inline void quantize_float32_to_signed8bit (float *v, int8_t *q, float offset, float scale, int n) {
589+
int i = 0;
590+
for (; i + 3 < n; i += 4) {
591+
float s0 = (v[i] - offset) * scale;
592+
float s1 = (v[i + 1] - offset) * scale;
593+
float s2 = (v[i + 2] - offset) * scale;
594+
float s3 = (v[i + 3] - offset) * scale;
595+
596+
int r0 = (int)(s0 + 0.5f * (1.0f - 2.0f * (s0 < 0.0f)));
597+
int r1 = (int)(s1 + 0.5f * (1.0f - 2.0f * (s1 < 0.0f)));
598+
int r2 = (int)(s2 + 0.5f * (1.0f - 2.0f * (s2 < 0.0f)));
599+
int r3 = (int)(s3 + 0.5f * (1.0f - 2.0f * (s3 < 0.0f)));
600+
601+
r0 = r0 > 127 ? 127 : (r0 < -128 ? -128 : r0);
602+
r1 = r1 > 127 ? 127 : (r1 < -128 ? -128 : r1);
603+
r2 = r2 > 127 ? 127 : (r2 < -128 ? -128 : r2);
604+
r3 = r3 > 127 ? 127 : (r3 < -128 ? -128 : r3);
605+
606+
q[i] = (int8_t)r0;
607+
q[i + 1] = (int8_t)r1;
608+
q[i + 2] = (int8_t)r2;
609+
q[i + 3] = (int8_t)r3;
610+
}
611+
612+
for (; i < n; ++i) {
613+
float scaled = (v[i] - offset) * scale;
614+
int rounded = (int)(scaled + 0.5f * (1.0f - 2.0f * (scaled < 0.0f)));
615+
rounded = rounded > 127 ? 127 : (rounded < -128 ? -128 : rounded);
616+
q[i] = (int8_t)rounded;
617+
}
618+
}
619+
510620
static size_t vector_type_to_size (vector_type type) {
511621
switch (type) {
512622
case VECTOR_TYPE_F32: return sizeof(float);
@@ -539,8 +649,9 @@ const char *vector_type_to_name (vector_type type) {
539649
}
540650

541651
static vector_qtype quant_name_to_type (const char *qname) {
542-
if (strcasecmp(qname, "QUANTU8") == 0) return VECTOR_QUANT_8BIT;
543-
return 0;
652+
if (strcasecmp(qname, "UINT8") == 0) return VECTOR_QUANT_U8BIT;
653+
if (strcasecmp(qname, "INT8") == 0) return VECTOR_QUANT_S8BIT;
654+
return -1;
544655
}
545656

546657
static vector_distance distance_name_to_type (const char *dname) {
@@ -683,7 +794,7 @@ bool vector_keyvalue_callback (sqlite3_context *context, void *xdata, const char
683794

684795
if (strncasecmp(key, OPTION_KEY_QUANTTYPE, key_len) == 0) {
685796
vector_qtype type = quant_name_to_type(buffer);
686-
if (type == 0) return context_result_error(context, SQLITE_ERROR, "Invalid quantization type: '%s' is not a recognized or supported quantization type.", buffer);
797+
if (type == -1) return context_result_error(context, SQLITE_ERROR, "Invalid quantization type: '%s' is not a recognized or supported quantization type.", buffer);
687798
options->q_type = type;
688799
return true;
689800
}
@@ -736,6 +847,7 @@ void *vector_context_create (void) {
736847
}
737848

738849
void vector_context_free (void *p) {
850+
return;
739851
if (p) {
740852
vector_context *ctx = (vector_context *)p;
741853
for (int i=0; i<ctx->table_count; ++i) {
@@ -791,14 +903,16 @@ void vector_context_add (sqlite3_context *context, vector_context *ctx, const ch
791903
ctx->tables[index].pk_name = prikey;
792904
ctx->tables[index].options = *options;
793905
ctx->table_count++;
906+
907+
sqlite_unserialize(context, &ctx->tables[index]);
794908
}
795909

796910
void vector_options_init (vector_options *options) {
797911
memset(options, 0, sizeof(vector_options));
798912
options->v_type = VECTOR_TYPE_F32;
799913
options->v_distance = VECTOR_DISTANCE_L2;
800914
options->max_memory = DEFAULT_MAX_MEMORY;
801-
options->q_type = VECTOR_QUANT_8BIT;
915+
options->q_type = VECTOR_QUANT_AUTO;
802916
}
803917

804918
vector_options vector_options_create (void) {
@@ -889,6 +1003,7 @@ static int vector_rebuild_quantization (sqlite3_context *context, const char *ta
8891003
float max_val = -MAXFLOAT;
8901004
#endif
8911005

1006+
bool contains_negative = false;
8921007
while (1) {
8931008
rc = sqlite3_step(vm);
8941009
if (rc == SQLITE_DONE) {rc = SQLITE_OK; break;}
@@ -927,15 +1042,24 @@ static int vector_rebuild_quantization (sqlite3_context *context, const char *ta
9271042
}
9281043
if (val < min_val) min_val = val;
9291044
if (val > max_val) max_val = val;
1045+
if (val < 0.0) contains_negative = true;
9301046
}
9311047
}
9321048

1049+
// set proper format
1050+
if (qtype == VECTOR_QUANT_AUTO) {
1051+
if (contains_negative == true) qtype = VECTOR_QUANT_S8BIT;
1052+
else qtype = VECTOR_QUANT_U8BIT;
1053+
}
1054+
9331055
// STEP 2
934-
// compute scale and offset and set table them to table context
935-
// standard min-max linear quantization
936-
float scale = 255.0f / (max_val - min_val);
937-
float offset = min_val;
1056+
// compute scale and offset and set table them to table context standard min-max linear quantization
1057+
float abs_max = fmaxf(fabsf(min_val), fabsf(max_val)); // only used in VECTOR_QUANT_S8BIT
1058+
float scale = (qtype == VECTOR_QUANT_U8BIT) ? (255.0f / (max_val - min_val)) : (127.0f / abs_max);
1059+
// in the VECTOR_QUANT_S8BIT version I am assuming a symmetric quantization, for asymmetric quantization min_val should be used
1060+
float offset = (qtype == VECTOR_QUANT_U8BIT) ? min_val : 0.0f;
9381061

1062+
t_ctx->options.q_type = qtype;
9391063
t_ctx->scale = scale;
9401064
t_ctx->offset = offset;
9411065

@@ -986,7 +1110,8 @@ static int vector_rebuild_quantization (sqlite3_context *context, const char *ta
9861110
data += sizeof(int64_t);
9871111

9881112
// quantize vector
989-
quantize_float32_to_u8(v, data, offset, scale, dim);
1113+
if (qtype == VECTOR_QUANT_U8BIT) quantize_float32_to_unsigned8bit(v, data, offset, scale, dim);
1114+
else quantize_float32_to_signed8bit(v, (int8_t *)data, offset, scale, dim);
9901115
data += (dim * sizeof(uint8_t));
9911116

9921117
max_rowid = rowid;
@@ -1048,6 +1173,14 @@ static void vector_quantize (sqlite3_context *context, const char *table_name, c
10481173
rc = sqlite3_exec(db, "COMMIT;", NULL, NULL, NULL);
10491174
if (rc != SQLITE_OK) goto quantize_cleanup;
10501175

1176+
// serialize quantization options
1177+
rc = sqlite_serialize(context, table_name, column_name, SQLITE_INTEGER, OPTION_KEY_QUANTTYPE, t_ctx->options.q_type, 0);
1178+
if (rc != SQLITE_OK) goto quantize_cleanup;
1179+
rc = sqlite_serialize(context, table_name, column_name, SQLITE_FLOAT, OPTION_KEY_QUANTSCALE, 0, t_ctx->scale);
1180+
if (rc != SQLITE_OK) goto quantize_cleanup;
1181+
rc = sqlite_serialize(context, table_name, column_name, SQLITE_FLOAT, OPTION_KEY_QUANTOFFSET, 0, t_ctx->offset);
1182+
if (rc != SQLITE_OK) goto quantize_cleanup;
1183+
10511184
quantize_cleanup:
10521185
if (rc != SQLITE_OK) {
10531186
printf("%s", sqlite3_errmsg(db));
@@ -1658,7 +1791,7 @@ static int vFullScanCursorFilter (sqlite3_vtab_cursor *cur, int idxNum, const ch
16581791

16591792
// MARK: -
16601793

1661-
static int vQuantRunMemory(vFullScanCursor *c, uint8_t *v, int dim) {
1794+
static int vQuantRunMemory(vFullScanCursor *c, uint8_t *v, vector_qtype qtype, int dim) {
16621795
const int counter = c->table->precounter;
16631796
const uint8_t *data = c->table->preloaded;
16641797
const size_t rowid_size = sizeof(int64_t);
@@ -1672,7 +1805,7 @@ static int vQuantRunMemory(vFullScanCursor *c, uint8_t *v, int dim) {
16721805

16731806
// compute distance function
16741807
vector_distance vd = c->table->options.v_distance;
1675-
vector_type vt = VECTOR_TYPE_U8;
1808+
vector_type vt = (qtype == VECTOR_QUANT_U8BIT) ? VECTOR_TYPE_U8 : VECTOR_TYPE_I8;
16761809
distance_function_t distance_fn = dispatch_distance_table[vd][vt];
16771810

16781811
for (int i = 0; i < counter; ++i) {
@@ -1701,9 +1834,14 @@ static int vQuantRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1siz
17011834
uint8_t *v = (uint8_t *)sqlite3_malloc(dimension * sizeof(int8_t));
17021835
if (!v) return SQLITE_NOMEM;
17031836

1704-
quantize_float32_to_u8((float *)v1, v, c->table->offset, c->table->scale, dimension);
1837+
vector_qtype qtype = c->table->options.q_type;
1838+
if (qtype == VECTOR_QUANT_U8BIT) {
1839+
quantize_float32_to_unsigned8bit((float *)v1, v, c->table->offset, c->table->scale, dimension);
1840+
} else {
1841+
quantize_float32_to_signed8bit((float *)v1, (int8_t *)v, c->table->offset, c->table->scale, dimension);
1842+
}
17051843
if (c->table->preloaded) {
1706-
int rc = vQuantRunMemory(c, v, dimension);
1844+
int rc = vQuantRunMemory(c, v, qtype, dimension);
17071845
if (v) sqlite3_free(v);
17081846
return rc;
17091847
}
@@ -1721,7 +1859,7 @@ static int vQuantRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1siz
17211859

17221860
// compute distance function
17231861
vector_distance vd = c->table->options.v_distance;
1724-
vector_type vt = VECTOR_TYPE_U8;
1862+
vector_type vt = (qtype == VECTOR_QUANT_U8BIT) ? VECTOR_TYPE_U8 : VECTOR_TYPE_I8;
17251863
distance_function_t distance_fn = dispatch_distance_table[vd][vt];
17261864

17271865
while (1) {
@@ -1898,6 +2036,11 @@ SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const s
18982036

18992037
init_distance_functions(false);
19002038

2039+
// TODO: error message must be duplicate here?
2040+
// create internal table
2041+
rc = sqlite3_exec(db, VECTOR_INTERNAL_TABLE, NULL, NULL, NULL);
2042+
if (rc != SQLITE_OK) return rc;
2043+
19012044
// init context
19022045
void *ctx = vector_context_create();
19032046
if (!ctx) {

src/sqlite-vector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
extern "C" {
2525
#endif
2626

27-
#define SQLITE_VECTOR_VERSION "0.8.8"
27+
#define SQLITE_VECTOR_VERSION "0.8.9"
2828

2929
SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
3030

0 commit comments

Comments
 (0)