Skip to content

Commit 197a732

Browse files
committed
Added plot settings. Improved UI styling
1 parent 22eb568 commit 197a732

File tree

1 file changed

+164
-53
lines changed

1 file changed

+164
-53
lines changed

src/tdamapper/_app.py

Lines changed: 164 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@
6262
CLUSTERING_DBSCAN_MIN_SAMPLES = 5
6363
CLUSTERING_AGGLOMERATIVE_N_CLUSTERS = 2
6464

65+
PLOT_ITERATIONS = 100
66+
PLOT_COLORMAP = "Viridis"
67+
PLOT_NODE_SIZE = 1.0
68+
6569
RANDOM_SEED = 42
6670

6771

@@ -82,18 +86,8 @@ class MapperConfig:
8286
clustering_dbscan_eps: float = CLUSTERING_DBSCAN_EPS
8387
clustering_dbscan_min_samples: int = CLUSTERING_DBSCAN_MIN_SAMPLES
8488
clustering_agglomerative_n_clusters: int = CLUSTERING_AGGLOMERATIVE_N_CLUSTERS
85-
86-
87-
def empty_figure():
88-
fig = go.Figure()
89-
fig.update_layout(
90-
xaxis=dict(visible=False),
91-
yaxis=dict(visible=False),
92-
plot_bgcolor="rgba(0,0,0,0)",
93-
paper_bgcolor="rgba(0,0,0,0)",
94-
margin=dict(l=0, r=0, t=0, b=0),
95-
)
96-
return fig
89+
plot_iterations: int = PLOT_ITERATIONS
90+
plot_seed: int = RANDOM_SEED
9791

9892

9993
def fix_data(data):
@@ -126,7 +120,7 @@ def _umap(X):
126120
return _umap
127121

128122

129-
def run_mapper(df, labels, **kwargs):
123+
def run_mapper(df, **kwargs):
130124
if df is None or df.empty:
131125
logger.error("No data found. Please upload a file first.")
132126
return
@@ -194,25 +188,32 @@ def run_mapper(df, labels, **kwargs):
194188
return
195189

196190
mapper = MapperAlgorithm(cover=cover, clustering=clustering)
197-
df_fixed = fix_data(df)
198-
X = df_fixed.to_numpy()
191+
X = df.to_numpy()
199192
y = lens(X)
200193
df_y = pd.DataFrame(y, columns=[f"{lens_type} {i}" for i in range(y.shape[1])])
201-
df_labels = pd.DataFrame(labels) if labels is not None else pd.DataFrame()
202194
if cover_scale_data:
203195
y = StandardScaler().fit_transform(y)
204196
if clustering_scale_data:
205197
X = StandardScaler().fit_transform(X)
206-
df_colors = pd.concat([df_labels, df_y, df_fixed], axis=1)
207198
mapper_graph = mapper.fit_transform(X, y)
199+
return mapper_graph, df_y
200+
201+
202+
def create_mapper_figure(df_X, df_y, df_target, mapper_graph, **kwargs):
203+
df_colors = pd.concat([df_target, df_y, df_X], axis=1)
204+
mapper_config = MapperConfig(**kwargs)
205+
plot_iterations = mapper_config.plot_iterations
206+
plot_seed = mapper_config.plot_seed
208207
mapper_fig = MapperPlot(
209208
mapper_graph,
210209
dim=3,
210+
iterations=plot_iterations,
211+
seed=plot_seed,
211212
).plot_plotly(
212213
colors=df_colors.to_numpy(),
213214
title=df_colors.columns.to_list(),
214215
cmap=[
215-
"Viridis",
216+
PLOT_COLORMAP,
216217
"Cividis",
217218
"Jet",
218219
"Plasma",
@@ -225,8 +226,11 @@ def run_mapper(df, labels, **kwargs):
225226
"PuOr",
226227
],
227228
height=800,
228-
node_size=[i * 0.125 for i in range(17)],
229+
node_size=[i * 0.125 * PLOT_NODE_SIZE for i in range(17)],
229230
)
231+
mapper_fig.layout.width = None
232+
mapper_fig.layout.height = None
233+
mapper_fig.layout.autosize = True
230234
logger.info("Mapper run completed successfully.")
231235
return mapper_fig
232236

@@ -235,20 +239,29 @@ class App:
235239

236240
def __init__(self, storage):
237241
self.storage = storage
242+
243+
ui.colors(
244+
themelight="#ebedf8",
245+
themedark="#132f48",
246+
)
247+
238248
with ui.left_drawer(elevated=True).classes(
239249
"w-96 h-full overflow-y-auto gap-12"
240250
):
241251
with ui.link(target=GIT_REPO_URL, new_tab=True).classes("w-full"):
242252
ui.image(LOGO_URL)
243253

