Skip to content

Commit b1326d2

Browse files
committed
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into expose-zarr-compression
2 parents cffb2c9 + 10e90db commit b1326d2

File tree

5 files changed

+24
-15
lines changed

5 files changed

+24
-15
lines changed

src/spikeinterface/core/sorting_tools.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,23 @@ def random_spikes_selection(
197197
cum_sizes = np.cumsum([0] + [s.size for s in spikes])
198198

199199
# this fast when numba
200-
spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids)
200+
spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids, absolute_index=False)
201201

202202
random_spikes_indices = []
203203
for unit_index, unit_id in enumerate(sorting.unit_ids):
204204
all_unit_indices = []
205205
for segment_index in range(sorting.get_num_segments()):
206-
inds_in_seg = spike_indices[segment_index][unit_id] + cum_sizes[segment_index]
206+
# this is local index
207+
inds_in_seg = spike_indices[segment_index][unit_id]
207208
if margin_size is not None:
208-
inds_in_seg = inds_in_seg[inds_in_seg >= margin_size]
209-
inds_in_seg = inds_in_seg[inds_in_seg < (num_samples[segment_index] - margin_size)]
210-
all_unit_indices.append(inds_in_seg)
209+
local_spikes = spikes[segment_index][inds_in_seg]
210+
mask = (local_spikes["sample_index"] >= margin_size) & (
211+
local_spikes["sample_index"] < (num_samples[segment_index] - margin_size)
212+
)
213+
inds_in_seg = inds_in_seg[mask]
214+
# go back to absolut index
215+
inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index]
216+
all_unit_indices.append(inds_in_seg_abs)
211217
all_unit_indices = np.concatenate(all_unit_indices)
212218
selected_unit_indices = rng.choice(
213219
all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False

src/spikeinterface/core/sortinganalyzer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,7 +2062,8 @@ def load_data(self):
20622062
continue
20632063
ext_data_name = ext_data_file.stem
20642064
if ext_data_file.suffix == ".json":
2065-
ext_data = json.load(ext_data_file.open("r"))
2065+
with ext_data_file.open("r") as f:
2066+
ext_data = json.load(f)
20662067
elif ext_data_file.suffix == ".npy":
20672068
# The lazy loading of an extension is complicated because if we compute again
20682069
# and have a link to the old buffer on windows then it fails
@@ -2074,7 +2075,8 @@ def load_data(self):
20742075

20752076
ext_data = pd.read_csv(ext_data_file, index_col=0)
20762077
elif ext_data_file.suffix == ".pkl":
2077-
ext_data = pickle.load(ext_data_file.open("rb"))
2078+
with ext_data_file.open("rb") as f:
2079+
ext_data = pickle.load(f)
20782080
else:
20792081
continue
20802082
self.data[ext_data_name] = ext_data

src/spikeinterface/core/tests/test_sorting_tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def test_generate_unit_ids_for_merge_group():
162162
if __name__ == "__main__":
163163
# test_spike_vector_to_spike_trains()
164164
# test_spike_vector_to_indices()
165-
# test_random_spikes_selection()
165+
test_random_spikes_selection()
166166

167-
test_apply_merges_to_sorting()
168-
test_get_ids_after_merging()
167+
# test_apply_merges_to_sorting()
168+
# test_get_ids_after_merging()
169169
# test_generate_unit_ids_for_merge_group()

src/spikeinterface/qualitymetrics/quality_metric_calculator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def _run(self, verbose=False, **job_kwargs):
234234
)
235235

236236
existing_metrics = []
237-
qm_extension = self.sorting_analyzer.get_extension("quality_metrics")
237+
# here we get in the loaded via the dict only (to avoid full loading from disk after params reset)
238+
qm_extension = self.sorting_analyzer.extensions.get("quality_metrics", None)
238239
if (
239240
delete_existing_metrics is False
240241
and qm_extension is not None

src/spikeinterface/sorters/utils/shellscript.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,15 @@ def start(self) -> None:
8686
if self._verbose:
8787
print("RUNNING SHELL SCRIPT: " + cmd)
8888
self._start_time = time.time()
89+
encoding = sys.getdefaultencoding()
8990
self._process = subprocess.Popen(
90-
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True
91+
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True, encoding=encoding
9192
)
9293
with open(script_log_path, "w+") as script_log_file:
9394
for line in self._process.stdout:
9495
script_log_file.write(line)
95-
if (
96-
self._verbose
97-
): # Print onto console depending on the verbose property passed on from the sorter class
96+
if self._verbose:
97+
# Print onto console depending on the verbose property passed on from the sorter class
9898
print(line)
9999

100100
def wait(self, timeout=None) -> Optional[int]:

0 commit comments

Comments
 (0)