Skip to content

Commit 7379a96

Browse files
authored
Merge branch 'main' into error-get-data
2 parents 3406f85 + 0df1160 commit 7379a96

File tree

9 files changed

+295
-107
lines changed

9 files changed

+295
-107
lines changed

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,12 +710,13 @@ class ComputeNoiseLevels(AnalyzerExtension):
710710
need_recording = True
711711
use_nodepipeline = False
712712
need_job_kwargs = False
713+
need_backward_compatibility_on_load = True
713714

714715
def __init__(self, sorting_analyzer):
715716
AnalyzerExtension.__init__(self, sorting_analyzer)
716717

717-
def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None):
718-
params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed)
718+
def _set_params(self, **noise_level_params):
719+
params = noise_level_params.copy()
719720
return params
720721

721722
def _select_extension_data(self, unit_ids):
@@ -736,6 +737,15 @@ def _run(self, verbose=False):
736737
def _get_data(self):
737738
return self.data["noise_levels"]
738739

740+
def _handle_backward_compatibility_on_load(self):
741+
# The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None)
742+
# now it is handle more explicitly using random_slices_kwargs=dict()
743+
for key in ("num_chunks_per_segment", "chunk_size", "seed"):
744+
if key in self.params:
745+
if "random_slices_kwargs" not in self.params:
746+
self.params["random_slices_kwargs"] = dict()
747+
self.params["random_slices_kwargs"][key] = self.params.pop(key)
748+
739749

740750
register_result_extension(ComputeNoiseLevels)
741751
compute_noise_levels = ComputeNoiseLevels.function_factory()

src/spikeinterface/core/job_tools.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1):
185185
return n_jobs
186186

187187

188+
def chunk_duration_to_chunk_size(chunk_duration, recording):
189+
if isinstance(chunk_duration, float):
190+
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
191+
elif isinstance(chunk_duration, str):
192+
if chunk_duration.endswith("ms"):
193+
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
194+
elif chunk_duration.endswith("s"):
195+
chunk_duration = float(chunk_duration.replace("s", ""))
196+
else:
197+
raise ValueError("chunk_duration must ends with s or ms")
198+
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
199+
else:
200+
raise ValueError("chunk_duration must be str or float")
201+
return chunk_size
202+
203+
188204
def ensure_chunk_size(
189205
recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
190206
):
@@ -231,18 +247,7 @@ def ensure_chunk_size(
231247
num_channels = recording.get_num_channels()
232248
chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs))
233249
elif chunk_duration is not None:
234-
if isinstance(chunk_duration, float):
235-
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
236-
elif isinstance(chunk_duration, str):
237-
if chunk_duration.endswith("ms"):
238-
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
239-
elif chunk_duration.endswith("s"):
240-
chunk_duration = float(chunk_duration.replace("s", ""))
241-
else:
242-
raise ValueError("chunk_duration must ends with s or ms")
243-
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
244-
else:
245-
raise ValueError("chunk_duration must be str or float")
250+
chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording)
246251
else:
247252
# Edge case to define single chunk per segment for n_jobs=1.
248253
# All chunking parameters equal None mean single chunk per segment
@@ -382,11 +387,13 @@ def __init__(
382387
f"chunk_duration={chunk_duration_str}",
383388
)
384389

385-
def run(self):
390+
def run(self, all_chunks=None):
386391
"""
387392
Runs the defined jobs.
388393
"""
389-
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)
394+
395+
if all_chunks is None:
396+
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)
390397

391398
if self.handle_returns:
392399
returns = []

0 commit comments

Comments
 (0)