Skip to content

Commit def408d

Browse files
committed
Fix precision validation gaps and enhance insert() capabilities
This commit addresses multiple precision and validation issues identified in the codebase analysis: ## 1. Input Validation - Add NaN/Inf validation to insert() methods for both float32 and float64 - Ensures consistency with constructor validation - Prevents invalid data from entering the tree structure ## 2. Float64 Insert Support - Add float64 overload for insert() method - Maintains idx2exact map for dynamically inserted items - Preserves double-precision refinement capability for inserted boxes - Uses explicit py::overload_cast in Python bindings to handle overloads ## 3. Precision Testing - Add comprehensive tests for NaN/Inf validation in insert operations - Add tests for float64 insert() maintaining precision - Add tests verifying rebuild() preserves idx2exact - Add systematic precision boundary tests (adjusted for float32 limits) - Document float32 precision limitations in test comments ## Technical Notes - Float64 input is converted to float32 for tree structure - Double-precision refinement helps reduce false positives - Precision limits: gaps below ~1e-7 may not be reliably detected - At large magnitudes (e.g., 1e6), absolute precision degrades Fixes validation gaps in insert operations and maintains precision capabilities for dynamically updated trees.
1 parent ddb3f0c commit def408d

File tree

4 files changed

+482
-6
lines changed

4 files changed

+482
-6
lines changed

include/prtree/core/prtree.h

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ template <IndexType T, int B = 6, int D = 2> class PRTree {
308308
return obj;
309309
}
310310

311+
// Insert with float32 coordinates (no double-precision refinement)
311312
void insert(const T &idx, const py::array_t<float> &x,
312313
const std::optional<std::string> objdumps = std::nullopt) {
313314
// Phase 1: Thread-safety - protect entire insert operation
@@ -342,6 +343,15 @@ template <IndexType T, int B = 6, int D = 2> class PRTree {
342343
minima[i] = *x.data(i);
343344
maxima[i] = *x.data(i + D);
344345
}
346+
347+
// Validate bounding box (reject NaN/Inf, enforce min <= max)
348+
float coords[2 * D];
349+
for (int j = 0; j < D; ++j) {
350+
coords[j] = minima[j];
351+
coords[j + D] = maxima[j];
352+
}
353+
validate_box(coords, D);
354+
345355
bb = BB<D>(minima, maxima);
346356
}
347357
idx2bb.emplace(idx, bb);
@@ -437,6 +447,152 @@ template <IndexType T, int B = 6, int D = 2> class PRTree {
437447
#endif
438448
}
439449

