Skip to content

Commit e405748

Browse files
authored
Merge pull request #181 from alejoe91/handle-times
Add main_setting to use recording times
2 parents 120907a + 1a49a15 commit e405748

17 files changed

+429
-213
lines changed

spikeinterface_gui/backend_panel.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class SignalNotifier(param.Parameterized):
1313
channel_visibility_changed = param.Event()
1414
manual_curation_updated = param.Event()
1515
time_info_updated = param.Event()
16+
use_times_updated = param.Event()
1617
active_view_updated = param.Event()
1718
unit_color_changed = param.Event()
1819

@@ -35,6 +36,9 @@ def notify_manual_curation_updated(self):
3536
def notify_time_info_updated(self):
3637
self.param.trigger("time_info_updated")
3738

39+
def notify_use_times_updated(self):
40+
self.param.trigger("use_times_updated")
41+
3842
def notify_active_view_updated(self):
3943
# this is used to keep an "active view" in the main window
4044
# when a view triggers this event, it self-declares it as active
@@ -65,6 +69,7 @@ def connect_view(self, view):
6569
view.notifier.param.watch(self.on_channel_visibility_changed, "channel_visibility_changed")
6670
view.notifier.param.watch(self.on_manual_curation_updated, "manual_curation_updated")
6771
view.notifier.param.watch(self.on_time_info_updated, "time_info_updated")
72+
view.notifier.param.watch(self.on_use_times_updated, "use_times_updated")
6873
view.notifier.param.watch(self.on_active_view_updated, "active_view_updated")
6974
view.notifier.param.watch(self.on_unit_color_changed, "unit_color_changed")
7075

@@ -110,6 +115,15 @@ def on_time_info_updated(self, param):
110115
continue
111116
view.on_time_info_updated()
112117

118+
def on_use_times_updated(self, param):
119+
# use times is updated also when a view is not active
120+
if not self._active:
121+
return
122+
for view in self.controller.views:
123+
if param.obj.view == view:
124+
continue
125+
view.on_use_times_updated()
126+
113127
def on_active_view_updated(self, param):
114128
if not self._active:
115129
return

spikeinterface_gui/backend_qt.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class SignalNotifier(QT.QObject):
1919
channel_visibility_changed = QT.pyqtSignal()
2020
manual_curation_updated = QT.pyqtSignal()
2121
time_info_updated = QT.pyqtSignal()
22+
use_times_updated = QT.pyqtSignal()
2223
unit_color_changed = QT.pyqtSignal()
2324

2425
def __init__(self, parent=None, view=None):
@@ -40,6 +41,9 @@ def notify_manual_curation_updated(self):
4041
def notify_time_info_updated(self):
4142
self.time_info_updated.emit()
4243

44+
def notify_use_times_updated(self):
45+
self.use_times_updated.emit()
46+
4347
def notify_unit_color_changed(self):
4448
self.unit_color_changed.emit()
4549

@@ -63,6 +67,7 @@ def connect_view(self, view):
6367
view.notifier.channel_visibility_changed.connect(self.on_channel_visibility_changed)
6468
view.notifier.manual_curation_updated.connect(self.on_manual_curation_updated)
6569
view.notifier.time_info_updated.connect(self.on_time_info_updated)
70+
view.notifier.use_times_updated.connect(self.on_use_times_updated)
6671
view.notifier.unit_color_changed.connect(self.on_unit_color_changed)
6772

6873
def on_spike_selection_changed(self):
@@ -110,7 +115,16 @@ def on_time_info_updated(self):
110115
# do not refresh it self
111116
continue
112117
view.on_time_info_updated()
113-
118+
119+
def on_use_times_updated(self):
120+
if not self._active:
121+
return
122+
for view in self.controller.views:
123+
if view.qt_widget == self.sender().parent():
124+
# do not refresh it self
125+
continue
126+
view.on_use_times_updated()
127+
114128
def on_unit_color_changed(self):
115129
if not self._active:
116130
return
@@ -383,7 +397,6 @@ def open_help(self):
383397
def refresh(self):
384398
view = self._view()
385399
view.refresh()
386-
387400

