Skip to content

Commit 38c1a6d

Browse files
committed
Improved node size settings
1 parent 7dd9051 commit 38c1a6d

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

src/tdamapper/plot_backends/plot_plotly.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
_NODE_OPACITY = 1.0
2121

22-
_EDGE_WIDTH = 0.75
22+
_EDGE_WIDTH_2D = 0.75
23+
24+
_EDGE_WIDTH_3D = 1.5
2325

2426
_EDGE_OPACITY = 1.0
2527

@@ -33,6 +35,10 @@
3335

3436
_DEFAULT_SPACING = 0.25
3537

38+
_MAX_SIZEREF = 1000000000
39+
40+
_MARKER_SIZE_FACTOR = 400.0
41+
3642
MENU_DARK_MODE_NAME = "menu_dark_mode"
3743

3844
MENU_CMAP_NAME = "menu_cmap"
@@ -262,12 +268,11 @@ def _edge_pos_array(self):
262268
edges_arr[i].append(None)
263269
return edges_arr
264270

265-
def _marker_size(self, node_size: float) -> List[float]:
271+
def _marker_size(self) -> List[float]:
266272
attr_size = nx.get_node_attributes(self.graph, ATTR_SIZE)
267273
max_size = max(attr_size.values(), default=1.0)
268-
scale = node_size * (25.0 if self.dim == 2 else 15.0)
269274
marker_size = [
270-
scale * math.sqrt(attr_size[n] / max_size) for n in self.graph.nodes()
275+
_MARKER_SIZE_FACTOR * attr_size[n] / max_size for n in self.graph.nodes()
271276
]
272277
return marker_size
273278

@@ -317,6 +322,7 @@ def _set_colors(self, colors, agg):
317322
patch=dict(
318323
text=scatter_text,
319324
marker=dict(
325+
opacity=_NODE_OPACITY,
320326
color=node_col_arr,
321327
cmin=min(node_col_arr, default=None),
322328
cmax=max(node_col_arr, default=None),
@@ -352,7 +358,11 @@ def set_node_size(self, node_size: float) -> None:
352358
return
353359
self.fig.update_traces(
354360
patch=dict(
355-
marker_size=self._marker_size(node_size),
361+
marker_sizeref=(
362+
_MAX_SIZEREF if node_size == 0.0 else (1.0 / node_size) ** 2
363+
),
364+
marker_sizemin=1.0,
365+
marker_sizemode="area",
356366
),
357367
selector=dict(name=_NODES_TRACE),
358368
)
@@ -432,7 +442,7 @@ def _nodes_trace(self, node_pos_arr) -> Union[go.Scatter, go.Scatter3d]:
432442
marker=dict(
433443
showscale=True,
434444
reversescale=False,
435-
size=self._marker_size(DEFAULT_NODE_SIZE),
445+
size=self._marker_size(),
436446
opacity=_NODE_OPACITY,
437447
line_width=_NODE_OUTER_WIDTH,
438448
line_color=_NODE_OUTER_COLOR,
@@ -454,7 +464,6 @@ def _edges_trace(self, edge_pos_arr) -> Union[go.Scatter, go.Scatter3d]:
454464
y=edge_pos_arr[1],
455465
mode="lines",
456466
opacity=_EDGE_OPACITY,
457-
line_width=_EDGE_WIDTH,
458467
line_color=_EDGE_COLOR,
459468
hoverinfo="skip",
460469
)
@@ -463,14 +472,18 @@ def _edges_trace(self, edge_pos_arr) -> Union[go.Scatter, go.Scatter3d]:
463472
dict(
464473
z=edge_pos_arr[2],
465474
line_colorscale=DEFAULT_CMAP,
475+
line_width=_EDGE_WIDTH_3D,
476+
marker_line_width=_EDGE_WIDTH_3D,
466477
),
467478
)
468479
return go.Scatter3d(scatter)
469480
else:
470481
scatter.update(
471482
dict(
472483
marker_colorscale=DEFAULT_CMAP,
484+
marker_line_width=_EDGE_WIDTH_2D,
473485
marker_line_colorscale=DEFAULT_CMAP,
486+
line_width=_EDGE_WIDTH_2D,
474487
),
475488
)
476489
return go.Scatter(scatter)
@@ -745,13 +758,17 @@ def _ui_slider_node_size(self, node_sizes: List[float]) -> Dict:
745758
steps = [
746759
dict(
747760
method="restyle",
748-
label=f"{size}",
761+
label=f"{node_size}",
749762
args=[
750-
{"marker.size": [self._marker_size(size)]},
763+
{
764+
"marker.sizeref": [
765+
_MAX_SIZEREF if node_size == 0.0 else (1.0 / node_size) ** 2
766+
]
767+
},
751768
[1],
752769
],
753770
)
754-
for size in node_sizes
771+
for node_size in node_sizes
755772
]
756773

757774
return dict(

0 commit comments

Comments
 (0)