Skip to content

Commit 5443b2f

Browse files
committed
Fixed node size
1 parent 76eaabf commit 5443b2f

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

src/tdamapper/plot_backends/plot_plotly.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737

3838
_MAX_SIZEREF = 1000000000
3939

40-
_MARKER_SIZE_FACTOR = 400.0
40+
_MARKER_SIZE_FACTOR_2D = 400.0
41+
42+
_MARKER_SIZE_FACTOR_3D = 600.0
4143

4244
MENU_DARK_MODE_NAME = "menu_dark_mode"
4345

@@ -271,9 +273,8 @@ def _edge_pos_array(self):
271273
def _marker_size(self) -> List[float]:
272274
attr_size = nx.get_node_attributes(self.graph, ATTR_SIZE)
273275
max_size = max(attr_size.values(), default=1.0)
274-
marker_size = [
275-
_MARKER_SIZE_FACTOR * attr_size[n] / max_size for n in self.graph.nodes()
276-
]
276+
factor = _MARKER_SIZE_FACTOR_3D if self.dim == 3 else _MARKER_SIZE_FACTOR_2D
277+
marker_size = [factor * attr_size[n] / max_size for n in self.graph.nodes()]
277278
return marker_size
278279

279280
def set_cmap(self, cmap: str) -> None:
@@ -646,21 +647,21 @@ def ui_menu_dark_mode(self) -> Dict:
646647
)
647648

648649
def _ui_menu_cmap(self, cmaps: List[str]) -> dict:
649-
target_traces = [1] if self.dim == 2 else [0, 1]
650+
target_traces = [0, 1] if self.dim == 3 else [1]
650651

651652
def _update_cmap(cmap: str) -> dict:
652653
cmap_rgb = _get_cmap_rgb(cmap)
653-
if self.dim == 2:
654-
return {
655-
"marker.colorscale": [cmap_rgb],
656-
"marker.line.colorscale": [cmap_rgb],
657-
}
658-
elif self.dim == 3:
654+
if self.dim == 3:
659655
return {
660656
"marker.colorscale": [None, cmap_rgb],
661657
"marker.line.colorscale": [None, cmap_rgb],
662658
"line.colorscale": [cmap_rgb, None],
663659
}
660+
elif self.dim == 2:
661+
return {
662+
"marker.colorscale": [cmap_rgb],
663+
"marker.line.colorscale": [cmap_rgb],
664+
}
664665
return {}
665666

666667
buttons = []
@@ -700,18 +701,7 @@ def _update_colors(i: int) -> dict:
700701
node_col_arr = list(node_col_agg.values())
701702
scatter_text = self._text(node_col_agg)
702703
cbar = self._colorbar(titles[i])
703-
if self.dim == 2:
704-
return {
705-
"text": [scatter_text],
706-
**{
707-
f"marker.colorbar.{'.'.join(k.split('_'))}": [v]
708-
for k, v in cbar.items()
709-
},
710-
"marker.color": [node_col_arr],
711-
"marker.cmax": [max(node_col_arr, default=None)],
712-
"marker.cmin": [min(node_col_arr, default=None)],
713-
}
714-
elif self.dim == 3:
704+
if self.dim == 3:
715705
edge_col = self._edge_colors_from_node_colors(
716706
node_col_agg,
717707
)
@@ -728,9 +718,20 @@ def _update_colors(i: int) -> dict:
728718
"line.cmax": [max(node_col_arr, default=None), None],
729719
"line.cmin": [min(node_col_arr, default=None), None],
730720
}
721+
elif self.dim == 2:
722+
return {
723+
"text": [scatter_text],
724+
**{
725+
f"marker.colorbar.{'.'.join(k.split('_'))}": [v]
726+
for k, v in cbar.items()
727+
},
728+
"marker.color": [node_col_arr],
729+
"marker.cmax": [max(node_col_arr, default=None)],
730+
"marker.cmin": [min(node_col_arr, default=None)],
731+
}
731732
return {}
732733

733-
target_traces = [1] if self.dim == 2 else [0, 1]
734+
target_traces = [0, 1] if self.dim == 3 else [1]
734735

735736
buttons = []
736737
if colors.shape[1] > 1:

0 commit comments

Comments
 (0)