Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/spikeinterface/preprocessing/silence_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment

from spikeinterface.core import get_random_data_chunks, get_noise_levels
from spikeinterface.core import get_noise_levels
from spikeinterface.core.generate import NoiseGeneratorRecording
from spikeinterface.core.job_tools import split_job_kwargs


class SilencedPeriodsRecording(BasePreprocessor):
Expand Down Expand Up @@ -36,6 +37,12 @@ class SilencedPeriodsRecording(BasePreprocessor):
- "noise": The periods are filled with a gaussion noise that has the
same variance that the one in the recordings, on a per channel
basis
job_kwargs : dict
Keyword arguments for the joblib parallelization. If you want to use
`job_kwargs`, you need to pass them as a dictionary with the key "job_kwargs".
For example, `job_kwargs={"num_workers": 4}`.
Note that this is not used for the `get_noise_levels` function, which has its own
`random_slices_kwargs` argument.
**random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function

Returns
Expand All @@ -44,7 +51,16 @@ class SilencedPeriodsRecording(BasePreprocessor):
The recording extractor after silencing some periods
"""

def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, seed=None, **random_chunk_kwargs):
def __init__(
self,
recording,
list_periods,
mode="zeros",
noise_levels=None,
seed=None,
job_kwargs=dict(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
job_kwargs=dict(),
**job_kwargs,

Should it be like this?

And then load the _shared_job_kwargs doc?

I'll test this PR now!

**random_chunk_kwargs,
):
available_modes = ("zeros", "noise")
num_seg = recording.get_num_segments()

Expand Down Expand Up @@ -74,7 +90,7 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see
random_slices_kwargs = random_chunk_kwargs.copy()
random_slices_kwargs["seed"] = seed
noise_levels = get_noise_levels(
recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs
recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs, **job_kwargs
)
noise_generator = NoiseGeneratorRecording(
num_channels=recording.get_num_channels(),
Expand All @@ -97,7 +113,9 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see
rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index)
self.add_recording_segment(rec_segment)

self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed)
self._kwargs = dict(
recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels
)
self._kwargs.update(random_chunk_kwargs)


Expand Down
Loading