254+
with ui.column().classes("w-full gap-2"):
255+
self._init_about()
256+
244257
with ui.column().classes("w-full gap-2"):
245258
self._init_file_upload()
246259

247260
ui.button(
248-
"Load Data",
261+
"⬆️ Load Data",
249262
on_click=self.load_file,
250-
color="primary",
251-
).classes("w-full")
263+
color="themelight",
264+
).classes("w-full text-themedark")
252265

253266
with ui.column().classes("w-full gap-2"):
254267
self._init_lens()
@@ -260,22 +273,53 @@ def __init__(self, storage):
260273
self._init_clustering()
261274

262275
ui.button(
263-
" Run Mapper",
276+
"🚀 Run Mapper",
264277
on_click=self.async_run_mapper,
265-
color="primary",
266-
).classes("w-full")
278+
color="themelight",
279+
).classes("w-full text-themedark")
267280

268-
ui.label(
269-
text=(
270-
"If you like this project, please consider giving it a ⭐ on GitHub!"
271-
"Made with ❤️ and ☕️ in Rome."
272-
)
273-
).classes("text-caption text-gray-500").classes(
274-
"text-caption text-gray-500"
275-
)
281+
with ui.column().classes("w-full gap-2"):
282+
self._init_draw()
283+
284+
ui.button(
285+
"🌊 Redraw",
286+
on_click=self.async_draw_mapper,
287+
color="themelight",
288+
).classes("w-full text-themedark")
289+
290+
with ui.column().classes("w-full gap-2"):
291+
self._init_footnotes()
276292

277293
with ui.column().classes("w-full h-screen overflow-hidden"):
278-
self._init_plot()
294+
self._init_draw_area()
295+
296+
def _init_about(self):
297+
with ui.dialog() as dialog, ui.card():
298+
ui.markdown(
299+
"""
300+
### About
301+
302+
**tda-mapper** is a Python library built around the Mapper algorithm, a core
303+
technique in Topological Data Analysis (TDA) for extracting topological
304+
structure from complex data. Designed for computational efficiency and
305+
scalability, it leverages optimized spatial search methods to support
306+
high-dimensional datasets. You can find further details in the
307+
[documentation](https://tda-mapper.readthedocs.io/en/main/)
308+
and in the
309+
[paper](https://openreview.net/pdf?id=lTX4bYREAZ).
310+
"""
311+
)
312+
ui.link(
313+
text="If you like this project, please consider giving it a ⭐ on GitHub!",
314+
target=GIT_REPO_URL,
315+
new_tab=True,
316+
).classes("w-full")
317+
ui.button("Close", on_click=dialog.close, color="themelight").classes(
318+
"w-full text-themedark"
319+
)
320+
ui.button("ℹ️ About", on_click=dialog.open, color="themelight").classes(
321+
"w-full text-themedark"
322+
)
279323

280324
def _init_file_upload(self):
281325
ui.label("📊 Data").classes("text-h6")
@@ -464,11 +508,31 @@ def _init_clustering_settings(self):
464508
value=CLUSTERING_AGGLOMERATIVE,
465509
)
466510

467-
def _init_plot(self):
511+
def _init_draw(self):
512+
ui.label("🎨 Draw").classes("text-h6")
513+
self._init_draw_settings()
514+
515+
def _init_draw_settings(self):
516+
self.plot_iterations = ui.number(
517+
label="Iterations",
518+
value=PLOT_ITERATIONS,
519+
min=1,
520+
max=10 * PLOT_ITERATIONS,
521+
).classes("w-full")
522+
self.plot_seed = ui.number(
523+
label="Seed",
524+
value=RANDOM_SEED,
525+
).classes("w-full")
526+
527+
def _init_footnotes(self):
528+
ui.label(text=("Made in Rome with ❤️ and ☕️")).classes(
529+
"text-caption text-gray-500"
530+
).classes("text-caption text-gray-500")
531+
532+
def _init_draw_area(self):
468533
self.plot_container = ui.element("div").classes("w-full h-full")
469534
with self.plot_container:
470-
fig = empty_figure()
471-
self.draw_area = ui.plotly(fig).classes("w-full h-full")
535+
self.draw_area = None
472536

473537
def get_mapper_config(self):
474538
return MapperConfig(
@@ -531,13 +595,21 @@ def get_mapper_config(self):
531595
if self.clustering_agglomerative_n_clusters.value
532596
else CLUSTERING_AGGLOMERATIVE_N_CLUSTERS
533597
),
598+
plot_iterations=(
599+
int(self.plot_iterations.value)
600+
if self.plot_iterations.value
601+
else PLOT_ITERATIONS
602+
),
603+
plot_seed=(
604+
int(self.plot_seed.value) if self.plot_seed.value else RANDOM_SEED
605+
),
534606
)
535607

