Skip to content

Commit ced1796

Browse files
authored
Merge pull request #198 from alejoe91/performance-panel
Improve web mode "snappiness"
2 parents 2997ecb + 601de0d commit ced1796

12 files changed

+546
-314
lines changed

spikeinterface_gui/basescatterview.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ class BaseScatterView(ViewBase):
99
_depend_on = None
1010
_settings = [
1111
{'name': "auto_decimate", 'type': 'bool', 'value' : True },
12-
{'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 10_000 },
12+
{'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 5_000 },
1313
{'name': 'alpha', 'type': 'float', 'value' : 0.7, 'limits':(0, 1.), 'step':0.05 },
1414
{'name': 'scatter_size', 'type': 'float', 'value' : 2., 'step':0.5 },
15-
{'name': 'num_bins', 'type': 'int', 'value' : 400, 'step': 1 },
15+
{'name': 'num_bins', 'type': 'int', 'value' : 100, 'step': 1 },
1616
]
1717
_need_compute = False
1818

@@ -407,6 +407,8 @@ def _panel_make_layout(self):
407407
# Add SelectionGeometry event handler to capture lasso vertices
408408
self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry)
409409

410+
self.hist_source = ColumnDataSource(data={"x": [], "y": []})
411+
self.hist_data_source = ColumnDataSource(data=dict(x=[], y=[], color=[]))
410412
self.hist_fig = bpl.figure(
411413
tools="reset,wheel_zoom",
412414
sizing_mode="stretch_both",
@@ -416,6 +418,8 @@ def _panel_make_layout(self):
416418
y_range=self.y_range,
417419
styles={"flex": "1"} # Make histogram narrower than scatter plot
418420
)
421+
self.lines_hist = self.hist_fig.multi_line('x', 'y', source=self.hist_data_source,
422+
line_color='color', line_width=2)
419423
self.hist_fig.toolbar.logo = None
420424
self.hist_fig.yaxis.axis_label = self.y_label
421425
self.hist_fig.xaxis.axis_label = "Count"
@@ -447,24 +451,23 @@ def _panel_make_layout(self):
447451
),
448452
)
449453
)
450-
self.hist_lines = []
454+
# self.hist_lines = []
451455
self.noise_harea = []
452456
self.plotted_inds = []
453457

454458
def _panel_refresh(self):
455459
from bokeh.models import ColumnDataSource, Range1d
456460

457-
# clear figures
458-
for renderer in self.hist_lines:
459-
self.hist_fig.renderers.remove(renderer)
460-
self.hist_lines = []
461461
self.plotted_inds = []
462462

463463
max_count = 1
464464
xs = []
465465
ys = []
466466
colors = []
467467

468+
xh = []
469+
yh = []
470+
colors_h = []
468471
segment_index = self.controller.get_time()[1]
469472
# get view segment index from segment selector
470473
segment_index_from_selector = self.segment_selector.options.index(self.segment_selector.value)
@@ -484,33 +487,33 @@ def _panel_refresh(self):
484487
max_count = max(max_count, np.max(hist_count))
485488
self.plotted_inds.extend(inds)
486489

487-
hist_lines = self.hist_fig.line(
488-
"x",
489-
"y",
490-
source=ColumnDataSource(
491-
{"x":hist_count,
492-
"y":hist_bins[:-1],
493-
}
494-
),
495-
line_color=color,
496-
line_width=2,
497-
)
498-
self.hist_lines.append(hist_lines)
490+
# Prepare data for multi_line
491+
xh.append(hist_count)
492+
yh.append(hist_bins[:-1])
493+
colors_h.append(color)
494+
499495
t_start, t_end = self.controller.get_t_start_t_stop()
500496
self.scatter_fig.x_range.start = t_start
501497
self.scatter_fig.x_range.end = t_end
502498

503499
self._max_count = max_count
504500

505501
# Add scatter plot with correct alpha parameter
506-
self.scatter_source.data = {
507-
"x": xs,
508-
"y": ys,
509-
"color": colors
510-
}
502+
self.scatter_source.data = dict(
503+
x=xs,
504+
y=ys,
505+
color=colors
506+
)
511507
self.scatter.glyph.size = self.settings['scatter_size']
512508
self.scatter.glyph.fill_alpha = self.settings['alpha']
513509

510+
# Update histogram multi_line data
511+
self.hist_data_source.data = dict(
512+
x=xh,
513+
y=yh,
514+
color=colors_h
515+
)
516+
514517
# handle selected spikes
515518
self._panel_update_selected_spikes()
516519

@@ -529,7 +532,6 @@ def _panel_on_select_button(self, event):
529532
self.scatter_fig.toolbar.active_drag = None
530533
self.scatter_source.selected.indices = []
531534

532-
533535
def _panel_change_segment(self, event):
534536
self._current_selected = 0
535537
segment_index = int(self.segment_selector.value.split()[-1])

spikeinterface_gui/crosscorrelogramview.py

