Skip to content

Commit 27d51f3

Browse files
authored
Merge pull request #220 from lucasimi/add-profiling
Added profiling. Reduced bottlenecks of quickselect using numba
2 parents 93d3a98 + 219a070 commit 27d51f3

28 files changed

+808
-614
lines changed

app/streamlit_app.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,11 @@ def mode(arr):
211211

212212

213213
def quantile(q):
214-
return lambda agg: np.nanquantile(agg, q=q)
214+
215+
def _quantile_q(agg):
216+
return np.nanquantile(agg, q=q)
217+
218+
return _quantile_q
215219

216220

217221
@st.cache_data
@@ -565,12 +569,12 @@ def plot_agg_input_section():
565569
return agg, agg_name
566570

567571

572+
def _hash_networkx_graph(graph):
573+
return _encode_graph(_get_graph_no_attribs(graph))
574+
575+
568576
@st.cache_data(
569-
hash_funcs={
570-
"networkx.classes.graph.Graph": lambda g: _encode_graph(
571-
_get_graph_no_attribs(g)
572-
)
573-
},
577+
hash_funcs={"networkx.classes.graph.Graph": _hash_networkx_graph},
574578
show_spinner="Generating Mapper Layout",
575579
)
576580
def compute_mapper_plot(mapper_graph, dim, seed, iterations):
@@ -610,8 +614,12 @@ def mapper_plot_section(mapper_graph):
610614
return mapper_plot
611615

612616

617+
def _hash_mapper_plot(mapper_plot):
618+
return mapper_plot.positions
619+
620+
613621
@st.cache_data(
614-
hash_funcs={"tdamapper.plot.MapperPlot": lambda mp: mp.positions},
622+
hash_funcs={"tdamapper.plot.MapperPlot": _hash_mapper_plot},
615623
show_spinner="Rendering Mapper",
616624
)
617625
def compute_mapper_fig(mapper_plot, colors, node_size, cmap, _agg, agg_name):

benchmarks/benchmark.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from tdamapper.core import TrivialClustering
1313

1414

15+
def _identity(x):
16+
return x
17+
18+
1519
def _segment(cardinality, dimension, noise=0.1, start=None, end=None):
1620
if start is None:
1721
start = np.zeros(dimension)
@@ -70,7 +74,7 @@ def fit(self, X, y=None):
7074
def run_gm(X, n, p):
7175
t0 = time.time()
7276
pipe = gm.make_mapper_pipeline(
73-
filter_func=lambda x: x,
77+
filter_func=_identity,
7478
cover=gm.CubicalCover(n_intervals=n, overlap_frac=p),
7579
clusterer=TrivialEstimator(),
7680
)