388401
areas = {
389402
'right' : QT.Qt.RightDockWidgetArea,

spikeinterface_gui/basescatterview.py

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q
3232
self._lasso_vertices = {segment_index: None for segment_index in range(controller.num_segments)}
3333
# this is used in panel
3434
self._current_selected = 0
35+
self._block_auto_refresh_and_notify = False
3536

3637
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
3738

3839

39-
def get_unit_data(self, unit_id, seg_index=0):
40-
inds = self.controller.get_spike_indices(unit_id, seg_index=seg_index)
41-
spike_times = self.controller.spikes["sample_index"][inds] / self.controller.sampling_frequency
40+
def get_unit_data(self, unit_id, segment_index=0):
41+
inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index)
42+
spike_indices = self.controller.spikes["sample_index"][inds]
43+
spike_times = self.controller.sample_index_to_time(spike_indices)
4244
spike_data = self.spike_data[inds]
4345
ptp = np.ptp(spike_data)
4446
hist_min, hist_max = [np.min(spike_data) - 0.2 * ptp, np.max(spike_data) + 0.2 * ptp]
@@ -53,15 +55,15 @@ def get_unit_data(self, unit_id, seg_index=0):
5355

5456
return spike_times, spike_data, hist_count, hist_bins, inds
5557

56-
def get_selected_spikes_data(self, seg_index=0, visible_inds=None):
57-
sl = self.controller.segment_slices[seg_index]
58+
def get_selected_spikes_data(self, segment_index=0, visible_inds=None):
59+
sl = self.controller.segment_slices[segment_index]
5860
spikes_in_seg = self.controller.spikes[sl]
5961
selected_indices = self.controller.get_indices_spike_selected()
6062
if visible_inds is not None:
6163
selected_indices = np.intersect1d(selected_indices, visible_inds)
6264
mask = np.isin(sl.start + np.arange(len(spikes_in_seg)), selected_indices)
6365
selected_spikes = spikes_in_seg[mask]
64-
spike_times = selected_spikes['sample_index'] / self.controller.sampling_frequency
66+
spike_times = self.controller.sample_index_to_time(selected_spikes['sample_index'])
6567
spike_data = self.spike_data[sl][mask]
6668
return (spike_times, spike_data)
6769

@@ -85,8 +87,8 @@ def select_all_spikes_from_lasso(self, keep_already_selected=False):
8587
for segment_index, vertices in self._lasso_vertices.items():
8688
if vertices is None:
8789
continue
88-
spike_inds = self.controller.get_spike_indices(visible_unit_id, seg_index=segment_index)
89-
spike_times = self.controller.spikes["sample_index"][spike_inds] / fs
90+
spike_inds = self.controller.get_spike_indices(visible_unit_id, segment_index=segment_index)
91+
spike_times = self.controller.sample_index_to_time(self.controller.spikes["sample_index"][spike_inds])
9092
spike_data = self.spike_data[spike_inds]
9193

9294
points = np.column_stack((spike_times, spike_data))
@@ -119,7 +121,7 @@ def split(self):
119121

120122
if self.controller.num_segments > 1:
121123
# check that lasso vertices are defined for all segments
122-
if not all(self._lasso_vertices[seg_index] is not None for seg_index in range(self.controller.num_segments)):
124+
if not all(self._lasso_vertices[segment_index] is not None for segment_index in range(self.controller.num_segments)):
123125
# Use the new continue_from_user pattern
124126
self.continue_from_user(
125127
"Not all segments have lasso selection. "
@@ -163,6 +165,15 @@ def on_unit_visibility_changed(self):
163165
self._current_selected = self.controller.get_indices_spike_selected().size
164166
self.refresh()
165167

168+
def _qt_on_time_info_updated(self):
169+
if self.combo_seg.currentIndex() != self.controller.get_time()[1]:
170+
self._block_auto_refresh_and_notify = True
171+
self.refresh()
172+
self._block_auto_refresh_and_notify = False
173+
174+
def on_use_times_updated(self):
175+
self.refresh()
176+
166177
## QT zone ##
167178
def _qt_make_layout(self):
168179
from .myqt import QT
@@ -174,8 +185,8 @@ def _qt_make_layout(self):
174185
tb = self.qt_widget.view_toolbar
175186
self.combo_seg = QT.QComboBox()
176187
tb.addWidget(self.combo_seg)
177-
self.combo_seg.addItems([ f'Segment {seg_index}' for seg_index in range(self.controller.num_segments) ])
178-
self.combo_seg.currentIndexChanged.connect(self.refresh)
188+
self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ])
189+
self.combo_seg.currentIndexChanged.connect(self._qt_change_segment)
179190
add_stretch_to_qtoolbar(tb)
180191
self.lasso_but = QT.QPushButton("select", checkable = True)
181192
tb.addWidget(self.lasso_but)
@@ -184,9 +195,6 @@ def _qt_make_layout(self):
184195
self.split_but = QT.QPushButton("split")
185196
tb.addWidget(self.split_but)
186197
self.split_but.clicked.connect(self.split)
187-
shortcut_split = QT.QShortcut(self.qt_widget)
188-
shortcut_split.setKey(QT.QKeySequence("ctrl+s"))
189-
shortcut_split.activated.connect(self.split)
190198
h = QT.QHBoxLayout()
191199
self.layout.addLayout(h)
192200