Lines changed: 89 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def __init__(self, controller=None, parent=None, backend="qt"):
1818
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
1919

2020
self.ccg, self.bins = self.controller.get_correlograms()
21+
self.figure_cache = {}
22+
self.max_cache_size = 20
2123

2224

2325
def _on_settings_changed(self):
@@ -64,24 +66,33 @@ def _qt_refresh(self):
6466

6567
for r in range(n):
6668
for c in range(r, n):
67-
68-
i = unit_ids.index(visible_unit_ids[r])
69-
j = unit_ids.index(visible_unit_ids[c])
70-
count = ccg[i, j, :]
71-
72-
plot = pg.PlotItem()
73-
if not self.settings['display_axis']:
74-
plot.hideAxis('bottom')
75-
plot.hideAxis('left')
76-
77-
if r==c:
78-
unit_id = visible_unit_ids[r]
79-
color = colors[unit_id]
69+
unit_id1 = visible_unit_ids[r]
70+
unit_id2 = visible_unit_ids[c]
71+
if (unit_id1, unit_id2) in self.figure_cache:
72+
plot = self.figure_cache[(unit_id1, unit_id2)]
8073
else:
81-
color = (120,120,120,120)
82-
83-
curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=color, pen=color)
84-
plot.addItem(curve)
74+
# create new plot
75+
i = unit_ids.index(visible_unit_ids[r])
76+
j = unit_ids.index(visible_unit_ids[c])
77+
count = ccg[i, j, :]
78+
79+
plot = pg.PlotItem()
80+
if not self.settings['display_axis']:
81+
plot.hideAxis('bottom')
82+
plot.hideAxis('left')
83+
84+
if r == c:
85+
unit_id = visible_unit_ids[r]
86+
color = colors[unit_id]
87+
else:
88+
color = (120, 120, 120, 120)
89+
90+
curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=color, pen=color)
91+
plot.addItem(curve)
92+
# cache plot
93+
if len(self.figure_cache) >= self.max_cache_size:
94+
self.figure_cache.pop(next(iter(self.figure_cache)))
95+
self.figure_cache[(unit_id1, unit_id2)] = plot
8596
self.grid.addItem(plot, row=r, col=c)
8697

8798
## panel ##
@@ -102,18 +113,12 @@ def _panel_make_layout(self):
102113
self.empty_plot_pane,
103114
sizing_mode="stretch_both",
104115
)
105-
self.is_warning_active = False
106-
107-
self.plots = []
108116

109117
def _panel_refresh(self):
110118
import panel as pn
111119
import bokeh.plotting as bpl
112120
from bokeh.layouts import gridplot
113-
from .utils_panel import _bg_color, insert_warning, clear_warning
114-
115-
# clear previous plot
116-
self.plots = []
121+
from .utils_panel import _bg_color
117122

118123
if self.ccg is None:
119124
return
@@ -127,67 +132,75 @@ def _panel_refresh(self):
127132
}
128133
ccg = self.ccg
129134
bins = self.bins
130-
135+
figures = []
131136
first_fig = None
132137
for r in range(n):
133138
row_plots = []
134139
for c in range(r, n):
135-
i = unit_ids.index(visible_unit_ids[r])
136-
j = unit_ids.index(visible_unit_ids[c])
137-
count = ccg[i, j, :]
140+
unit1 = visible_unit_ids[r]
141+
unit2 = visible_unit_ids[c]
138142

