Skip to content

Commit e722b5f

Browse files
authored
Merge pull request #226 from lucasimi/improve-plot-plotly
Improve plot plotly
2 parents 1805f01 + 7a5d85b commit e722b5f

File tree

14 files changed

+334
-200
lines changed

14 files changed

+334
-200
lines changed

docs/source/notebooks/circles.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
fig = plot.plot_plotly(
9797
colors=labels,
9898
cmap=["jet", "viridis", "cividis"],
99+
node_size=[0.0, 0.5, 1.0, 1.5, 2.0],
99100
agg=np.nanmean,
100101
width=600,
101102
height=600,
@@ -118,6 +119,7 @@
118119
fig = plot.plot_plotly(
119120
colors=labels,
120121
cmap=["jet", "viridis", "cividis"],
122+
node_size=[0.0, 0.5, 1.0, 1.5, 2.0],
121123
agg=np.nanstd,
122124
width=600,
123125
height=600,

docs/source/notebooks/digits.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def mode(arr):
103103
colors=labels,
104104
cmap=["jet", "viridis", "cividis"],
105105
agg=mode,
106+
node_size=[0.0, 0.5, 1.0, 1.5, 2.0],
106107
title="mode of digits",
107108
width=600,
108109
height=600,
109-
node_size=0.5,
110110
)
111111

112112
fig.show(config={"scrollZoom": True}, renderer="notebook_connected")
@@ -134,10 +134,10 @@ def entropy(arr):
134134
colors=labels,
135135
cmap=["jet", "viridis", "cividis"],
136136
agg=entropy,
137+
node_size=[0.0, 0.5, 1.0, 1.5, 2.0],
137138
title="entropy of digits",
138139
width=600,
139140
height=600,
140-
node_size=0.5,
141141
)
142142

143143
fig.show(config={"scrollZoom": True}, renderer="notebook_connected")

src/tdamapper/_common.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def warn_user(msg):
2929

3030
class EstimatorMixin:
3131

32-
def __is_sparse(self, X):
32+
def _is_sparse(self, X):
3333
# simple alternative use scipy.sparse.issparse
3434
return hasattr(X, "toarray")
3535

3636
def _validate_X_y(self, X, y):
37-
if self.__is_sparse(X):
37+
if self._is_sparse(X):
3838
raise ValueError("Sparse data not supported.")
3939

4040
X = np.asarray(X)
@@ -80,10 +80,10 @@ class ParamsMixin:
8080
scikit-learn `get_params` and `set_params`.
8181
"""
8282

83-
def __is_param_public(self, k):
83+
def _is_param_public(self, k):
8484
return (not k.startswith("_")) and (not k.endswith("_"))
8585

86-
def __split_param(self, k):
86+
def _split_param(self, k):
8787
k_split = k.split("__")
8888
outer = k_split[0]
8989
inner = "__".join(k_split[1:])
@@ -98,7 +98,7 @@ def get_params(self, deep=True):
9898
"""
9999
params = {}
100100
for k, v in self.__dict__.items():
101-
if self.__is_param_public(k):
101+
if self._is_param_public(k):
102102
params[k] = v
103103
if hasattr(v, "get_params") and deep:
104104
for _k, _v in v.get_params().items():
@@ -111,8 +111,8 @@ def set_params(self, **params):
111111
"""
112112
nested_params = []
113113
for k, v in params.items():
114-
if self.__is_param_public(k):
115-
k_outer, k_inner = self.__split_param(k)
114+
if self._is_param_public(k):
115+
k_outer, k_inner = self._split_param(k)
116116
if not k_inner:
117117
if hasattr(self, k_outer):
118118
setattr(self, k_outer, v)
@@ -131,7 +131,7 @@ def __repr__(self):
131131
v_default = getattr(obj_noargs, k)
132132
v_default_repr = repr(v_default)
133133
v_repr = repr(v)
134-
if self.__is_param_public(k) and not v_repr == v_default_repr:
134+
if self._is_param_public(k) and not v_repr == v_default_repr:
135135
args_repr.append(f"{k}={v_repr}")
136136
return f"{self.__class__.__name__}({', '.join(args_repr)})"
137137

src/tdamapper/core.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def fit(self, X):
297297
:return: The object itself.
298298
:rtype: self
299299
"""
300-
self.__X = X
300+
self._X = X
301301
return self
302302

