Skip to content

Commit 41af197

Browse files
committed
Improved zero rounding
1 parent 05a975f commit 41af197

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/sqlite-vector.c

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "distance-cpu.h"
1111

1212
#include <math.h>
13+
#include <float.h>
1314
#include <stdio.h>
1415
#include <ctype.h>
1516
#include <limits.h>
@@ -959,6 +960,10 @@ bool vector_keyvalue_callback (sqlite3_context *context, void *xdata, const char
959960
return true;
960961
}
961962

963+
static inline int nearly_zero_float32 (float x) {
964+
return fabsf(x) <= 8.0f * FLT_EPSILON; // tweak factor for your use
965+
}
966+
962967
// MARK: - SQL -
963968

964969
static char *generate_create_quant_table (const char *table_name, const char *column_name, char sql[STATIC_SQL_SIZE]) {
@@ -1930,7 +1935,8 @@ static int vFullScanRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1
19301935
if (rc != SQLITE_ROW) goto cleanup;
19311936

19321937
float *v2 = (float *)sqlite3_column_blob(vm, 1);
1933-
double distance = distance_fn((const void *)v1, (const void *)v2, dimension);
1938+
float distance = distance_fn((const void *)v1, (const void *)v2, dimension);
1939+
if (nearly_zero_float32(distance)) distance = 0.0;
19341940
VECTOR_PRINT((void*)v2, vt, dimension);
19351941

19361942
if (distance < c->distance[c->max_index]) {
@@ -1974,7 +1980,8 @@ static int vQuantRunMemory(vFullScanCursor *c, uint8_t *v, vector_qtype qtype, i
19741980
const uint8_t *vector_data = current_data + rowid_size;
19751981

19761982
float dist = distance_fn((const void *)v, (const void *)vector_data, dim);
1977-
1983+
if (nearly_zero_float32(dist)) dist = 0.0;
1984+
19781985
if (dist < current_max) {
19791986
distance[max_index] = dist;
19801987
rowids[max_index] = INT64_FROM_INT8PTR(current_data);
@@ -2046,7 +2053,8 @@ static int vQuantRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1siz
20462053
for (int i=0; i<counter; ++i) {
20472054
const uint8_t *current_data = data + (i * total_stride);
20482055
const uint8_t *vector_data = current_data + rowid_size;
2049-
double distance = (double)distance_fn((const void *)v, (const void *)vector_data, dimension);
2056+
float distance = distance_fn((const void *)v, (const void *)vector_data, dimension);
2057+
if (nearly_zero_float32(distance)) distance = 0.0;
20502058
VECTOR_PRINT((void*)vector_data, vt, dimension);
20512059

20522060
if (distance < current_max_distance) {

0 commit comments

Comments
 (0)