@@ -235,6 +243,13 @@ def _qt_initialize_plot(self):
235243
def _qt_on_spike_selection_changed(self):
236244
self.refresh()
237245

246+
def _qt_change_segment(self):
247+
segment_index = self.combo_seg.currentIndex()
248+
self.controller.set_time(segment_index=segment_index)
249+
if not self._block_auto_refresh_and_notify:
250+
self.refresh()
251+
self.notify_time_info_updated()
252+
238253
def _qt_refresh(self):
239254
from .myqt import QT
240255
import pyqtgraph as pg
@@ -246,13 +261,18 @@ def _qt_refresh(self):
246261
if self.spike_data is None:
247262
return
248263

264+
segment_index = self.controller.get_time()[1]
265+
# Update combo_seg if it doesn't match the current segment index
266+
if self.combo_seg.currentIndex() != segment_index:
267+
self.combo_seg.setCurrentIndex(segment_index)
268+
249269
max_count = 1
250270
all_inds = []
251271
for unit_id in self.controller.get_visible_unit_ids():
252272

253273
spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data(
254274
unit_id,
255-
seg_index=self.combo_seg.currentIndex()
275+
segment_index=segment_index
256276
)
257277

258278
# make a copy of the color
@@ -276,7 +296,7 @@ def _qt_refresh(self):
276296
y_range_plot_1 = self.plot.getViewBox().viewRange()
277297
self.viewBox2.setYRange(y_range_plot_1[1][0], y_range_plot_1[1][1], padding = 0.0)
278298

279-
spike_times, spike_data = self.get_selected_spikes_data(seg_index=self.combo_seg.currentIndex(), visible_inds=all_inds)
299+
spike_times, spike_data = self.get_selected_spikes_data(segment_index=self.combo_seg.currentIndex(), visible_inds=all_inds)
280300

281301
self.scatter_select.setData(spike_times, spike_data)
282302

@@ -296,8 +316,8 @@ def _qt_on_lasso_finished(self, points, shift_held=False):
296316
self.lasso.setData([], [])
297317
vertices = np.array(points)
298318

299-
seg_index = self.combo_seg.currentIndex()
300-
sl = self.controller.segment_slices[seg_index]
319+
segment_index = self.combo_seg.currentIndex()
320+
sl = self.controller.segment_slices[segment_index]
301321
spikes_in_seg = self.controller.spikes[sl]
302322

303323
# Create mask for visible units
@@ -315,16 +335,16 @@ def _qt_on_lasso_finished(self, points, shift_held=False):
315335
self.notify_spike_selection_changed()
316336
return
317337

318-
if self._lasso_vertices[seg_index] is None:
319-
self._lasso_vertices[seg_index] = []
338+
if self._lasso_vertices[segment_index] is None:
339+
self._lasso_vertices[segment_index] = []
320340

321341
if shift_held:
322342
# If shift is held, append the vertices to the current lasso vertices
323-
self._lasso_vertices[seg_index].append(vertices)
343+
self._lasso_vertices[segment_index].append(vertices)
324344
keep_already_selected = True
325345
else:
326346
# If shift is not held, clear the existing lasso vertices for this segment
327-
self._lasso_vertices[seg_index] = [vertices]
347+
self._lasso_vertices[segment_index] = [vertices]
328348
keep_already_selected = False
329349

330350
self.select_all_spikes_from_lasso(keep_already_selected=keep_already_selected)
@@ -341,11 +361,11 @@ def _panel_make_layout(self):
341361

342362
self.lasso_tool = LassoSelectTool()
343363

344-
self.segment_index = 0
364+
segment_index = self.controller.get_time()[1]
345365
self.segment_selector = pn.widgets.Select(
346366
name="",
347367
options=[f"Segment {i}" for i in range(self.controller.num_segments)],
348-
value=f"Segment {self.segment_index}",
368+
value=f"Segment {segment_index}",
349369
)
350370
self.segment_selector.param.watch(self._panel_change_segment, 'value')
351371