536608
def upload_file(self, file):
537609
if file is not None:
538610
df = pd.read_csv(file.content)
539-
self.storage["df"] = df
540-
self.storage["labels"] = None
611+
self.storage["df"] = fix_data(df)
612+
self.storage["labels"] = pd.DataFrame()
541613
logger.info("File uploaded successfully.")
542614
ui.notify("File uploaded successfully.", type="info")
543615
else:
@@ -552,8 +624,8 @@ def load_file(self):
552624
else:
553625
logger.error("Unknown example dataset selected.")
554626
return
555-
self.storage["df"] = df
556-
self.storage["labels"] = labels
627+
self.storage["df"] = fix_data(df)
628+
self.storage["labels"] = fix_data(labels)
557629
elif self.load_type.value == LOAD_CSV:
558630
df = self.storage.get("df")
559631
if df is None:
@@ -573,29 +645,67 @@ def load_file(self):
573645
ui.notify("No data found. Please upload a file first.", type="warning")
574646

575647
async def async_run_mapper(self):
576-
df = self.storage.get("df")
577-
if df is None or df.empty:
648+
notification = ui.notification(timeout=None, type="ongoing")
649+
notification.message = "Running Mapper..."
650+
notification.spinner = True
651+
df_X = self.storage.get("df")
652+
if df_X is None or df_X.empty:
578653
logger.warning("No data found. Please upload a file first.")
579654
ui.notify("No data found. Please upload a file first.", type="warning")
580655
return
581-
labels = self.storage.get("labels")
656+
mapper_config = self.get_mapper_config()
657+
mapper_graph, df_y = await run.cpu_bound(
658+
run_mapper, df_X, **asdict(mapper_config)
659+
)
660+
self.storage["mapper_graph"] = mapper_graph
661+
self.storage["df_y"] = df_y
662+
notification.message = "Done!"
663+
notification.spinner = False
664+
notification.dismiss()
665+
666+
await self.async_draw_mapper()
667+
668+
async def async_draw_mapper(self):
582669
notification = ui.notification(timeout=None, type="ongoing")
583-
notification.message = "Running Mapper..."
670+
notification.message = "Drawing Mapper..."
584671
notification.spinner = True
672+
585673
mapper_config = self.get_mapper_config()
674+
675+
df_X = self.storage.get("df", pd.DataFrame())
676+
df_y = self.storage.get("df_y", pd.DataFrame())
677+
df_target = self.storage.get("labels", pd.DataFrame())
678+
mapper_graph = self.storage.get("mapper_graph", None)
679+
680+
if df_X.empty or mapper_graph is None:
681+
logger.warning("No data or Mapper graph found. Please run Mapper first.")
682+
ui.notify(
683+
"No data or Mapper graph found. Please run Mapper first.",
684+
type="warning",
685+
)
686+
notification.message = "No data or Mapper graph found."
687+
notification.spinner = False
688+
notification.dismiss()
689+
return
690+
586691
mapper_fig = await run.cpu_bound(
587-
run_mapper, df, labels, **asdict(mapper_config)
692+
create_mapper_figure,
693+
df_X,
694+
df_y,
695+
df_target,
696+
mapper_graph,
697+
**asdict(mapper_config),
588698
)
589-
mapper_fig.layout.width = None
590-
mapper_fig.layout.height = None
591-
mapper_fig.layout.autosize = True
592-
notification.message = "Done!"
593-
notification.spinner = False
594-
self.draw_area.clear()
699+
700+
if self.draw_area is not None:
701+
self.draw_area.clear()
595702
self.plot_container.clear()
596703
with self.plot_container:
597704
logger.info("Displaying Mapper plot.")
598705
self.draw_area = ui.plotly(mapper_fig).classes("w-full h-full")
706+
707+
notification.message = "Done!"
708+
notification.spinner = False
599709
notification.dismiss()
600710

601711

@@ -609,10 +719,11 @@ def main_page():
609719
def main():
610720
port = os.getenv("PORT", "8080")
611721
host = os.getenv("HOST", "0.0.0.0")
722+
production = os.getenv("PRODUCTION", "false").lower() == "true"
612723
storage_secret = os.getenv("STORAGE_SECRET", "storage_secret")
613724
ui.run(
614725
storage_secret=storage_secret,
615-
reload=False,
726+
reload=not production,
616727
host=host,
617728
title="tda-mapper-app",
618729
favicon=ICON_URL,

0 commit comments

Comments
 (0)