Skip to content

Commit 3032112

Browse files
committed
fix random_spikes_selection()
1 parent a09dd57 commit 3032112

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/spikeinterface/core/sorting_tools.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,21 @@ def random_spikes_selection(
197197
cum_sizes = np.cumsum([0] + [s.size for s in spikes])
198198

199199
# this fast when numba
200-
spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids)
200+
spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids, absolute_index=False)
201201

202202
random_spikes_indices = []
203203
for unit_index, unit_id in enumerate(sorting.unit_ids):
204204
all_unit_indices = []
205205
for segment_index in range(sorting.get_num_segments()):
206-
inds_in_seg = spike_indices[segment_index][unit_id] + cum_sizes[segment_index]
206+
# this is local index
207+
inds_in_seg = spike_indices[segment_index][unit_id]
207208
if margin_size is not None:
208-
inds_in_seg = inds_in_seg[inds_in_seg >= margin_size]
209-
inds_in_seg = inds_in_seg[inds_in_seg < (num_samples[segment_index] - margin_size)]
210-
all_unit_indices.append(inds_in_seg)
209+
local_spikes = spikes[segment_index][inds_in_seg]
210+
mask = (local_spikes["sample_index"] >= margin_size) & (local_spikes["sample_index"] < (num_samples[segment_index] - margin_size))
211+
inds_in_seg = inds_in_seg[mask]
212+
# go back to absolut index
213+
inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index]
214+
all_unit_indices.append(inds_in_seg_abs)
211215
all_unit_indices = np.concatenate(all_unit_indices)
212216
selected_unit_indices = rng.choice(
213217
all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False

src/spikeinterface/core/tests/test_sorting_tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def test_generate_unit_ids_for_merge_group():
162162
if __name__ == "__main__":
163163
# test_spike_vector_to_spike_trains()
164164
# test_spike_vector_to_indices()
165-
# test_random_spikes_selection()
165+
test_random_spikes_selection()
166166

167-
test_apply_merges_to_sorting()
168-
test_get_ids_after_merging()
167+
# test_apply_merges_to_sorting()
168+
# test_get_ids_after_merging()
169169
# test_generate_unit_ids_for_merge_group()

0 commit comments

Comments
 (0)