src/tdamapper/_common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
This module provides common functionalities for internal use.
33
"""
44

5+
import cProfile
6+
import io
7+
import pstats
58
import warnings
69

710
import numpy as np
@@ -147,3 +150,22 @@ def clone(obj):
147150
obj_noargs = type(obj)()
148151
obj_noargs.set_params(**params)
149152
return obj_noargs
153+
154+
155+
def profile(n_lines=10):
156+
def decorator(func):
157+
def wrapper(*args, **kwargs):
158+
profiler = cProfile.Profile()
159+
profiler.enable()
160+
result = func(*args, **kwargs)
161+
profiler.disable()
162+
163+
s = io.StringIO()
164+
ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative")
165+
ps.print_stats(n_lines)
166+
print(s.getvalue())
167+
return result
168+
169+
return wrapper
170+
171+
return decorator

src/tdamapper/utils/_metrics.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,27 @@
22
from numba import njit
33

44

5-
@njit(fastmath=True)
5+
@njit(fastmath=True) # pragma: no cover
66
def euclidean(x, y):
77
return np.linalg.norm(x - y)
88

99

10-
@njit(fastmath=True)
10+
@njit(fastmath=True) # pragma: no cover
1111
def manhattan(x, y):
1212
return np.linalg.norm(x - y, ord=1)
1313

1414

15-
@njit(fastmath=True)
15+
@njit(fastmath=True) # pragma: no cover
1616
def chebyshev(x, y):
1717
return np.linalg.norm(x - y, ord=np.inf)
1818

1919

20-
@njit(fastmath=True)
20+
@njit(fastmath=True) # pragma: no cover
2121
def minkowski(p, x, y):
2222
return np.linalg.norm(x - y, ord=p)
2323

2424

25-
@njit(fastmath=True)
25+
@njit(fastmath=True) # pragma: no cover
2626
def cosine(x, y):
2727
xy = np.dot(x, y)
2828
xx = np.linalg.norm(x)

src/tdamapper/utils/metrics.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ def minkowski(p):
114114
return euclidean()
115115
elif np.isinf(p):
116116
return chebyshev()
117-
return lambda x, y: _metrics.minkowski(p, x, y)
117+
118+
def dist(x, y):
119+
return _metrics.minkowski(p, x, y)
120+
121+
return dist
118122

119123

120124
def cosine():

src/tdamapper/utils/quickselect.py

Lines changed: 82 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,103 @@
1-
def __swap(arr, i, j):
1+
import numpy as np
2+
from numba import njit
3+
4+
_ARR = np.zeros(1)
5+
6+
7+
@njit # pragma: no cover
8+
def swap(arr, i, j):
29
arr[i], arr[j] = arr[j], arr[i]
310

411

5-
def partition(data, start, end, p_ord):
12+
@njit # pragma: no cover
13+
def _swap_all(arr, i, j, extra1, use_extra1, extra2, use_extra2):
14+
swap(arr, i, j)
15+
if use_extra1:
16+
swap(extra1, i, j)
17+
if use_extra2:
18+
swap(extra2, i, j)
19+
20+
21+
@njit # pragma: no cover
22+
def _partition(data, start, end, p_ord, extra1, use_extra1, extra2, use_extra2):
623
higher = start
724
for j in range(start, end):
8-
j_ord, _ = data[j]
25+
j_ord = data[j]
926
if j_ord < p_ord:
10-
__swap(data, higher, j)
27+
_swap_all(data, higher, j, extra1, use_extra1, extra2, use_extra2)
1128
higher += 1
1229
return higher
1330

1431

15-
def quickselect(data, start, end, k):
32+
@njit # pragma: no cover
33+
def _quickselect(data, start, end, k, extra1, use_extra1, extra2, use_extra2):
1634
if (k < start) or (k >= end):
1735
return
18-
start_, end_, higher = start, end, None
36+
start_, end_, higher = start, end, -1
1937
while higher != k + 1:
20-
p, _ = data[k]
21-
__swap(data, start_, k)
22-
higher = partition(data, start_ + 1, end_, p)
23-
__swap(data, start_, higher - 1)
38+
p = data[k]
39+
_swap_all(data, start_, k, extra1, use_extra1, extra2, use_extra2)
40+
higher = _partition(
41+
data, start_ + 1, end_, p, extra1, use_extra1, extra2, use_extra2
42+
)
43+
_swap_all(data, start_, higher - 1, extra1, use_extra1, extra2, use_extra2)
2444
if k <= higher - 1:
2545
end_ = higher
2646
else:
2747
start_ = higher
2848

2949

30-
def partition_tuple(data_ord, data_arr, start, end, p_ord):
31-
higher = start
32-
for j in range(start, end):
33-
j_ord = data_ord[j]
34-
if j_ord < p_ord:
35-
__swap(data_arr, higher, j)
36-
__swap(data_ord, higher, j)
37-
higher += 1
38-
return higher
50+
def _to_array(extra1=None, extra2=None):
51+
extra1_arr = _ARR if extra1 is None else extra1
52+
extra2_arr = _ARR if extra2 is None else extra2
53+
return extra1_arr, extra2_arr
3954

4055

41-
def quickselect_tuple(data_ord, data_arr, start, end, k):
42-
if (k < start) or (k >= end):
43-
return
44-
start_, end_, higher = start, end, None
45-
while higher != k + 1:
46-
p_ord = data_ord[k]
47-
__swap(data_arr, start_, k)
48-
__swap(data_ord, start_, k)
49-
higher = partition_tuple(data_ord, data_arr, start_ + 1, end_, p_ord)
50-
__swap(data_arr, start_, higher - 1)
51-
__swap(data_ord, start_, higher - 1)
52-
if k <= higher - 1:
53-
end_ = higher
54-
else:
55-
start_ = higher
56+
def _use_array(extra1=None, extra2=None):
57+
use_extra1 = extra1 is not None
58+
use_extra2 = extra2 is not None
59+
return use_extra1, use_extra2
60+
61+
62+
def swap_all(arr, i, j, extra1=None, extra2=None):
63+
extra1_arr, extra2_arr = _to_array(extra1, extra2)
64+
use_extra1, use_extra2 = _use_array(extra1, extra2)
65+
_swap_all(
66+
arr,
67+
i,
68+
j,
69+
extra1=extra1_arr,
70+
use_extra1=use_extra1,
71+
extra2=extra2_arr,
72+
use_extra2=use_extra2,
73+
)
74+
75+
76+
def partition(data, start, end, p_ord, extra1=None, extra2=None):
77+
extra1_arr, extra2_arr = _to_array(extra1, extra2)
78+
use_extra1, use_extra2 = _use_array(extra1, extra2)
79+
return _partition(
80+
data,
81+
start,
82+
end,
83+
p_ord,
84+
extra1=extra1_arr,
85+
use_extra1=use_extra1,
86+
extra2=extra2_arr,
87+
use_extra2=use_extra2,
88+
)
89+
90+
91+
def quickselect(data, start, end, k, extra1=None, extra2=None):
92+
extra1_arr, extra2_arr = _to_array(extra1, extra2)
93+
use_extra1, use_extra2 = _use_array(extra1, extra2)
94+
_quickselect(
95+
data,
96+
start,
97+
end,
98+
k,
99+
extra1=extra1_arr,
100+
use_extra1=use_extra1,
101+
extra2=extra2_arr,
102+
use_extra2=use_extra2,
103+
)

src/tdamapper/utils/vptree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
A module for fast knn and range searches, depending only on a given metric
33
"""
44

5-
from tdamapper.utils.vptree_flat import VPTree as FVPT
6-
from tdamapper.utils.vptree_hier import VPTree as HVPT
5+
from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT
6+
from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT
77

88

99
class VPTree:

0 commit comments

Comments
 (0)