139-
# Create Bokeh figure
140-
if first_fig is not None:
141-
extra_kwargs = dict(x_range=first_fig.x_range)
143+
if (unit1, unit2) in self.figure_cache:
144+
fig = self.figure_cache[(unit1, unit2)]
142145
else:
143-
extra_kwargs = dict()
144-
fig = bpl.figure(
145-
width=250,
146-
height=250,
147-
tools="pan,wheel_zoom,reset",
148-
active_drag="pan",
149-
active_scroll="wheel_zoom",
150-
background_fill_color=_bg_color,
151-
border_fill_color=_bg_color,
152-
outline_line_color="white",
153-
**extra_kwargs,
154-
)
155-
fig.toolbar.logo = None
156-
157-
# Get color from controller
158-
if r == c:
159-
unit_id = visible_unit_ids[r]
160-
color = colors[unit_id]
161-
fill_alpha = 0.7
162-
else:
163-
color = "lightgray"
164-
fill_alpha = 0.4
165-
166-
fig.quad(
167-
top=count,
168-
bottom=0,
169-
left=bins[:-1],
170-
right=bins[1:],
171-
fill_color=color,
172-
line_color=color,
173-
alpha=fill_alpha,
174-
)
175-
if first_fig is None:
176-
first_fig = fig
177-
146+
# create new figure
147+
i = unit_ids.index(unit1)
148+
j = unit_ids.index(unit2)
149+
count = ccg[i, j, :]
150+
151+
# Create Bokeh figure
152+
if first_fig is not None:
153+
extra_kwargs = dict(x_range=first_fig.x_range)
154+
else:
155+
extra_kwargs = dict()
156+
fig = bpl.figure(
157+
width=250,
158+
height=250,
159+
tools="pan,wheel_zoom,reset",
160+
active_drag="pan",
161+
active_scroll="wheel_zoom",
162+
background_fill_color=_bg_color,
163+
border_fill_color=_bg_color,
164+
outline_line_color="white",
165+
**extra_kwargs,
166+
)
167+
fig.toolbar.logo = None
168+
169+
# Get color from controller
170+
if r == c:
171+
unit_id = visible_unit_ids[r]
172+
color = colors[unit_id]
173+
fill_alpha = 0.7
174+
else:
175+
color = "lightgray"
176+
fill_alpha = 0.4
177+
178+
fig.quad(
179+
top=count,
180+
bottom=0,
181+
left=bins[:-1],
182+
right=bins[1:],
183+
fill_color=color,
184+
line_color=color,
185+
alpha=fill_alpha,
186+
)
187+
if first_fig is None:
188+
first_fig = fig
189+
# Cache figure
190+
if len(self.figure_cache) >= self.max_cache_size:
191+
self.figure_cache.pop(next(iter(self.figure_cache)))
192+
self.figure_cache[(unit1, unit2)] = fig
178193
row_plots.append(fig)
179194
# Fill row with None for proper spacing
180195
full_row = [None] * r + row_plots + [None] * (n - len(row_plots))
181-
self.plots.append(full_row)
182-
183-
if len(self.plots) > 0:
184-
grid = gridplot(self.plots, toolbar_location="right", sizing_mode="stretch_both")
185-
self.layout[0] = pn.Column(
186-
grid,
187-
styles={'background-color': f'{_bg_color}'}
188-
)
189-
else:
190-
self.layout[0] = self.empty_plot_pane
196+
figures.append(full_row)
197+
198+
grid = gridplot(figures, toolbar_location="right", sizing_mode="stretch_both")
199+
grid.toolbar.logo = None
200+
self.layout[0] = pn.Column(
201+
grid,
202+
styles={'background-color': f'{_bg_color}'}
203+
)
191204

192205

193206

spikeinterface_gui/curationview.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def _qt_make_layout(self):
8282

8383
v = QT.QVBoxLayout()
8484
h.addLayout(v)
85-
v.addWidget(QT.QLabel("<b>Deleted</b>"))
8685
self.table_delete = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
8786
selectionBehavior=QT.QAbstractItemView.SelectRows)
8887
v.addWidget(self.table_delete)
@@ -99,7 +98,6 @@ def _qt_make_layout(self):
9998

10099
v = QT.QVBoxLayout()
101100
h.addLayout(v)
102-
v.addWidget(QT.QLabel("<b>Merges</b>"))
103101
self.table_merge = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
104102
selectionBehavior=QT.QAbstractItemView.SelectRows)
105103
# self.table_merge.setContextMenuPolicy(QT.Qt.CustomContextMenu)
@@ -118,7 +116,6 @@ def _qt_make_layout(self):
118116

119117
v = QT.QVBoxLayout()
120118
h.addLayout(v)
121-
v.addWidget(QT.QLabel("<b>Splits</b>"))
122119
self.table_split = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection,
123120
selectionBehavior=QT.QAbstractItemView.SelectRows)
124121
v.addWidget(self.table_split)
@@ -139,7 +136,7 @@ def _qt_refresh(self):
139136
self.table_merge.clear()
140137
self.table_merge.setRowCount(len(merged_units))
141138
self.table_merge.setColumnCount(1)
142-
self.table_merge.setHorizontalHeaderLabels(["Merges"])
139+
self.table_merge.setHorizontalHeaderLabels(["merges"])
143140
self.table_merge.setSortingEnabled(False)
144141
for ix, group in enumerate(merged_units):
145142
item = QT.QTableWidgetItem(str(group))
@@ -153,7 +150,7 @@ def _qt_refresh(self):
153150
self.table_delete.clear()
154151
self.table_delete.setRowCount(len(removed_units))
155152
self.table_delete.setColumnCount(1)
156-
self.table_delete.setHorizontalHeaderLabels(["unit_id"])
153+
self.table_delete.setHorizontalHeaderLabels(["removed"])
157154
self.table_delete.setSortingEnabled(False)
158155
for i, unit_id in enumerate(removed_units):
159156
color = self.get_unit_color(unit_id)
@@ -172,7 +169,7 @@ def _qt_refresh(self):
172169
self.table_split.clear()
173170
self.table_split.setRowCount(len(splits))
174171
self.table_split.setColumnCount(1)
175-
self.table_split.setHorizontalHeaderLabels(["Split units"])
172+
self.table_split.setHorizontalHeaderLabels(["splits"])
176173
self.table_split.setSortingEnabled(False)
177174
for i, split in enumerate(splits):
178175
unit_id = split["unit_id"]

0 commit comments

Comments
 (0)