Skip to content

Commit d337001

Browse files
authored
Merge branch 'main' into report_without_waveforms
2 parents 2faed13 + c4d2eaa commit d337001

25 files changed

+853
-349
lines changed

doc/modules/core.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ and merging unit groups.
385385
386386
sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3])
387387
sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0])
388-
sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3])
388+
sorting_analyzer_merge = sorting_analyzer.merge_units([[0, 1], [2, 3]])
389389
390390
All computed extensions will be automatically propagated or merged when curating. Please refer to the
391391
:ref:`modules/curation:Curation module` documentation for more information.

doc/modules/curation.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ The ``censored_period_ms`` parameter is the time window in milliseconds to consi
8888
The :py:func:`~spikeinterface.curation.remove_redundand_units` function removes
8989
redundant units from the sorting output. Redundant units are units that share over
9090
a certain percentage of spikes, by default 80%.
91-
The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.
91+
The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.
9292

9393
.. code-block:: python
9494
@@ -102,13 +102,18 @@ The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.
102102
)
103103
104104
# remove redundant units from SortingAnalyzer object
105-
clean_sorting_analyzer = remove_redundant_units(
105+
# note this returns a cleaned sorting
106+
clean_sorting = remove_redundant_units(
106107
sorting_analyzer,
107108
duplicate_threshold=0.9,
108109
remove_strategy="min_shift"
109110
)
111+
# in order to have a SortingAnalyer with only the non-redundant units one must
112+
# select the designed units remembering to give format and folder if one wants
113+
# a persistent SortingAnalyzer.
114+
clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids)
110115
111-
We recommend usinf the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps
116+
We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps
112117
the unit (among the redundant ones), with a better template alignment.
113118

114119

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,12 +691,13 @@ class ComputeNoiseLevels(AnalyzerExtension):
691691
need_recording = True
692692
use_nodepipeline = False
693693
need_job_kwargs = False
694+
need_backward_compatibility_on_load = True
694695

695696
def __init__(self, sorting_analyzer):
696697
AnalyzerExtension.__init__(self, sorting_analyzer)
697698

698-
def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None):
699-
params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed)
699+
def _set_params(self, **noise_level_params):
700+
params = noise_level_params.copy()
700701
return params
701702

702703
def _select_extension_data(self, unit_ids):
@@ -717,6 +718,15 @@ def _run(self, verbose=False):
717718
def _get_data(self):
718719
return self.data["noise_levels"]
719720

721+
def _handle_backward_compatibility_on_load(self):
722+
# The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None)
723+
# now it is handle more explicitly using random_slices_kwargs=dict()
724+
for key in ("num_chunks_per_segment", "chunk_size", "seed"):
725+
if key in self.params:
726+
if "random_slices_kwargs" not in self.params:
727+
self.params["random_slices_kwargs"] = dict()
728+
self.params["random_slices_kwargs"][key] = self.params.pop(key)
729+
720730

721731
register_result_extension(ComputeNoiseLevels)
722732
compute_noise_levels = ComputeNoiseLevels.function_factory()

src/spikeinterface/core/baserecordingsnippets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,10 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False
172172
number_of_device_channel_indices = np.max(list(device_channel_indices) + [0])
173173
if number_of_device_channel_indices >= self.get_num_channels():
174174
error_msg = (
175-
f"The given Probe have 'device_channel_indices' that do not match channel count \n"
176-
f"{number_of_device_channel_indices} vs {self.get_num_channels()} \n"
175+
f"The given Probe either has 'device_channel_indices' that does not match channel count \n"
176+
f"{len(device_channel_indices)} vs {self.get_num_channels()} \n"
177+
f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n"
178+
f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n"
177179
f"device_channel_indices are the following: {device_channel_indices} \n"
178180
f"recording channels are the following: {self.get_channel_ids()} \n"
179181
)

src/spikeinterface/core/job_tools.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size):
149149

150150

151151
def divide_recording_into_chunks(recording, chunk_size):
152-
all_chunks = []
152+
recording_slices = []
153153
for segment_index in range(recording.get_num_segments()):
154154
num_frames = recording.get_num_samples(segment_index)
155155
chunks = divide_segment_into_chunks(num_frames, chunk_size)
156-
all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks])
157-
return all_chunks
156+
recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks])
157+
return recording_slices
158158

159159

160160
def ensure_n_jobs(recording, n_jobs=1):
@@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1):
185185
return n_jobs
186186

