Skip to content

Commit c4d2eaa

Browse files
authored
Merge pull request #3506 from samuelgarcia/nodepipeline_chunks
recording_slices in run_node_pipeline()
2 parents d78a0da + 0e44185 commit c4d2eaa

File tree

5 files changed

+38
-18
lines changed

5 files changed

+38
-18
lines changed

src/spikeinterface/core/job_tools.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size):
149149

150150

151151
def divide_recording_into_chunks(recording, chunk_size):
152-
all_chunks = []
152+
recording_slices = []
153153
for segment_index in range(recording.get_num_segments()):
154154
num_frames = recording.get_num_samples(segment_index)
155155
chunks = divide_segment_into_chunks(num_frames, chunk_size)
156-
all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks])
157-
return all_chunks
156+
recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks])
157+
return recording_slices
158158

159159

160160
def ensure_n_jobs(recording, n_jobs=1):
@@ -387,13 +387,13 @@ def __init__(
387387
f"chunk_duration={chunk_duration_str}",
388388
)
389389

390-
def run(self, all_chunks=None):
390+
def run(self, recording_slices=None):
391391
"""
392392
Runs the defined jobs.
393393
"""
394394

395-
if all_chunks is None:
396-
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)
395+
if recording_slices is None:
396+
recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size)
397397

398398
if self.handle_returns:
399399
returns = []
@@ -402,17 +402,17 @@ def run(self, all_chunks=None):
402402

403403
if self.n_jobs == 1:
404404
if self.progress_bar:
405-
all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name)
405+
recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name)
406406

407407
worker_ctx = self.init_func(*self.init_args)
408-
for segment_index, frame_start, frame_stop in all_chunks:
408+
for segment_index, frame_start, frame_stop in recording_slices:
409409
res = self.func(segment_index, frame_start, frame_stop, worker_ctx)
410410
if self.handle_returns:
411411
returns.append(res)
412412
if self.gather_func is not None:
413413
self.gather_func(res)
414414
else:
415-
n_jobs = min(self.n_jobs, len(all_chunks))
415+
n_jobs = min(self.n_jobs, len(recording_slices))
416416

417417
# parallel
418418
with ProcessPoolExecutor(
@@ -421,10 +421,10 @@ def run(self, all_chunks=None):
421421
mp_context=mp.get_context(self.mp_context),
422422
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
423423
) as executor:
424-
results = executor.map(function_wrapper, all_chunks)
424+
results = executor.map(function_wrapper, recording_slices)
425425

426426
if self.progress_bar:
427-
results = tqdm(results, desc=self.job_name, total=len(all_chunks))
427+
results = tqdm(results, desc=self.job_name, total=len(recording_slices))
428428

429429
for res in results:
430430
if self.handle_returns:

src/spikeinterface/core/node_pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ def run_node_pipeline(
489489
names=None,
490490
verbose=False,
491491
skip_after_n_peaks=None,
492+
recording_slices=None,
492493
):
493494
"""
494495
Machinery to compute in parallel operations on peaks and traces.
@@ -540,6 +541,10 @@ def run_node_pipeline(
540541
skip_after_n_peaks : None | int
541542
Skip the computation after n_peaks.
542543
This is not an exact because internally this skip is done per worker in average.
544+
recording_slices : None | list[tuple]
545+
Optionaly give a list of slices to run the pipeline only on some chunks of the recording.
546+
It must be a list of (segment_index, frame_start, frame_stop).
547+
If None (default), the function iterates over the entire duration of the recording.
543548
544549
Returns
545550
-------
@@ -578,7 +583,7 @@ def run_node_pipeline(
578583
**job_kwargs,
579584
)
580585

581-
processor.run()
586+
processor.run(recording_slices=recording_slices)
582587

583588
outs = gather_func.finalize_buffers(squeeze_output=squeeze_output)
584589
return outs

src/spikeinterface/core/recording_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ def append_noise_chunk(res):
806806
gather_func=append_noise_chunk,
807807
**job_kwargs,
808808
)
809-
executor.run(all_chunks=recording_slices)
809+
executor.run(recording_slices=recording_slices)
810810
noise_levels_chunks = np.stack(noise_levels_chunks)
811811
noise_levels = np.mean(noise_levels_chunks, axis=0)
812812

src/spikeinterface/core/tests/test_node_pipeline.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import shutil
55

66
from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording
7-
7+
from spikeinterface.core.job_tools import divide_recording_into_chunks
88

99
# from spikeinterface.sortingcomponents.peak_detection import detect_peaks
1010
from spikeinterface.core.node_pipeline import (
@@ -191,8 +191,8 @@ def test_run_node_pipeline(cache_folder_creation):
191191
unpickled_node = pickle.loads(pickled_node)
192192

193193

194-
def test_skip_after_n_peaks():
195-
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0])
194+
def test_skip_after_n_peaks_and_recording_slices():
195+
recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0], seed=2205)
196196

197197
# job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False)
198198
job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False)
@@ -211,18 +211,27 @@ def test_skip_after_n_peaks():
211211
node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True)
212212
nodes = [node0, node1]
213213

214+
# skip
214215
skip_after_n_peaks = 30
215216
some_amplitudes = run_node_pipeline(
216217
recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks
217218
)
218-
219219
assert some_amplitudes.size >= skip_after_n_peaks
220220
assert some_amplitudes.size < spikes.size
221221

222+
# slices : 1 every 4
223+
recording_slices = divide_recording_into_chunks(recording, 10_000)
224+
recording_slices = recording_slices[::4]
225+
some_amplitudes = run_node_pipeline(
226+
recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices
227+
)
228+
tolerance = 1.2
229+
assert some_amplitudes.size < (spikes.size // 4) * tolerance
230+
222231

223232
# the following is for testing locally with python or ipython. It is not used in ci or with pytest.
224233
if __name__ == "__main__":
225234
# folder = Path("./cache_folder/core")
226235
# test_run_node_pipeline(folder)
227236

228-
test_skip_after_n_peaks()
237+
test_skip_after_n_peaks_and_recording_slices()

src/spikeinterface/sortingcomponents/peak_detection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def detect_peaks(
5757
folder=None,
5858
names=None,
5959
skip_after_n_peaks=None,
60+
recording_slices=None,
6061
**kwargs,
6162
):
6263
"""Peak detection based on threshold crossing in term of k x MAD.
@@ -83,6 +84,10 @@ def detect_peaks(
8384
skip_after_n_peaks : None | int
8485
Skip the computation after n_peaks.
8586
This is not an exact because internally this skip is done per worker in average.
87+
recording_slices : None | list[tuple]
88+
Optionaly give a list of slices to run the pipeline only on some chunks of the recording.
89+
It must be a list of (segment_index, frame_start, frame_stop).
90+
If None (default), the function iterates over the entire duration of the recording.
8691
8792
{method_doc}
8893
{job_doc}
@@ -135,6 +140,7 @@ def detect_peaks(
135140
folder=folder,
136141
names=names,
137142
skip_after_n_peaks=skip_after_n_peaks,
143+
recording_slices=recording_slices,
138144
)
139145
return outs
140146

0 commit comments

Comments
 (0)