Skip to content

Commit 2ba37a8

Browse files
committed
Don't let decimate mess with times and skim tests
1 parent e525d85 commit 2ba37a8

File tree

2 files changed

+23
-23
lines changed

2 files changed

+23
-23
lines changed

src/spikeinterface/preprocessing/decimate.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,15 @@ def __init__(
6363
f"Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start/end frames."
6464
)
6565
self._decimation_offset = decimation_offset
66-
resample_rate = self._orig_samp_freq / self._decimation_factor
66+
decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor
6767

68-
BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate)
68+
BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency)
6969

70-
# in case there was a time_vector, it will be dropped for sanity.
71-
# This is not necessary but consistent with ResampleRecording
7270
for parent_segment in recording._recording_segments:
73-
parent_segment.time_vector = None
7471
self.add_recording_segment(
7572
DecimateRecordingSegment(
7673
parent_segment,
77-
resample_rate,
74+
decimated_sampling_frequency,
7875
self._orig_samp_freq,
7976
decimation_factor,
8077
decimation_offset,
@@ -93,22 +90,25 @@ class DecimateRecordingSegment(BaseRecordingSegment):
9390
def __init__(
9491
self,
9592
parent_recording_segment,
96-
resample_rate,
93+
decimated_sampling_frequency,
9794
parent_rate,
9895
decimation_factor,
9996
decimation_offset,
10097
dtype,
10198
):
102-
if parent_recording_segment.t_start is None:
103-
new_t_start = None
99+
if parent_recording_segment.time_vector is not None:
100+
time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor]
101+
decimated_sampling_frequency = None
104102
else:
105-
new_t_start = parent_recording_segment.t_start + decimation_offset / parent_rate
103+
time_vector = None
104+
if parent_recording_segment.t_start is None:
105+
t_start = None
106+
else:
107+
t_start = parent_recording_segment.t_start + decimation_offset / parent_rate
106108

107109
# Do not use BasePreprocessorSegment bcause we have to reset the sampling rate!
108110
BaseRecordingSegment.__init__(
109-
self,
110-
sampling_frequency=resample_rate,
111-
t_start=new_t_start,
111+
self, sampling_frequency=decimated_sampling_frequency, t_start=t_start, time_vector=time_vector
112112
)
113113
self._parent_segment = parent_recording_segment
114114
self._decimation_factor = decimation_factor

src/spikeinterface/preprocessing/tests/test_decimate.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,34 @@
88
import numpy as np
99

1010

11-
@pytest.mark.parametrize("N_segments", [1, 2])
12-
@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101])
13-
@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101])
11+
@pytest.mark.parametrize("num_segments", [1, 2])
12+
@pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101])
13+
@pytest.mark.parametrize("decimation_factor", [1, 7, 50])
1414
@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000])
1515
@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000])
16-
def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame):
16+
def test_decimate(num_segments, decimation_offset, decimation_factor, start_frame, end_frame):
1717
rec = generate_recording()
1818

19-
segment_num_samps = [101 + i for i in range(N_segments)]
19+
segment_num_samps = [101 + i for i in range(num_segments)]
2020

2121
rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1)
2222

23-
parent_traces = [rec.get_traces(i) for i in range(N_segments)]
23+
parent_traces = [rec.get_traces(i) for i in range(num_segments)]
2424

2525
if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor:
2626
with pytest.raises(ValueError):
2727
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
2828
return
2929

3030
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
31-
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)]
31+
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)]
3232

3333
if start_frame is None:
34-
start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments))
34+
start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
3535
if end_frame is None:
36-
end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments))
36+
end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
3737

38-
for i in range(N_segments):
38+
for i in range(num_segments):
3939
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
4040
assert np.all(
4141
decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame]

0 commit comments

Comments
 (0)