187187

188+
def chunk_duration_to_chunk_size(chunk_duration, recording):
189+
if isinstance(chunk_duration, float):
190+
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
191+
elif isinstance(chunk_duration, str):
192+
if chunk_duration.endswith("ms"):
193+
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
194+
elif chunk_duration.endswith("s"):
195+
chunk_duration = float(chunk_duration.replace("s", ""))
196+
else:
197+
raise ValueError("chunk_duration must ends with s or ms")
198+
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
199+
else:
200+
raise ValueError("chunk_duration must be str or float")
201+
return chunk_size
202+
203+
188204
def ensure_chunk_size(
189205
recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
190206
):
@@ -231,18 +247,7 @@ def ensure_chunk_size(
231247
num_channels = recording.get_num_channels()
232248
chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs))
233249
elif chunk_duration is not None:
234-
if isinstance(chunk_duration, float):
235-
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
236-
elif isinstance(chunk_duration, str):
237-
if chunk_duration.endswith("ms"):
238-
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
239-
elif chunk_duration.endswith("s"):
240-
chunk_duration = float(chunk_duration.replace("s", ""))
241-
else:
242-
raise ValueError("chunk_duration must ends with s or ms")
243-
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
244-
else:
245-
raise ValueError("chunk_duration must be str or float")
250+
chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording)
246251
else:
247252
# Edge case to define single chunk per segment for n_jobs=1.
248253
# All chunking parameters equal None mean single chunk per segment
@@ -382,11 +387,13 @@ def __init__(
382387
f"chunk_duration={chunk_duration_str}",
383388
)
384389

385-
def run(self):
390+
def run(self, recording_slices=None):
386391
"""
387392
Runs the defined jobs.
388393
"""
389-
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)
394+
395+
if recording_slices is None:
396+
recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size)
390397

391398
if self.handle_returns:
392399
returns = []
@@ -395,17 +402,17 @@ def run(self):
395402

396403
if self.n_jobs == 1:
397404
if self.progress_bar:
398-
all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name)
405+
recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name)
399406

400407
worker_ctx = self.init_func(*self.init_args)
401-
for segment_index, frame_start, frame_stop in all_chunks:
408+
for segment_index, frame_start, frame_stop in recording_slices:
402409
res = self.func(segment_index, frame_start, frame_stop, worker_ctx)
403410
if self.handle_returns:
404411
returns.append(res)
405412
if self.gather_func is not None:
406413
self.gather_func(res)
407414
else:
408-
n_jobs = min(self.n_jobs, len(all_chunks))
415+
n_jobs = min(self.n_jobs, len(recording_slices))
409416

410417
# parallel
411418
with ProcessPoolExecutor(
@@ -414,10 +421,10 @@ def run(self):
414421
mp_context=mp.get_context(self.mp_context),
415422
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
416423
) as executor:
417-
results = executor.map(function_wrapper, all_chunks)
424+
results = executor.map(function_wrapper, recording_slices)
418425

419426
if self.progress_bar:
420-
results = tqdm(results, desc=self.job_name, total=len(all_chunks))
427+
results = tqdm(results, desc=self.job_name, total=len(recording_slices))
421428

422429
for res in results:
423430
if self.handle_returns:

src/spikeinterface/core/node_pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ def run_node_pipeline(
489489
names=None,
490490
verbose=False,
491491
skip_after_n_peaks=None,
492+
recording_slices=None,
492493
):
493494
"""
494495
Machinery to compute in parallel operations on peaks and traces.
@@ -540,6 +541,10 @@ def run_node_pipeline(
540541
skip_after_n_peaks : None | int
541542
Skip the computation after n_peaks.
542543
This is not an exact because internally this skip is done per worker in average.
544+
recording_slices : None | list[tuple]
545+
Optionaly give a list of slices to run the pipeline only on some chunks of the recording.
546+
It must be a list of (segment_index, frame_start, frame_stop).
547+
If None (default), the function iterates over the entire duration of the recording.
543548
544549
Returns
545550
-------
@@ -578,7 +583,7 @@ def run_node_pipeline(
578583
**job_kwargs,
579584
)
580585

581-
processor.run()
586+
processor.run(recording_slices=recording_slices)
582587

583588
outs = gather_func.finalize_buffers(squeeze_output=squeeze_output)
584589
return outs

0 commit comments

Comments
 (0)