Skip to content

Commit f7639d3

Browse files
committed
Updated buttons
1 parent 197a732 commit f7639d3

File tree

2 files changed

+99
-37
lines changed

2 files changed

+99
-37
lines changed

src/tdamapper/_app.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from dataclasses import asdict, dataclass
44

55
import pandas as pd
6-
import plotly.graph_objects as go
76
from nicegui import app, run, ui
87
from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans
98
from sklearn.datasets import load_digits, load_iris
@@ -62,6 +61,7 @@
6261
CLUSTERING_DBSCAN_MIN_SAMPLES = 5
6362
CLUSTERING_AGGLOMERATIVE_N_CLUSTERS = 2
6463

64+
PLOT_DIMENSIONS = 2
6565
PLOT_ITERATIONS = 100
6666
PLOT_COLORMAP = "Viridis"
6767
PLOT_NODE_SIZE = 1.0
@@ -86,6 +86,7 @@ class MapperConfig:
8686
clustering_dbscan_eps: float = CLUSTERING_DBSCAN_EPS
8787
clustering_dbscan_min_samples: int = CLUSTERING_DBSCAN_MIN_SAMPLES
8888
clustering_agglomerative_n_clusters: int = CLUSTERING_AGGLOMERATIVE_N_CLUSTERS
89+
plot_dimensions: int = PLOT_DIMENSIONS
8990
plot_iterations: int = PLOT_ITERATIONS
9091
plot_seed: int = RANDOM_SEED
9192

@@ -202,11 +203,12 @@ def run_mapper(df, **kwargs):
202203
def create_mapper_figure(df_X, df_y, df_target, mapper_graph, **kwargs):
203204
df_colors = pd.concat([df_target, df_y, df_X], axis=1)
204205
mapper_config = MapperConfig(**kwargs)
206+
plot_dimensions = mapper_config.plot_dimensions
205207
plot_iterations = mapper_config.plot_iterations
206208
plot_seed = mapper_config.plot_seed
207209
mapper_fig = MapperPlot(
208210
mapper_graph,
209-
dim=3,
211+
dim=plot_dimensions,
210212
iterations=plot_iterations,
211213
seed=plot_seed,
212214
).plot_plotly(
@@ -217,20 +219,18 @@ def create_mapper_figure(df_X, df_y, df_target, mapper_graph, **kwargs):
217219
"Cividis",
218220
"Jet",
219221
"Plasma",
220-
"Inferno",
221-
"Magma",
222-
"Turbo",
223222
"RdBu",
224-
"BrBG",
225-
"PiYG",
226-
"PuOr",
227223
],
228224
height=800,
229225
node_size=[i * 0.125 * PLOT_NODE_SIZE for i in range(17)],
230226
)
231-
mapper_fig.layout.width = None
232-
mapper_fig.layout.height = None
233-
mapper_fig.layout.autosize = True
227+
mapper_fig.update_layout(
228+
width=None,
229+
height=None,
230+
autosize=True,
231+
xaxis=dict(scaleanchor="y"),
232+
uirevision="constant",
233+
)
234234
logger.info("Mapper run completed successfully.")
235235
return mapper_fig
236236

@@ -513,6 +513,11 @@ def _init_draw(self):
513513
self._init_draw_settings()
514514

