@@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1):
185185 return n_jobs
186186
187187
188+ def chunk_duration_to_chunk_size (chunk_duration , recording ):
189+ if isinstance (chunk_duration , float ):
190+ chunk_size = int (chunk_duration * recording .get_sampling_frequency ())
191+ elif isinstance (chunk_duration , str ):
192+ if chunk_duration .endswith ("ms" ):
193+ chunk_duration = float (chunk_duration .replace ("ms" , "" )) / 1000.0
194+ elif chunk_duration .endswith ("s" ):
195+ chunk_duration = float (chunk_duration .replace ("s" , "" ))
196+ else :
197+ raise ValueError ("chunk_duration must ends with s or ms" )
198+ chunk_size = int (chunk_duration * recording .get_sampling_frequency ())
199+ else :
200+ raise ValueError ("chunk_duration must be str or float" )
201+ return chunk_size
202+
203+
188204def ensure_chunk_size (
189205 recording , total_memory = None , chunk_size = None , chunk_memory = None , chunk_duration = None , n_jobs = 1 , ** other_kwargs
190206):
@@ -231,18 +247,7 @@ def ensure_chunk_size(
231247 num_channels = recording .get_num_channels ()
232248 chunk_size = int (total_memory / (num_channels * n_bytes * n_jobs ))
233249 elif chunk_duration is not None :
234- if isinstance (chunk_duration , float ):
235- chunk_size = int (chunk_duration * recording .get_sampling_frequency ())
236- elif isinstance (chunk_duration , str ):
237- if chunk_duration .endswith ("ms" ):
238- chunk_duration = float (chunk_duration .replace ("ms" , "" )) / 1000.0
239- elif chunk_duration .endswith ("s" ):
240- chunk_duration = float (chunk_duration .replace ("s" , "" ))
241- else :
242- raise ValueError ("chunk_duration must ends with s or ms" )
243- chunk_size = int (chunk_duration * recording .get_sampling_frequency ())
244- else :
245- raise ValueError ("chunk_duration must be str or float" )
250+ chunk_size = chunk_duration_to_chunk_size (chunk_duration , recording )
246251 else :
247252 # Edge case to define single chunk per segment for n_jobs=1.
248253 # All chunking parameters equal None mean single chunk per segment
@@ -382,11 +387,13 @@ def __init__(
382387 f"chunk_duration={ chunk_duration_str } " ,
383388 )
384389
385- def run (self ):
390+ def run (self , all_chunks = None ):
386391 """
387392 Runs the defined jobs.
388393 """
389- all_chunks = divide_recording_into_chunks (self .recording , self .chunk_size )
394+
395+ if all_chunks is None :
396+ all_chunks = divide_recording_into_chunks (self .recording , self .chunk_size )
390397
391398 if self .handle_returns :
392399 returns = []
0 commit comments