@@ -197,17 +197,21 @@ def random_spikes_selection(
197197 cum_sizes = np .cumsum ([0 ] + [s .size for s in spikes ])
198198
199199 # this fast when numba
200- spike_indices = spike_vector_to_indices (spikes , sorting .unit_ids )
200+ spike_indices = spike_vector_to_indices (spikes , sorting .unit_ids , absolute_index = False )
201201
202202 random_spikes_indices = []
203203 for unit_index , unit_id in enumerate (sorting .unit_ids ):
204204 all_unit_indices = []
205205 for segment_index in range (sorting .get_num_segments ()):
206- inds_in_seg = spike_indices [segment_index ][unit_id ] + cum_sizes [segment_index ]
206+ # this is local index
207+ inds_in_seg = spike_indices [segment_index ][unit_id ]
207208 if margin_size is not None :
208- inds_in_seg = inds_in_seg [inds_in_seg >= margin_size ]
209- inds_in_seg = inds_in_seg [inds_in_seg < (num_samples [segment_index ] - margin_size )]
210- all_unit_indices .append (inds_in_seg )
209+ local_spikes = spikes [segment_index ][inds_in_seg ]
210+ mask = (local_spikes ["sample_index" ] >= margin_size ) & (local_spikes ["sample_index" ] < (num_samples [segment_index ] - margin_size ))
211+ inds_in_seg = inds_in_seg [mask ]
212+ # go back to absolut index
213+ inds_in_seg_abs = inds_in_seg + cum_sizes [segment_index ]
214+ all_unit_indices .append (inds_in_seg_abs )
211215 all_unit_indices = np .concatenate (all_unit_indices )
212216 selected_unit_indices = rng .choice (
213217 all_unit_indices , size = min (max_spikes_per_unit , all_unit_indices .size ), replace = False , shuffle = False
0 commit comments