515515
def _init_draw_settings(self):
516+
self.plot_dimensions = ui.select(
517+
options=[2, 3],
518+
label="Dimensions",
519+
value=PLOT_DIMENSIONS,
520+
).classes("w-full")
516521
self.plot_iterations = ui.number(
517522
label="Iterations",
518523
value=PLOT_ITERATIONS,
@@ -531,8 +536,7 @@ def _init_footnotes(self):
531536

532537
def _init_draw_area(self):
533538
self.plot_container = ui.element("div").classes("w-full h-full")
534-
with self.plot_container:
535-
self.draw_area = None
539+
self.draw_area = None
536540

537541
def get_mapper_config(self):
538542
return MapperConfig(
@@ -595,6 +599,11 @@ def get_mapper_config(self):
595599
if self.clustering_agglomerative_n_clusters.value
596600
else CLUSTERING_AGGLOMERATIVE_N_CLUSTERS
597601
),
602+
plot_dimensions=(
603+
int(self.plot_dimensions.value)
604+
if self.plot_dimensions.value
605+
else PLOT_DIMENSIONS
606+
),
598607
plot_iterations=(
599608
int(self.plot_iterations.value)
600609
if self.plot_iterations.value
@@ -697,11 +706,11 @@ async def async_draw_mapper(self):
697706
**asdict(mapper_config),
698707
)
699708

709+
logger.info("Displaying Mapper plot.")
700710
if self.draw_area is not None:
701711
self.draw_area.clear()
702-
self.plot_container.clear()
712+
self.plot_container.clear()
703713
with self.plot_container:
704-
logger.info("Displaying Mapper plot.")
705714
self.draw_area = ui.plotly(mapper_fig).classes("w-full h-full")
706715

707716
notification.message = "Done!"

src/tdamapper/plot_backends/plot_plotly.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333

3434
_DEFAULT_SPACING = 0.25
3535

36+
FONT_COLOR_LIGHT = "#2a3f5f"
37+
38+
FONT_COLOR_DARK = "#f2f5fa"
39+
3640
DEFAULT_NODE_SIZE = 1
3741

3842
DEFAULT_CMAP = "Jet"
@@ -209,6 +213,7 @@ def __init__(self, mapper_plot, fig: Optional[go.Figure] = None):
209213
self.ui_menu_cmap: Dict = {}
210214
self.ui_menu_color: Dict = {}
211215
self.ui_slider_size: Dict = {}
216+
self.ui_menu_dark_mode: Dict = {}
212217

213218
def plot(
214219
self,
@@ -468,13 +473,15 @@ def _edges_trace(self, edge_pos_arr) -> Union[go.Scatter, go.Scatter3d]:
468473

469474
def _colorbar(self, title: str) -> dict:
470475
cbar = dict(
476+
orientation="v",
471477
showticklabels=True,
472478
outlinewidth=1,
473479
borderwidth=0,
474480
thicknessmode="fraction",
475-
xanchor="left",
476481
title_side="right",
477482
title_text=title,
483+
xanchor="right",
484+
x=1.0,
478485
ypad=0,
479486
xpad=0,
480487
tickwidth=1,
@@ -495,7 +502,6 @@ def _lbl(n):
495502
return [_lbl(n) for n in self.graph.nodes()]
496503

497504
def _layout(self) -> go.Layout:
498-
line_col = "rgba(230, 230, 230, 1.0)"
499505
axis = dict(
500506
showline=False,
501507
linewidth=1,
@@ -513,16 +519,12 @@ def _layout(self) -> go.Layout:
513519
backgroundcolor="rgba(0, 0, 0, 0)",
514520
showaxeslabels=False,
515521
showline=False,
516-
linecolor=line_col,
517-
zerolinecolor=line_col,
518-
gridcolor=line_col,
519522
linewidth=1,
520523
mirror=True,
521524
showticklabels=False,
522525
title="",
523526
)
524527
return go.Layout(
525-
template="plotly_white",
526528
autosize=True,
527529
height=None,
528530
width=None,
@@ -536,6 +538,9 @@ def _layout(self) -> go.Layout:
536538
yaxis=scene_axis,
537539
zaxis=scene_axis,
538540
),
541+
font_color=FONT_COLOR_LIGHT,
542+
paper_bgcolor="white",
543+
plot_bgcolor="white",
539544
)
540545

541546
def set_ui(
@@ -558,19 +563,23 @@ def set_ui(
558563
if node_sizes is not None:
559564
self.ui_slider_size = self._ui_slider_node_size(node_sizes)
560565

566+
self.ui_menu_dark_mode = self.ui_dark_mode()
567+
561568
menus = []
562569
sliders = []
563-
x = 0.0
570+
564571
if self.ui_menu_cmap:
565-
self.ui_menu_cmap["x"] = x
566-
x += _DEFAULT_SPACING
567572
menus.append(self.ui_menu_cmap)
568573
if self.ui_menu_color:
569-
self.ui_menu_color["x"] = x
570574
menus.append(self.ui_menu_color)
575+
if self.ui_menu_dark_mode:
576+
menus.append(self.ui_menu_dark_mode)
571577
if self.ui_slider_size:
572-
self.ui_slider_size["x"] = 0.0
573578
sliders.append(self.ui_slider_size)
579+
580+
# self.fig.layout.updatemenus = [self.ui_menu_dark_mode]
581+
# self.fig.layout.sliders = []
582+
574583
self.fig.update_layout(
575584
updatemenus=menus,
576585
sliders=sliders,
@@ -607,14 +616,14 @@ def _update_cmap(cmap: str) -> dict:
607616

608617
return dict(
609618
buttons=buttons,
610-
x=0.25,
611-
xanchor="left",
619+
direction="down",
620+
x=0.5,
612621
y=1.0,
622+
xanchor="center",
613623
yanchor="top",
614-
direction="down",
615624
)
616625

617-
def _ui_slider_node_size(self, node_sizes: List[float]) -> dict:
626+
def _ui_slider_node_size(self, node_sizes: List[float]) -> Dict:
618627
steps = [
619628
dict(
620629
method="restyle",
@@ -629,16 +638,25 @@ def _ui_slider_node_size(self, node_sizes: List[float]) -> dict:
629638

630639
return dict(
631640
active=len(steps) // 2,
632-
currentvalue={"prefix": "Node size: "},
641+
currentvalue=dict(
642+
prefix="Node size: ",
643+
visible=False,
644+
xanchor="center",
645+
),
633646
steps=steps,
634-
x=0.0,
647+
x=0.5,
635648
y=0.0,
636-
xanchor="left",
637-
len=0.3,
649+
xanchor="center",
638650
yanchor="bottom",
651+
len=0.5,
652+
lenmode="fraction",
653+
ticklen=1,
654+
pad=dict(t=1, b=1, l=1, r=1),
655+
bgcolor="rgba(1.0, 1.0, 1.0, 0.5)",
656+
activebgcolor="rgba(127, 127, 127, 0.5)",
639657
)
640658

641-
def _ui_menu_color(self, colors: np.ndarray, titles: List[str], agg) -> dict:
659+
def _ui_menu_color(self, colors: np.ndarray, titles: List[str], agg) -> Dict:
642660
colors_arr = np.array(colors)
643661
colors_num = colors_arr.shape[1] if colors_arr.ndim == 2 else 1
644662

@@ -699,10 +717,45 @@ def _update_colors(i: int) -> dict:
699717

700718
return dict(
701719
buttons=buttons,
720+
direction="down",
702721
active=0,
703-
x=0.0,
704-
xanchor="left",
722+
x=0.75,
705723
y=1.0,
724+
xanchor="center",
706725
yanchor="top",
726+
)
727+
728+
def ui_dark_mode(self) -> Dict:
729+
buttons = [
730+
dict(
731+
label="Light",
732+
method="relayout",
733+
args=[
734+
{
735+
"font.color": FONT_COLOR_LIGHT,
736+
"paper_bgcolor": "white",
737+
"plot_bgcolor": "white",
738+
}
739+
],
740+
),
741+
dict(
742+
label="Dark",
743+
method="relayout",
744+
args=[
745+
{
746+
"font.color": FONT_COLOR_DARK,
747+
"paper_bgcolor": "black",
748+
"plot_bgcolor": "black",
749+
}
750+
],
751+
),
752+
]
753+
return dict(
754+
buttons=buttons,
707755
direction="down",
756+
active=0,
757+
x=0.25,
758+
y=1.0,
759+
xanchor="center",
760+
yanchor="top",
708761
)

0 commit comments

Comments
 (0)