Skip to content

Commit 15ade9f

Browse files
committed
Complete architectural refactoring for native precision support
This commit eliminates the complex idx2exact post-processing architecture and replaces it with native float32/float64 template specialization, significantly simplifying the codebase and optimizing for each precision level. Key Changes: - Templated PRTree with Real type parameter (float or double) - Removed idx2exact map and refine_candidates() complexity entirely - Exposed 6 separate C++ classes (_PRTree{2D,3D,4D}_{float32,float64}) - Added automatic dtype-based precision selection in Python wrapper - Propagated Real template parameter through all detail classes: - BB (bounding_box.h) - DataType (data_type.h) - PRTreeNode, PRTreeLeaf, PRTreeElement (nodes.h) - PseudoPRTree, PseudoPRTreeNode (pseudo_tree.h) Benefits: - Eliminates "strange post-processing" that forced float32 on all users - Each precision level now uses native types throughout - Simpler, more maintainable codebase - Better performance through compile-time type optimization - Users get the precision they request without conversion overhead All tests passing with new architecture.
1 parent 6548882 commit 15ade9f

File tree

7 files changed

+549
-621
lines changed

7 files changed

+549
-621
lines changed

include/prtree/core/detail/bounding_box.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414

1515
#include "prtree/core/detail/types.h"
1616

