|
10 | 10 | #include "distance-cpu.h" |
11 | 11 |
|
12 | 12 | #include <math.h> |
| 13 | +#include <float.h> |
13 | 14 | #include <stdio.h> |
14 | 15 | #include <ctype.h> |
15 | 16 | #include <limits.h> |
@@ -959,6 +960,10 @@ bool vector_keyvalue_callback (sqlite3_context *context, void *xdata, const char |
959 | 960 | return true; |
960 | 961 | } |
961 | 962 |
|
| 963 | +static inline int nearly_zero_float32 (float x) { |
| 964 | + return fabsf(x) <= 8.0f * FLT_EPSILON; // tweak factor for your use |
| 965 | +} |
| 966 | + |
962 | 967 | // MARK: - SQL - |
963 | 968 |
|
964 | 969 | static char *generate_create_quant_table (const char *table_name, const char *column_name, char sql[STATIC_SQL_SIZE]) { |
@@ -1930,7 +1935,8 @@ static int vFullScanRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1 |
1930 | 1935 | if (rc != SQLITE_ROW) goto cleanup; |
1931 | 1936 |
|
1932 | 1937 | float *v2 = (float *)sqlite3_column_blob(vm, 1); |
1933 | | - double distance = distance_fn((const void *)v1, (const void *)v2, dimension); |
| 1938 | + float distance = distance_fn((const void *)v1, (const void *)v2, dimension); |
| 1939 | + if (nearly_zero_float32(distance)) distance = 0.0; |
1934 | 1940 | VECTOR_PRINT((void*)v2, vt, dimension); |
1935 | 1941 |
|
1936 | 1942 | if (distance < c->distance[c->max_index]) { |
@@ -1974,7 +1980,8 @@ static int vQuantRunMemory(vFullScanCursor *c, uint8_t *v, vector_qtype qtype, i |
1974 | 1980 | const uint8_t *vector_data = current_data + rowid_size; |
1975 | 1981 |
|
1976 | 1982 | float dist = distance_fn((const void *)v, (const void *)vector_data, dim); |
1977 | | - |
| 1983 | + if (nearly_zero_float32(dist)) dist = 0.0; |
| 1984 | + |
1978 | 1985 | if (dist < current_max) { |
1979 | 1986 | distance[max_index] = dist; |
1980 | 1987 | rowids[max_index] = INT64_FROM_INT8PTR(current_data); |
@@ -2046,7 +2053,8 @@ static int vQuantRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1siz |
2046 | 2053 | for (int i=0; i<counter; ++i) { |
2047 | 2054 | const uint8_t *current_data = data + (i * total_stride); |
2048 | 2055 | const uint8_t *vector_data = current_data + rowid_size; |
2049 | | - double distance = (double)distance_fn((const void *)v, (const void *)vector_data, dimension); |
| 2056 | + float distance = distance_fn((const void *)v, (const void *)vector_data, dimension); |
| 2057 | + if (nearly_zero_float32(distance)) distance = 0.0; |
2050 | 2058 | VECTOR_PRINT((void*)vector_data, vt, dimension); |
2051 | 2059 |
|
2052 | 2060 | if (distance < current_max_distance) { |
|
0 commit comments