@@ -381,8 +401,8 @@ def _panel_make_layout(self):
381401
self.scatter_fig.toolbar.active_drag = None
382402
self.scatter_fig.xaxis.axis_label = "Time (s)"
383403
self.scatter_fig.yaxis.axis_label = self.y_label
384-
time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency
385-
self.scatter_fig.x_range = Range1d(0., time_max)
404+
t_start, t_stop = self.controller.get_t_start_t_stop()
405+
self.scatter_fig.x_range = Range1d(t_start, t_stop)
386406

387407
# Add SelectionGeometry event handler to capture lasso vertices
388408
self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry)
@@ -445,11 +465,17 @@ def _panel_refresh(self):
445465
ys = []
446466
colors = []
447467

468+
segment_index = self.controller.get_time()[1]
469+
# get view segment index from segment selector
470+
segment_index_from_selector = self.segment_selector.options.index(self.segment_selector.value)
471+
if segment_index != segment_index_from_selector:
472+
self.segment_selector.value = f"Segment {segment_index}"
473+
448474
visible_unit_ids = self.controller.get_visible_unit_ids()
449475
for unit_id in visible_unit_ids:
450476
spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data(
451477
unit_id,
452-
seg_index=self.segment_index
478+
segment_index=segment_index
453479
)
454480
color = self.get_unit_color(unit_id)
455481
xs.extend(spike_times)
@@ -470,6 +496,9 @@ def _panel_refresh(self):
470496
line_width=2,
471497
)
472498
self.hist_lines.append(hist_lines)
499+
t_start, t_end = self.controller.get_t_start_t_stop()
500+
self.scatter_fig.x_range.start = t_start
501+
self.scatter_fig.x_range.end = t_end
473502

474503
self._max_count = max_count
475504

@@ -503,10 +532,13 @@ def _panel_on_select_button(self, event):
503532

504533
def _panel_change_segment(self, event):
505534
self._current_selected = 0
506-
self.segment_index = int(self.segment_selector.value.split()[-1])
507-
time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency
508-
self.scatter_fig.x_range.end = time_max
535+
segment_index = int(self.segment_selector.value.split()[-1])
536+
self.controller.set_time(segment_index=segment_index)
537+
t_start, t_end = self.controller.get_t_start_t_stop()
538+
self.scatter_fig.x_range.start = t_start
539+
self.scatter_fig.x_range.end = t_end
509540
self.refresh()
541+
self.notify_time_info_updated()
510542

511543
def _on_panel_selection_geometry(self, event):
512544
"""
@@ -524,16 +556,16 @@ def _on_panel_selection_geometry(self, event):
524556
return
525557

526558
# Append the current polygon to the lasso vertices if shift is held
527-
seg_index = self.segment_index
528-
if self._lasso_vertices[seg_index] is None:
529-
self._lasso_vertices[seg_index] = []
559+
segment_index = self.controller.get_time()[1]
560+
if self._lasso_vertices[segment_index] is None:
561+
self._lasso_vertices[segment_index] = []
530562
if len(selected) > self._current_selected:
531563
self._current_selected = len(selected)
532564
# Store the current polygon for the current segment
533-
self._lasso_vertices[seg_index].append(polygon)
565+
self._lasso_vertices[segment_index].append(polygon)
534566
keep_already_selected = True
535567
else:
536-
self._lasso_vertices[seg_index] = [polygon]
568+
self._lasso_vertices[segment_index] = [polygon]
537569
keep_already_selected = False
538570

539571
self.select_all_spikes_from_lasso(keep_already_selected)
@@ -551,7 +583,8 @@ def _panel_update_selected_spikes(self):
551583
selected_spike_indices = np.intersect1d(selected_spike_indices, self.plotted_inds)
552584
if len(selected_spike_indices) > 0:
553585
# map absolute indices to visible spikes
554-
sl = self.controller.segment_slices[self.segment_index]
586+
segment_index = self.controller.get_time()[1]
587+
sl = self.controller.segment_slices[segment_index]
555588
spikes_in_seg = self.controller.spikes[sl]
556589
visible_mask = np.zeros(len(spikes_in_seg), dtype=bool)
557590
for unit_index, unit_id in self.controller.iter_visible_units():
@@ -573,7 +606,8 @@ def _panel_on_spike_selection_changed(self):
573606
return
574607
elif len(selected_indices) == 1:
575608
selected_segment = self.controller.spikes[selected_indices[0]]['segment_index']
576-
if selected_segment != self.segment_index:
609+
segment_index = self.controller.get_time()[1]
610+
if selected_segment != segment_index:
577611
self.segment_selector.value = f"Segment {selected_segment}"
578612
self._panel_change_segment(None)
579613
# update selected spikes

0 commit comments

Comments
 (0)