2020)
2121from sklearn .datasets import fetch_openml , load_digits , load_iris
2222from sklearn .decomposition import PCA
23+ from sklearn .preprocessing import StandardScaler
2324from umap import UMAP
2425
2526from tdamapper .core import aggregate_graph
@@ -349,6 +350,7 @@ def mapper_lens_input_section(X):
349350
350351def mapper_cover_input_section ():
351352 st .header ("🌐 Cover" )
353+ scale_cover = st .checkbox ("Apply Scaling" , value = False , key = "scale_cover" )
352354 cover_type = st .selectbox (
353355 "Type" ,
354356 options = [
@@ -395,7 +397,7 @@ def mapper_cover_input_section():
395397 elif cover_type == V_COVER_KNN :
396398 knn_k = st .number_input ("Neighbors" , value = 10 , min_value = 1 )
397399 cover = KNNCover (neighbors = knn_k )
398- return cover
400+ return cover , scale_cover
399401
400402
401403def mapper_clustering_cover ():
@@ -564,6 +566,7 @@ def mapper_clustering_affinityprop():
564566
565567def mapper_clustering_input_section ():
566568 st .header ("🧮 Clustering" )
569+ scale_clustering = st .checkbox ("Apply Scaling" , value = False , key = "scale_clustering" )
567570 clustering_type = st .selectbox (
568571 "Type" ,
569572 options = [
@@ -592,32 +595,36 @@ def mapper_clustering_input_section():
592595 clustering = mapper_clustering_hdbscan ()
593596 elif clustering_type == V_CLUSTERING_AFFINITY_PROPAGATION :
594597 clustering = mapper_clustering_affinityprop ()
595- return clustering
598+ return clustering , scale_clustering
596599
597600
598601@st .cache_data (
599602 hash_funcs = {"tdamapper.learn.MapperAlgorithm" : MapperAlgorithm .__repr__ },
600603 show_spinner = "Computing Mapper" ,
601604)
602- def compute_mapper (mapper , X , y ):
605+ def compute_mapper (mapper , X , y , scale_clustering , scale_cover ):
603606 logger .info ("Generating Mapper graph" )
607+ if scale_clustering :
608+ X = StandardScaler ().fit_transform (X )
609+ if scale_cover :
610+ y = StandardScaler ().fit_transform (y )
604611 mapper_graph = mapper .fit_transform (X , y )
605612 return mapper_graph
606613
607614
608615def mapper_input_section (X ):
609616 lens = mapper_lens_input_section (X )
610617 st .divider ()
611- cover = mapper_cover_input_section ()
618+ cover , scale_cover = mapper_cover_input_section ()
612619 st .divider ()
613- clustering = mapper_clustering_input_section ()
620+ clustering , scale_clustering = mapper_clustering_input_section ()
614621 mapper_algo = MapperAlgorithm (
615622 cover = cover ,
616623 clustering = clustering ,
617624 verbose = False ,
618625 n_jobs = - 2 ,
619626 )
620- mapper_graph = compute_mapper (mapper_algo , X , lens )
627+ mapper_graph = compute_mapper (mapper_algo , X , lens , scale_clustering , scale_cover )
621628 return mapper_graph
622629
623630
@@ -707,20 +714,14 @@ def _hash_mapper_plot(mapper_plot):
707714 hash_funcs = {"tdamapper.plot.MapperPlot" : _hash_mapper_plot },
708715 show_spinner = "Rendering Mapper" ,
709716)
710- def compute_mapper_fig (mapper_plot , colors , node_size , cmap , _agg , agg_name ):
717+ def compute_mapper_fig (mapper_plot , colors , _agg , agg_name ):
711718 logger .info ("Generating Mapper figure" )
712719 mapper_fig = mapper_plot .plot_plotly (
713720 colors ,
714- node_size = [
715- 0.0 ,
716- node_size / 2.0 ,
717- node_size ,
718- node_size * 1.5 ,
719- node_size * 2.0 ,
720- ],
721+ node_size = [0.25 * i for i in range (9 )],
721722 agg = _agg ,
722723 title = [f"{ c } " for c in colors .columns ],
723- cmap = cmap ,
724+ cmap = [ "Jet" , "Viridis" , "Cividis" ] ,
724725 width = 600 ,
725726 height = 600 ,
726727 )
@@ -735,9 +736,7 @@ def mapper_figure_section(df_X, df_y, mapper_plot):
735736 mapper_fig = compute_mapper_fig (
736737 mapper_plot ,
737738 colors = colors ,
738- node_size = 1.0 ,
739739 _agg = agg ,
740- cmap = ["Jet" , "Viridis" , "Cividis" ],
741740 agg_name = agg_name ,
742741 )
743742 mapper_fig .update_layout (
0 commit comments