450+
// Insert with float64 coordinates (maintains double-precision refinement)
451+
void insert(const T &idx, const py::array_t<double> &x,
452+
const std::optional<std::string> objdumps = std::nullopt) {
453+
// Phase 1: Thread-safety - protect entire insert operation
454+
std::lock_guard<std::recursive_mutex> lock(*tree_mutex_);
455+
456+
#ifdef MY_DEBUG
457+
ProfilerStart("insert.prof");
458+
std::cout << "profiler start of insert (float64)" << std::endl;
459+
#endif
460+
vec<size_t> cands;
461+
BB<D> bb;
462+
std::array<double, 2 * D> exact_coords;
463+
464+
const auto &buff_info_x = x.request();
465+
const auto &shape_x = buff_info_x.shape;
466+
const auto &ndim = buff_info_x.ndim;
467+
// Phase 4: Improved error messages with context
468+
if (unlikely((shape_x[0] != 2 * D || ndim != 1))) {
469+
throw std::runtime_error(
470+
"Invalid shape for bounding box array. Expected shape (" +
471+
std::to_string(2 * D) + ",) but got shape (" +
472+
std::to_string(shape_x[0]) + ",) with ndim=" + std::to_string(ndim));
473+
}
474+
auto it = idx2bb.find(idx);
475+
if (unlikely(it != idx2bb.end())) {
476+
throw std::runtime_error(
477+
"Index already exists in tree: " + std::to_string(idx));
478+
}
479+
{
480+
Real minima[D];
481+
Real maxima[D];
482+
483+
// Store exact double coordinates
484+
for (int i = 0; i < D; ++i) {
485+
double val_min = *x.data(i);
486+
double val_max = *x.data(i + D);
487+
exact_coords[i] = val_min;
488+
exact_coords[i + D] = val_max;
489+
}
490+
491+
// Validate bounding box with double precision (reject NaN/Inf, enforce min <= max)
492+
validate_box(exact_coords.data(), D);
493+
494+
// Convert to float32 for tree after validation
495+
for (int i = 0; i < D; ++i) {
496+
minima[i] = static_cast<Real>(exact_coords[i]);
497+
maxima[i] = static_cast<Real>(exact_coords[i + D]);
498+
}
499+
500+
bb = BB<D>(minima, maxima);
501+
}
502+
idx2bb.emplace(idx, bb);
503+
idx2exact[idx] = exact_coords; // Store exact coordinates for refinement
504+
set_obj(idx, objdumps);
505+
506+
Real delta[D];
507+
for (int i = 0; i < D; ++i) {
508+
delta[i] = bb.max(i) - bb.min(i) + 0.00000001;
509+
}
510+
511+
// find the leaf node to insert
512+
Real c = 0.0;
513+
size_t count = flat_tree.size();
514+
while (cands.empty()) {
515+
Real d[D];
516+
for (int i = 0; i < D; ++i) {
517+
d[i] = delta[i] * c;
518+
}
519+
bb.expand(d);
520+
c = (c + 1) * 2;
521+
522+
queue<size_t> que;
523+
auto qpush_if_intersect = [&](const size_t &i) {
524+
if (flat_tree[i](bb)) {
525+
que.emplace(i);
526+
}
527+
};
528+
529+
qpush_if_intersect(0);
530+
while (!que.empty()) {
531+
size_t i = que.front();
532+
que.pop();
533+
PRTreeElement<T, B, D> &elem = flat_tree[i];
534+
535+
if (elem.leaf && elem.leaf->mbb(bb)) {
536+
cands.push_back(i);
537+
} else {
538+
for (size_t offset = 0; offset < B; offset++) {
539+
size_t j = i * B + offset + 1;
540+
if (j < count)
541+
qpush_if_intersect(j);
542+
}
543+
}
544+
}
545+
}
546+
547+
if (unlikely(cands.empty()))
548+
throw std::runtime_error("cannnot determine where to insert");
549+
550+
// Now cands is the list of candidate leaf nodes to insert
551+
bb = idx2bb.at(idx);
552+
size_t min_leaf = 0;
553+
if (cands.size() == 1) {
554+
min_leaf = cands[0];
555+
} else {
556+
Real min_diff_area = 1e100;
557+
for (const auto &i : cands) {
558+
PRTreeLeaf<T, B, D> *leaf = flat_tree[i].leaf.get();
559+
PRTreeLeaf<T, B, D> tmp_leaf = PRTreeLeaf<T, B, D>(*leaf);
560+
Real diff_area = -tmp_leaf.area();
561+
tmp_leaf.push(idx, bb);
562+
diff_area += tmp_leaf.area();
563+
if (diff_area < min_diff_area) {
564+
min_diff_area = diff_area;
565+
min_leaf = i;
566+
}
567+
}
568+
}
569+
flat_tree[min_leaf].leaf->push(idx, bb);
570+
// update mbbs of all cands and their parents
571+
size_t i = min_leaf;
572+
while (true) {
573+
PRTreeElement<T, B, D> &elem = flat_tree[i];
574+
575+
if (elem.leaf)
576+
elem.mbb += elem.leaf->mbb;
577+
578+
if (i > 0) {
579+
size_t j = (i - 1) / B;
580+
flat_tree[j].mbb += flat_tree[i].mbb;
581+
}
582+
if (i == 0)
583+
break;
584+
i = (i - 1) / B;
585+
}
586+
587+
if (size() > REBUILD_THRE * n_at_build) {
588+
rebuild();
589+
}
590+
#ifdef MY_DEBUG
591+
ProfilerStop();
592+
std::cout << "profiler end of insert (float64)" << std::endl;
593+
#endif
594+
}
595+
440596
void rebuild() {
441597
// Phase 1: Thread-safety - protect entire rebuild operation
442598
std::lock_guard<std::recursive_mutex> lock(*tree_mutex_);

src/cpp/bindings/python_bindings.cc

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,21 @@ PYBIND11_MODULE(PRTree, m) {
4747
.def("get_obj", &PRTree<T, B, 2>::get_obj, R"pbdoc(
4848
Get string by index
4949
)pbdoc")
50-
.def("insert", &PRTree<T, B, 2>::insert, R"pbdoc(
51-
Insert one to prtree
50+
.def("insert",
51+
py::overload_cast<const T &, const py::array_t<float> &,
52+
const std::optional<std::string>>(
53+
&PRTree<T, B, 2>::insert),
54+
py::arg("idx"), py::arg("bb"), py::arg("obj") = py::none(),
55+
R"pbdoc(
56+
Insert one to prtree (float32)
57+
)pbdoc")
58+
.def("insert",
59+
py::overload_cast<const T &, const py::array_t<double> &,
60+
const std::optional<std::string>>(
61+
&PRTree<T, B, 2>::insert),
62+
py::arg("idx"), py::arg("bb"), py::arg("obj") = py::none(),
63+
R"pbdoc(
64+
Insert one to prtree (float64 with precision)
5265
)pbdoc")
5366
.def("save", &PRTree<T, B, 2>::save, R"pbdoc(
5467
cereal save
@@ -100,8 +113,21 @@ PYBIND11_MODULE(PRTree, m) {
100113
.def("get_obj", &PRTree<T, B, 3>::get_obj, R"pbdoc(
101114
Get string by index
102115
)pbdoc")
103-
.def("insert", &PRTree<T, B, 3>::insert, R"pbdoc(
104-
Insert one to prtree
116+
.def("insert",
117+
py::overload_cast<const T &, const py::array_t<float> &,
118+
const std::optional<std::string>>(
119+
&PRTree<T, B, 3>::insert),
120+
py::arg("idx"), py::arg("bb"), py::arg("obj") = py::none(),
121+
R"pbdoc(
122+
Insert one to prtree (float32)
123+
)pbdoc")
124+
.def("insert",
125+
py::overload_cast<const T &, const py::array_t<double> &,
126+
const std::optional<std::string>>(
127+
&PRTree<T, B, 3>::insert),
128+
py::arg("idx"), py::arg("bb"), py::arg("obj") = py::none(),
129+
R"pbdoc(
130+
Insert one to prtree (float64 with precision)
105131
)pbdoc")
106132
.def("save", &PRTree<T, B, 3>::save, R"pbdoc(
107133
cereal save
@@ -153,8 +179,21 @@ PYBIND11_MODULE(PRTree, m) {
153179
.def("get_obj", &PRTree<T, B, 4>::get_obj, R"pbdoc(
154180
Get string by index
155181
)pbdoc")
156-
.def("insert", &PRTree<T, B, 4>::insert, R"pbdoc(
157-
Insert one to prtree
182+
.def("insert",
183+
py::overload_cast<const T &, const py::array_t<float> &,
184+
const std::optional<std::string>>(
185+
&PRTree<T, B, 4>::insert),
186+
py::arg("idx"), py::arg("bb"), py::arg("obj") = py::none(),
187+
R"pbdoc(
188+
Insert one to prtree (float32)
189+
)pbdoc")
190+
.def("insert",
191+
py::overload_cast<const T &, const py::array_t<double> &,
192+
const std::optional<std::string>>(
193+
&PRTree<T, B, 4>::insert),
194+
py::arg("idx"), py::arg("bb"), py::arg("obj") = py::none(),
195+
R"pbdoc(
196+
Insert one to prtree (float64 with precision)
158197
)pbdoc")
159198
.def("save", &PRTree<T, B, 4>::save, R"pbdoc(
160199
cereal save

tests/unit/test_insert.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,50 @@ def test_insert_with_invalid_box(self, PRTree, dim):
108108
with pytest.raises((ValueError, RuntimeError)):
109109
tree.insert(idx=1, bb=box)
110110

111+
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)])
112+
def test_insert_with_nan_coordinates_float32(self, PRTree, dim):
113+
"""Verify that insert with NaN coordinates (float32) raises an error."""
114+
tree = PRTree()
115+
116+
box = np.zeros(2 * dim, dtype=np.float32)
117+
box[0] = np.nan
118+
119+
with pytest.raises((ValueError, RuntimeError)):
120+
tree.insert(idx=1, bb=box)
121+
122+
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)])
123+
def test_insert_with_nan_coordinates_float64(self, PRTree, dim):
124+
"""Verify that insert with NaN coordinates (float64) raises an error."""
125+
tree = PRTree()
126+
127+
box = np.zeros(2 * dim, dtype=np.float64)
128+
box[0] = np.nan
129+
130+
with pytest.raises((ValueError, RuntimeError)):
131+
tree.insert(idx=1, bb=box)
132+
133+
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)])
134+
def test_insert_with_inf_coordinates_float32(self, PRTree, dim):
135+
"""Verify that insert with Inf coordinates (float32) raises an error."""
136+
tree = PRTree()
137+
138+
box = np.zeros(2 * dim, dtype=np.float32)
139+
box[0] = np.inf
140+
141+
with pytest.raises((ValueError, RuntimeError)):
142+
tree.insert(idx=1, bb=box)
143+
144+
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)])
145+
def test_insert_with_inf_coordinates_float64(self, PRTree, dim):
146+
"""Verify that insert with Inf coordinates (float64) raises an error."""
147+
tree = PRTree()
148+
149+
box = np.zeros(2 * dim, dtype=np.float64)
150+
box[0] = np.inf
151+
152+
with pytest.raises((ValueError, RuntimeError)):
153+
tree.insert(idx=1, bb=box)
154+
111155

