@@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size):
149149
150150
151151def divide_recording_into_chunks (recording , chunk_size ):
152- all_chunks = []
152+ recording_slices = []
153153 for segment_index in range (recording .get_num_segments ()):
154154 num_frames = recording .get_num_samples (segment_index )
155155 chunks = divide_segment_into_chunks (num_frames , chunk_size )
156- all_chunks .extend ([(segment_index , frame_start , frame_stop ) for frame_start , frame_stop in chunks ])
157- return all_chunks
156+ recording_slices .extend ([(segment_index , frame_start , frame_stop ) for frame_start , frame_stop in chunks ])
157+ return recording_slices
158158
159159
160160def ensure_n_jobs (recording , n_jobs = 1 ):
@@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1):
185185 return n_jobs
186186
187187
188+ def chunk_duration_to_chunk_size (chunk_duration , recording ):
189+ if isinstance (chunk_duration , float ):
190+ chunk_size = int (chunk_duration * recording .get_sampling_frequency ())
191+ elif isinstance (chunk_duration , str ):
192+ if chunk_duration .endswith ("ms" ):
193+ chunk_duration = float (chunk_duration .replace ("ms" , "" )) / 1000.0
194+ elif chunk_duration .endswith ("s" ):
195+ chunk_duration = float (chunk_duration .replace ("s" , "" ))
196+ else :
197+ raise ValueError ("chunk_duration must ends with s or ms" )
198+ chunk_size = int (chunk_duration * recording .get_sampling_frequency ())
199+ else :
200+ raise ValueError ("chunk_duration must be str or float" )
201+ return chunk_size
202+
203+
188204def ensure_chunk_size (
189205 recording , total_memory = None , chunk_size = None , chunk_memory = None , chunk_duration = None , n_jobs = 1 , ** other_kwargs
190206):
@@ -231,18 +247,7 @@ def ensure_chunk_size(
231247 num_channels = recording .get_num_channels ()
232248 chunk_size = int (total_memory / (num_channels * n_bytes * n_jobs ))
233249 elif chunk_duration is not None :
234- if isinstance (chunk_duration , float ):
235- chunk_size = int (chunk_duration * recording .get_sampling_frequency ())
236- elif isinstance (chunk_duration , str ):
237- if chunk_duration .endswith ("ms" ):
238- chunk_duration = float (chunk_duration .replace ("ms" , "" )) / 1000.0
239- elif chunk_duration .endswith ("s" ):
240- chunk_duration = float (chunk_duration .replace ("s" , "" ))
241- else :
242- raise ValueError ("chunk_duration must ends with s or ms" )
243- chunk_size = int (chunk_duration * recording .get_sampling_frequency ())
244- else :
245- raise ValueError ("chunk_duration must be str or float" )
250+ chunk_size = chunk_duration_to_chunk_size (chunk_duration , recording )
246251 else :
247252 # Edge case to define single chunk per segment for n_jobs=1.
248253 # All chunking parameters equal None mean single chunk per segment
@@ -382,11 +387,13 @@ def __init__(
382387 f"chunk_duration={ chunk_duration_str } " ,
383388 )
384389
385- def run (self ):
390+ def run (self , recording_slices = None ):
386391 """
387392 Runs the defined jobs.
388393 """
389- all_chunks = divide_recording_into_chunks (self .recording , self .chunk_size )
394+
395+ if recording_slices is None :
396+ recording_slices = divide_recording_into_chunks (self .recording , self .chunk_size )
390397
391398 if self .handle_returns :
392399 returns = []
@@ -395,17 +402,17 @@ def run(self):
395402
396403 if self .n_jobs == 1 :
397404 if self .progress_bar :
398- all_chunks = tqdm (all_chunks , ascii = True , desc = self .job_name )
405+ recording_slices = tqdm (recording_slices , ascii = True , desc = self .job_name )
399406
400407 worker_ctx = self .init_func (* self .init_args )
401- for segment_index , frame_start , frame_stop in all_chunks :
408+ for segment_index , frame_start , frame_stop in recording_slices :
402409 res = self .func (segment_index , frame_start , frame_stop , worker_ctx )
403410 if self .handle_returns :
404411 returns .append (res )
405412 if self .gather_func is not None :
406413 self .gather_func (res )
407414 else :
408- n_jobs = min (self .n_jobs , len (all_chunks ))
415+ n_jobs = min (self .n_jobs , len (recording_slices ))
409416
410417 # parallel
411418 with ProcessPoolExecutor (
@@ -414,10 +421,10 @@ def run(self):
414421 mp_context = mp .get_context (self .mp_context ),
415422 initargs = (self .func , self .init_func , self .init_args , self .max_threads_per_process ),
416423 ) as executor :
417- results = executor .map (function_wrapper , all_chunks )
424+ results = executor .map (function_wrapper , recording_slices )
418425
419426 if self .progress_bar :
420- results = tqdm (results , desc = self .job_name , total = len (all_chunks ))
427+ results = tqdm (results , desc = self .job_name , total = len (recording_slices ))
421428
422429 for res in results :
423430 if self .handle_returns :
0 commit comments