303303
def search(self, x):
@@ -314,7 +314,7 @@ def search(self, x):
314314
dataset.
315315
:rtype: list[int]
316316
"""
317-
return list(range(0, len(self.__X)))
317+
return list(range(0, len(self._X)))
318318

319319
def apply(self, X):
320320
"""
@@ -385,27 +385,27 @@ def __init__(
385385

386386
def fit(self, X, y=None):
387387
X, y = self._validate_X_y(X, y)
388-
self.__cover = TrivialCover() if self.cover is None else self.cover
389-
self.__clustering = (
388+
self._cover = TrivialCover() if self.cover is None else self.cover
389+
self._clustering = (
390390
TrivialClustering() if self.clustering is None else self.clustering
391391
)
392-
self.__verbose = self.verbose
393-
self.__failsafe = self.failsafe
394-
if self.__failsafe:
395-
self.__clustering = FailSafeClustering(
396-
clustering=self.__clustering,
397-
verbose=self.__verbose,
392+
self._verbose = self.verbose
393+
self._failsafe = self.failsafe
394+
if self._failsafe:
395+
self._clustering = FailSafeClustering(
396+
clustering=self._clustering,
397+
verbose=self._verbose,
398398
)
399-
self.__cover = clone(self.__cover)
400-
self.__clustering = clone(self.__clustering)
401-
self.__n_jobs = self.n_jobs
399+
self._cover = clone(self._cover)
400+
self._clustering = clone(self._clustering)
401+
self._n_jobs = self.n_jobs
402402
y = X if y is None else y
403403
self.graph_ = mapper_graph(
404404
X,
405405
y,
406-
self.__cover,
407-
self.__clustering,
408-
n_jobs=self.__n_jobs,
406+
self._cover,
407+
self._clustering,
408+
n_jobs=self._n_jobs,
409409
)
410410
self._set_n_features_in(X)
411411
return self
@@ -451,16 +451,16 @@ def __init__(self, clustering=None, verbose=True):
451451
self.verbose = verbose
452452

453453
def fit(self, X, y=None):
454-
self.__clustering = (
454+
self._clustering = (
455455
TrivialClustering() if self.clustering is None else self.clustering
456456
)
457-
self.__verbose = self.verbose
457+
self._verbose = self.verbose
458458
self.labels_ = None
459459
try:
460-
self.__clustering.fit(X, y)
461-
self.labels_ = self.__clustering.labels_
460+
self._clustering.fit(X, y)
461+
self.labels_ = self._clustering.labels_
462462
except ValueError as err:
463-
if self.__verbose:
463+
if self._verbose:
464464
_logger.warning("Unable to perform clustering on local chart: %s", err)
465465
self.labels_ = [0 for _ in X]
466466
return self

src/tdamapper/cover.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ def fit(self, X):
9797
:rtype: self
9898
"""
9999
metric = get_metric(self.metric, **(self.metric_params or {}))
100-
self.__radius = self.radius
101-
self.__data = list(enumerate(X))
102-
self.__vptree = VPTree(
103-
self.__data,
100+
self._radius = self.radius
101+
self._data = list(enumerate(X))
102+
self._vptree = VPTree(
103+
self._data,
104104
metric=_Pullback(_snd, metric),
105105
metric_params=None,
106106
kind=self.kind,
@@ -121,11 +121,11 @@ def search(self, x):
121121
:return: The indices of the neighbors contained in the dataset.
122122
:rtype: list[int]
123123
"""
124-
if self.__vptree is None:
124+
if self._vptree is None:
125125
return []
126-
neighs = self.__vptree.ball_search(
126+
neighs = self._vptree.ball_search(
127127
(-1, x),
128-
self.__radius,
128+
self._radius,
129129
inclusive=False,
130130
)
131131
return [x for (x, _) in neighs]
@@ -198,10 +198,10 @@ def fit(self, X):
198198
:rtype: self
199199
"""
200200
metric = get_metric(self.metric, **(self.metric_params or {}))
201-
self.__neighbors = self.neighbors
202-
self.__data = list(enumerate(X))
203-
self.__vptree = VPTree(
204-
self.__data,
201+
self._neighbors = self.neighbors
202+
self._data = list(enumerate(X))
203+
self._vptree = VPTree(
204+
self._data,
205205
metric=_Pullback(_snd, metric),
206206
metric_params=None,
207207
kind=self.kind,
@@ -223,9 +223,9 @@ def search(self, x):
223223
:return: The indices of the neighbors contained in the dataset.
224224
:rtype: list[int]
225225
"""
226-
if self.__vptree is None:
226+
if self._vptree is None:
227227
return []
228-
neighs = self.__vptree.knn_search((-1, x), self.__neighbors)
228+
neighs = self._vptree.knn_search((-1, x), self._neighbors)
229229
return [x for (x, _) in neighs]
230230

231231

src/tdamapper/plot.py

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import numpy as np
88

99
from tdamapper._common import deprecated
10-
from tdamapper._plot_matplotlib import plot_matplotlib
11-
from tdamapper._plot_plotly import plot_plotly, plot_plotly_update
12-
from tdamapper._plot_pyvis import plot_pyvis
10+
from tdamapper.plot_backends.plot_matplotlib import plot_matplotlib
11+
from tdamapper.plot_backends.plot_plotly import plot_plotly, plot_plotly_update
12+
from tdamapper.plot_backends.plot_pyvis import plot_pyvis
1313

1414

1515
class MapperPlot:
@@ -206,10 +206,6 @@ def plot_plotly(
206206
cmap=cmap,
207207
)
208208

209-
@deprecated(
210-
"This method is deprecated and will be removed in a future release. "
211-
"Use a new instance of tdamapper.plot.MapperPlot."
212-
)
213209
def plot_plotly_update(
214210
self,
215211
fig,
@@ -382,29 +378,29 @@ def __init__(
382378
height=512,
383379
cmap="jet",
384380
):
385-
self.__graph = graph
386-
self.__dim = dim
387-
self.__iterations = iterations
388-
self.__seed = seed
389-
self.__mapper_plot = MapperPlot(
390-
graph=self.__graph,
391-
dim=self.__dim,
392-
iterations=self.__iterations,
393-
seed=self.__seed,
381+
self._graph = graph
382+
self._dim = dim
383+
self._iterations = iterations
384+
self._seed = seed
385+
self._mapper_plot = MapperPlot(
386+
graph=self._graph,
387+
dim=self._dim,
388+
iterations=self._iterations,
389+
seed=self._seed,
394390
)
395-
self.__colors = colors
396-
self.__agg = agg
397-
self.__title = title
398-
self.__width = width
399-
self.__height = height
400-
self.__cmap = cmap
401-
self.__fig = self.__mapper_plot.plot_plotly(
402-
colors=self.__colors,
403-
agg=self.__agg,
404-
title=self.__title,
405-
width=self.__width,
406-
height=self.__height,
407-
cmap=self.__cmap,
391+
self._colors = colors
392+
self._agg = agg
393+
self._title = title
394+
self._width = width
395+
self._height = height
396+
self._cmap = cmap
397+
self._fig = self._mapper_plot.plot_plotly(
398+
colors=self._colors,
399+
agg=self._agg,
400+
title=self._title,
401+
width=self._width,
402+
height=self._height,
403+
cmap=self._cmap,
408404
)
409405

410406
def update(
@@ -451,20 +447,20 @@ def update(
451447
"""
452448
_update_pos = False
453449
if seed is not None:
454-
self.__seed = seed
450+
self._seed = seed
455451
_update_pos = True
456452
if iterations is not None:
457-
self.__iterations = iterations
453+
self._iterations = iterations
458454
_update_pos = True
459455
if _update_pos:
460-
self.__mapper_plot = MapperPlot(
461-
graph=self.__graph,
462-
dim=self.__dim,
463-
iterations=self.__iterations,
464-
seed=self.__seed,
456+
self._mapper_plot = MapperPlot(
457+
graph=self._graph,
458+
dim=self._dim,
459+
iterations=self._iterations,
460+
seed=self._seed,
465461
)
466-
self.__mapper_plot.plot_plotly_update(
467-
self.__fig,
462+
self._mapper_plot.plot_plotly_update(
463+
self._fig,
468464
colors=colors,
469465
agg=agg,
470466
title=title,
@@ -482,7 +478,7 @@ def plot(self):
482478
context to be shown.
483479
:rtype: :class:`plotly.graph_objects.Figure`
484480
"""
485-
return self.__fig
481+
return self._fig
486482

487483

488484
class MapperLayoutStatic:
@@ -543,12 +539,12 @@ def __init__(
543539
height=512,
544540
cmap="jet",
545541
):
546-
self.__colors = colors
547-
self.__agg = agg
548-
self.__title = title
549-
self.__width = width
550-
self.__height = height
551-
self.__cmap = cmap
542+
self._colors = colors
543+
self._agg = agg
544+
self._title = title
545+
self._width = width
546+
self._height = height
547+
self._cmap = cmap
552548
self.mapper_plot = MapperPlot(
553549
graph=graph,
554550
dim=dim,
@@ -566,10 +562,10 @@ def plot(self):
566562
:class:`matplotlib.axes.Axes`
567563
"""
568564
return self.mapper_plot.plot_matplotlib(
569-
colors=self.__colors,
570-
agg=self.__agg,
571-
title=self.__title,
572-
width=self.__width,
573-
height=self.__height,
574-
cmap=self.__cmap,
565+
colors=self._colors,
566+
agg=self._agg,
567+
title=self._title,
568+
width=self._width,
569+
height=self._height,
570+
cmap=self._cmap,
575571
)

src/tdamapper/plot_backends/__init__.py

Whitespace-only changes.
File renamed without changes.

0 commit comments

Comments
 (0)