112156
class TestConsistencyInsert:
113157
"""Test insert consistency."""
@@ -162,3 +206,98 @@ def test_incremental_construction(self, PRTree, dim):
162206
result2 = tree2.query(query_box)
163207

164208
assert set(result1) == set(result2)
209+
210+
211+
class TestPrecisionInsert:
212+
"""Test insert with precision requirements."""
213+
214+
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)])
215+
def test_insert_float64_maintains_precision(self, PRTree, dim):
216+
"""Verify that float64 insert maintains double-precision refinement."""
217+
# Create tree with float64 construction
218+
A = np.zeros((1, 2 * dim), dtype=np.float64)
219+
A[0, 0] = 0.0
220+
A[0, dim] = 75.02750896
221+
for i in range(1, dim):
222+
A[0, i] = 0.0
223+
A[0, i + dim] = 100.0
224+
225+
tree = PRTree(np.array([0], dtype=np.int64), A)
226+
227+
# Insert with float64 (small gap)
228+
B = np.zeros(2 * dim, dtype=np.float64)
229+
B[0] = 75.02751435
230+
B[dim] = 100.0
231+
for i in range(1, dim):
232+
B[i] = 0.0
233+
B[i + dim] = 100.0
234+
235+
tree.insert(idx=1, bb=B)
236+
237+
# Query should not find intersection due to small gap
238+
result = tree.query(B)
239+
assert 0 not in result, "Should not find item 0 due to small gap with float64 precision"
240+
assert 1 in result, "Should find item 1 (self)"
241+
242+
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)])
243+
def test_insert_float32_loses_precision(self, PRTree, dim):
244+
"""Verify that float32 insert may lose precision for small gaps."""
245+
# Create tree with float64 construction
246+
A = np.zeros((1, 2 * dim), dtype=np.float64)
247+
A[0, 0] = 0.0
248+
A[0, dim] = 75.02750896
249+
for i in range(1, dim):
250+
A[0, i] = 0.0
251+
A[0, i + dim] = 100.0
252+
253+
tree = PRTree(np.array([0], dtype=np.int64), A)
254+
255+
# Insert with float32 (small gap, may cause false positive)
256+
B = np.zeros(2 * dim, dtype=np.float32)
257+
B[0] = 75.02751435
258+
B[dim] = 100.0
259+
for i in range(1, dim):
260+
B[i] = 0.0
261+
B[i + dim] = 100.0
262+
263+
tree.insert(idx=1, bb=B)
264+
265+
# Query - item 1 won't have exact coordinates, so refinement won't apply to it
266+
result = tree.query(B)
267+
assert 1 in result, "Should find item 1 (self)"
268+
269+
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3)])
270+
def test_rebuild_preserves_idx2exact(self, PRTree, dim):
271+
"""Verify that rebuild() preserves idx2exact for precision."""
272+
# Create tree with float64 to populate idx2exact
273+
n = 10
274+
idx = np.arange(n, dtype=np.int64)
275+
boxes = np.random.rand(n, 2 * dim) * 100
276+
boxes = boxes.astype(np.float64)
277+
for i in range(dim):
278+
boxes[:, i + dim] += boxes[:, i] + 1
279+
280+
tree = PRTree(idx, boxes)
281+
282+
# Insert more items to trigger rebuild
283+
for i in range(n, n + 100):
284+
box = np.random.rand(2 * dim) * 100
285+
box = box.astype(np.float64)
286+
for d in range(dim):
287+
box[d + dim] += box[d] + 1
288+
tree.insert(idx=i, bb=box)
289+
290+
# Create a small-gap query that should only work with float64 refinement
291+
# Query box is to the right of boxes[0] with a small gap
292+
query = np.zeros(2 * dim, dtype=np.float64)
293+
query[0] = boxes[0, dim] + 1e-6 # Small gap after original box's max
294+
query[dim] = boxes[0, dim] + 10.0 # Query max
295+
for i in range(1, dim):
296+
# Overlap in other dimensions
297+
query[i] = boxes[0, i] - 10
298+
query[i + dim] = boxes[0, i + dim] + 10
299+
300+
result = tree.query(query)
301+
# Should not find item 0 if idx2exact is preserved and working
302+
# The gap of 1e-6 should be detected with float64 precision
303+
assert 0 not in result, "Should not find item 0 due to small gap (idx2exact should be preserved after rebuild)"

0 commit comments

Comments
 (0)