17-
using Real = float;
18-
19-
template <int D = 2> class BB {
17+
template <int D = 2, typename Real = float> class BB {
2018
private:
2119
Real values[2 * D];
2220

include/prtree/core/detail/data_type.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@
1313
#include "prtree/core/detail/types.h"
1414

1515
// Phase 8: Apply C++20 concept constraints
16-
template <IndexType T, int D = 2> class DataType {
16+
template <IndexType T, int D = 2, typename Real = float> class DataType {
1717
public:
18-
BB<D> second;
18+
BB<D, Real> second;
1919
T first;
2020

2121
DataType() noexcept = default;
2222

23-
DataType(const T &f, const BB<D> &s) {
23+
DataType(const T &f, const BB<D, Real> &s) {
2424
first = f;
2525
second = s;
2626
}
2727

28-
DataType(T &&f, BB<D> &&s) noexcept {
28+
DataType(T &&f, BB<D, Real> &&s) noexcept {
2929
first = std::move(f);
3030
second = std::move(s);
3131
}
@@ -39,9 +39,9 @@ template <IndexType T, int D = 2> class DataType {
3939
template <class Archive> void serialize(Archive &ar) { ar(first, second); }
4040
};
4141

42-
template <class T, int D = 2>
43-
void clean_data(DataType<T, D> *b, DataType<T, D> *e) {
44-
for (DataType<T, D> *it = e - 1; it >= b; --it) {
45-
it->~DataType<T, D>();
42+
template <class T, int D = 2, typename Real = float>
43+
void clean_data(DataType<T, D, Real> *b, DataType<T, D, Real> *e) {
44+
for (DataType<T, D, Real> *it = e - 1; it >= b; --it) {
45+
it->~DataType<T, D, Real>();
4646
}
4747
}

include/prtree/core/detail/nodes.h

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
#include "prtree/core/detail/types.h"
1818

1919
// Phase 8: Apply C++20 concept constraints
20-
template <IndexType T, int B = 6, int D = 2> class PRTreeLeaf {
20+
template <IndexType T, int B = 6, int D = 2, typename Real = float> class PRTreeLeaf {
2121
public:
22-
BB<D> mbb;
23-
svec<DataType<T, D>, B> data;
22+
BB<D, Real> mbb;
23+
svec<DataType<T, D, Real>, B> data;
2424

25-
PRTreeLeaf() { mbb = BB<D>(); }
25+
PRTreeLeaf() { mbb = BB<D, Real>(); }
2626

27-
PRTreeLeaf(const Leaf<T, B, D> &leaf) {
27+
PRTreeLeaf(const Leaf<T, B, D, Real> &leaf) {
2828
mbb = leaf.mbb;
2929
data = leaf.data;
3030
}
@@ -38,7 +38,7 @@ template <IndexType T, int B = 6, int D = 2> class PRTreeLeaf {
3838
}
3939
}
4040

41-
void operator()(const BB<D> &target, vec<T> &out) const {
41+
void operator()(const BB<D, Real> &target, vec<T> &out) const {
4242
if (mbb(target)) {
4343
for (const auto &x : data) {
4444
if (x.second(target)) {
@@ -48,7 +48,7 @@ template <IndexType T, int B = 6, int D = 2> class PRTreeLeaf {
4848
}
4949
}
5050

51-
void del(const T &key, const BB<D> &target) {
51+
void del(const T &key, const BB<D, Real> &target) {
5252
if (mbb(target)) {
5353
auto remove_it =
5454
std::remove_if(data.begin(), data.end(), [&](auto &datum) {
@@ -58,21 +58,21 @@ template <IndexType T, int B = 6, int D = 2> class PRTreeLeaf {
5858
}
5959
}
6060

61-
void push(const T &key, const BB<D> &target) {
61+
void push(const T &key, const BB<D, Real> &target) {
6262
data.emplace_back(key, target);
6363
update_mbb();
6464
}
6565

6666
template <class Archive> void save(Archive &ar) const {
67-
vec<DataType<T, D>> _data;
67+
vec<DataType<T, D, Real>> _data;
6868
for (const auto &datum : data) {
6969
_data.push_back(datum);
7070
}
7171
ar(mbb, _data);
7272
}
7373

7474
template <class Archive> void load(Archive &ar) {
75-
vec<DataType<T, D>> _data;
75+
vec<DataType<T, D, Real>> _data;
7676
ar(mbb, _data);
7777
for (const auto &datum : _data) {
7878
data.push_back(datum);
@@ -81,63 +81,63 @@ template <IndexType T, int B = 6, int D = 2> class PRTreeLeaf {
8181
};
8282

8383
// Phase 8: Apply C++20 concept constraints
84-
template <IndexType T, int B = 6, int D = 2> class PRTreeNode {
84+
template <IndexType T, int B = 6, int D = 2, typename Real = float> class PRTreeNode {
8585
public:
86-
BB<D> mbb;
87-
std::unique_ptr<Leaf<T, B, D>> leaf;
88-
std::unique_ptr<PRTreeNode<T, B, D>> head, next;
86+
BB<D, Real> mbb;
87+
std::unique_ptr<Leaf<T, B, D, Real>> leaf;
88+
std::unique_ptr<PRTreeNode<T, B, D, Real>> head, next;
8989

9090
PRTreeNode() {}
91-
PRTreeNode(const BB<D> &_mbb) { mbb = _mbb; }
91+
PRTreeNode(const BB<D, Real> &_mbb) { mbb = _mbb; }
9292

93-
PRTreeNode(BB<D> &&_mbb) noexcept { mbb = std::move(_mbb); }
93+
PRTreeNode(BB<D, Real> &&_mbb) noexcept { mbb = std::move(_mbb); }
9494

95-
PRTreeNode(Leaf<T, B, D> *l) {
96-
leaf = std::make_unique<Leaf<T, B, D>>();
95+
PRTreeNode(Leaf<T, B, D, Real> *l) {
96+
leaf = std::make_unique<Leaf<T, B, D, Real>>();
9797
mbb = l->mbb;
9898
leaf->mbb = std::move(l->mbb);
9999
leaf->data = std::move(l->data);
100100
}
101101

102-
bool operator()(const BB<D> &target) { return mbb(target); }
102+
bool operator()(const BB<D, Real> &target) { return mbb(target); }
103103
};
104104

105105
// Phase 8: Apply C++20 concept constraints
106-
template <IndexType T, int B = 6, int D = 2> class PRTreeElement {
106+
template <IndexType T, int B = 6, int D = 2, typename Real = float> class PRTreeElement {
107107
public:
108-
BB<D> mbb;
109-
std::unique_ptr<PRTreeLeaf<T, B, D>> leaf;
108+
BB<D, Real> mbb;
109+
std::unique_ptr<PRTreeLeaf<T, B, D, Real>> leaf;
110110
bool is_used = false;
111111

112112
PRTreeElement() {
113-
mbb = BB<D>();
113+
mbb = BB<D, Real>();
114114
is_used = false;
115115
}
116116

117-
PRTreeElement(const PRTreeNode<T, B, D> &node) {
118-
mbb = BB<D>(node.mbb);
117+
PRTreeElement(const PRTreeNode<T, B, D, Real> &node) {
118+
mbb = BB<D, Real>(node.mbb);
119119
if (node.leaf) {
120-
Leaf<T, B, D> tmp_leaf = Leaf<T, B, D>(*node.leaf.get());
121-
leaf = std::make_unique<PRTreeLeaf<T, B, D>>(tmp_leaf);
120+
Leaf<T, B, D, Real> tmp_leaf = Leaf<T, B, D, Real>(*node.leaf.get());
121+
leaf = std::make_unique<PRTreeLeaf<T, B, D, Real>>(tmp_leaf);
122122
}
123123
is_used = true;
124124
}
125125

126-
bool operator()(const BB<D> &target) { return is_used && mbb(target); }
126+
bool operator()(const BB<D, Real> &target) { return is_used && mbb(target); }
127127

128128
template <class Archive> void serialize(Archive &archive) {
129129
archive(mbb, leaf, is_used);
130130
}
131131
};
132132

133133
// Phase 8: Apply C++20 concept constraints
134-
template <IndexType T, int B = 6, int D = 2>
134+
template <IndexType T, int B = 6, int D = 2, typename Real = float>
135135
void bfs(
136-
const std::function<void(std::unique_ptr<PRTreeLeaf<T, B, D>> &)> &func,
137-
vec<PRTreeElement<T, B, D>> &flat_tree, const BB<D> target) {
136+
const std::function<void(std::unique_ptr<PRTreeLeaf<T, B, D, Real>> &)> &func,
137+
vec<PRTreeElement<T, B, D, Real>> &flat_tree, const BB<D, Real> target) {
138138
queue<size_t> que;
139139
auto qpush_if_intersect = [&](const size_t &i) {
140-
PRTreeElement<T, B, D> &r = flat_tree[i];
140+
PRTreeElement<T, B, D, Real> &r = flat_tree[i];
141141
// std::cout << "i " << (long int) i << " : " << (bool) r.leaf << std::endl;
142142
if (r(target)) {
143143
// std::cout << " is pushed" << std::endl;
@@ -151,7 +151,7 @@ void bfs(
151151
size_t idx = que.front();
152152
// std::cout << "idx: " << (long int) idx << std::endl;
153153
que.pop();
154-
PRTreeElement<T, B, D> &elem = flat_tree[idx];
154+
PRTreeElement<T, B, D, Real> &elem = flat_tree[idx];
155155

156156
if (elem.leaf) {
157157
// std::cout << "func called for " << (long int) idx << std::endl;

include/prtree/core/detail/pseudo_tree.h

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,22 @@
1919
#include "prtree/core/detail/types.h"
2020

2121
// Phase 8: Apply C++20 concept constraints
22-
template <IndexType T, int B = 6, int D = 2> class Leaf {
22+
template <IndexType T, int B = 6, int D = 2, typename Real = float> class Leaf {
2323
public:
24-
BB<D> mbb;
25-
svec<DataType<T, D>, B> data; // You can swap when filtering
24+
BB<D, Real> mbb;
25+
svec<DataType<T, D, Real>, B> data; // You can swap when filtering
2626
int axis = 0;
2727

2828
// T is type of keys(ids) which will be returned when you post a query.
29-
Leaf() { mbb = BB<D>(); }
29+
Leaf() { mbb = BB<D, Real>(); }
3030
Leaf(const int _axis) {
3131
axis = _axis;
32-
mbb = BB<D>();
32+
mbb = BB<D, Real>();
3333
}
3434

3535
void set_axis(const int &_axis) { axis = _axis; }
3636

37-
void push(const T &key, const BB<D> &target) {
37+
void push(const T &key, const BB<D, Real> &target) {
3838
data.emplace_back(key, target);
3939
update_mbb();
4040
}
@@ -46,15 +46,15 @@ template <IndexType T, int B = 6, int D = 2> class Leaf {
4646
}
4747
}
4848

49-
bool filter(DataType<T, D> &value) { // false means given value is ignored
49+
bool filter(DataType<T, D, Real> &value) { // false means given value is ignored
5050
// Phase 2: C++20 requires explicit 'this' capture
5151
auto comp = [this](const auto &a, const auto &b) noexcept {
5252
return a.second.val_for_comp(axis) < b.second.val_for_comp(axis);
5353
};
5454

5555
if (data.size() < B) { // if there is room, just push the candidate
5656
auto iter = std::lower_bound(data.begin(), data.end(), value, comp);
57-
DataType<T, D> tmp_value = DataType<T, D>(value);
57+
DataType<T, D, Real> tmp_value = DataType<T, D, Real>(value);
5858
data.insert(iter, std::move(tmp_value));
5959
mbb += value.second;
6060
return true;
@@ -76,9 +76,9 @@ template <IndexType T, int B = 6, int D = 2> class Leaf {
7676
};
7777

7878
// Phase 8: Apply C++20 concept constraints
79-
template <IndexType T, int B = 6, int D = 2> class PseudoPRTreeNode {
79+
template <IndexType T, int B = 6, int D = 2, typename Real = float> class PseudoPRTreeNode {
8080
public:
81-
Leaf<T, B, D> leaves[2 * D];
81+
Leaf<T, B, D, Real> leaves[2 * D];
8282
std::unique_ptr<PseudoPRTreeNode> left, right;
8383

8484
PseudoPRTreeNode() {
@@ -98,7 +98,7 @@ template <IndexType T, int B = 6, int D = 2> class PseudoPRTreeNode {
9898
archive(left, right, leaves);
9999
}
100100

101-
void address_of_leaves(vec<Leaf<T, B, D> *> &out) {
101+
void address_of_leaves(vec<Leaf<T, B, D, Real> *> &out) {
102102
for (auto &leaf : leaves) {
103103
if (leaf.data.size() > 0) {
104104
out.emplace_back(&leaf);
@@ -120,20 +120,20 @@ template <IndexType T, int B = 6, int D = 2> class PseudoPRTreeNode {
120120
};
121121

122122
// Phase 8: Apply C++20 concept constraints
123-
template <IndexType T, int B = 6, int D = 2> class PseudoPRTree {
123+
template <IndexType T, int B = 6, int D = 2, typename Real = float> class PseudoPRTree {
124124
public:
125-
std::unique_ptr<PseudoPRTreeNode<T, B, D>> root;
126-
vec<Leaf<T, B, D> *> cache_children;
125+
std::unique_ptr<PseudoPRTreeNode<T, B, D, Real>> root;
126+
vec<Leaf<T, B, D, Real> *> cache_children;
127127
const int nthreads = std::max(1, (int)std::thread::hardware_concurrency());
128128

129-
PseudoPRTree() { root = std::make_unique<PseudoPRTreeNode<T, B, D>>(); }
129+
PseudoPRTree() { root = std::make_unique<PseudoPRTreeNode<T, B, D, Real>>(); }
130130

131131
template <class iterator> PseudoPRTree(const iterator &b, const iterator &e) {
132132
if (!root) {
133-
root = std::make_unique<PseudoPRTreeNode<T, B, D>>();
133+
root = std::make_unique<PseudoPRTreeNode<T, B, D, Real>>();
134134
}
135135
construct(root.get(), b, e, 0);
136-
clean_data<T, D>(b, e);
136+
clean_data<T, D, Real>(b, e);
137137
}
138138

139139
template <class Archive> void serialize(Archive &archive) {
@@ -142,7 +142,7 @@ template <IndexType T, int B = 6, int D = 2> class PseudoPRTree {
142142
}
143143

144144
template <class iterator>
145-
void construct(PseudoPRTreeNode<T, B, D> *node, const iterator &b,
145+
void construct(PseudoPRTreeNode<T, B, D, Real> *node, const iterator &b,
146146
const iterator &e, const int depth) {
147147
if (e - b > 0 && node != nullptr) {
148148
bool use_recursive_threads = std::pow(2, depth + 1) <= nthreads;
@@ -152,20 +152,20 @@ template <IndexType T, int B = 6, int D = 2> class PseudoPRTree {
152152

153153
vec<std::thread> threads;
154154
threads.reserve(2);
155-
PseudoPRTreeNode<T, B, D> *node_left, *node_right;
155+
PseudoPRTreeNode<T, B, D, Real> *node_left, *node_right;
156156

157157
const int axis = depth % (2 * D);
158158
auto ee = node->filter(b, e);
159159
auto m = b;
160160
std::advance(m, (ee - b) / 2);
161161
std::nth_element(b, m, ee,
162-
[axis](const DataType<T, D> &lhs,
163-
const DataType<T, D> &rhs) noexcept {
162+
[axis](const DataType<T, D, Real> &lhs,
163+
const DataType<T, D, Real> &rhs) noexcept {
164164
return lhs.second[axis] < rhs.second[axis];
165165
});
166166

167167
if (m - b > 0) {
168-
node->left = std::make_unique<PseudoPRTreeNode<T, B, D>>(axis);
168+
node->left = std::make_unique<PseudoPRTreeNode<T, B, D, Real>>(axis);
169169
node_left = node->left.get();
170170
if (use_recursive_threads) {
171171
threads.push_back(
@@ -175,7 +175,7 @@ template <IndexType T, int B = 6, int D = 2> class PseudoPRTree {
175175
}
176176
}
177177
if (ee - m > 0) {
178-
node->right = std::make_unique<PseudoPRTreeNode<T, B, D>>(axis);
178+
node->right = std::make_unique<PseudoPRTreeNode<T, B, D, Real>>(axis);
179179
node_right = node->right.get();
180180
if (use_recursive_threads) {
181181
threads.push_back(
@@ -191,7 +191,7 @@ template <IndexType T, int B = 6, int D = 2> class PseudoPRTree {
191191

192192
auto get_all_leaves(const int hint) {
193193
if (cache_children.empty()) {
194-
using U = PseudoPRTreeNode<T, B, D>;
194+
using U = PseudoPRTreeNode<T, B, D, Real>;
195195
cache_children.reserve(hint);
196196
auto node = root.get();
197197
queue<U *> que;
@@ -210,15 +210,15 @@ template <IndexType T, int B = 6, int D = 2> class PseudoPRTree {
210210
return cache_children;
211211
}
212212

213-
std::pair<DataType<T, D> *, DataType<T, D> *> as_X(void *placement,
213+
std::pair<DataType<T, D, Real> *, DataType<T, D, Real> *> as_X(void *placement,
214214
const int hint) {
215-
DataType<T, D> *b, *e;
215+
DataType<T, D, Real> *b, *e;
216216
auto children = get_all_leaves(hint);
217217
T total = children.size();
218-
b = reinterpret_cast<DataType<T, D> *>(placement);
218+
b = reinterpret_cast<DataType<T, D, Real> *>(placement);
219219
e = b + total;
220220
for (T i = 0; i < total; i++) {
221-
new (b + i) DataType<T, D>{i, children[i]->mbb};
221+
new (b + i) DataType<T, D, Real>{i, children[i]->mbb};
222222
}
223223
return {b, e};
224224
}

0 commit comments

Comments
 (0)