Skip to content

Commit 8a7895e

Browse files
authored
Merge pull request #3519 from alejoe91/fix-decimate-times
Don't let decimate mess with times and skim tests
2 parents e525d85 + 2d843f8 commit 8a7895e

File tree

2 files changed

+66
-30
lines changed

2 files changed

+66
-30
lines changed

src/spikeinterface/preprocessing/decimate.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,15 @@ def __init__(
6363
f"Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start/end frames."
6464
)
6565
self._decimation_offset = decimation_offset
66-
resample_rate = self._orig_samp_freq / self._decimation_factor
66+
decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor
6767

68-
BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate)
68+
BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency)
6969

70-
# in case there was a time_vector, it will be dropped for sanity.
71-
# This is not necessary but consistent with ResampleRecording
7270
for parent_segment in recording._recording_segments:
73-
parent_segment.time_vector = None
7471
self.add_recording_segment(
7572
DecimateRecordingSegment(
7673
parent_segment,
77-
resample_rate,
74+
decimated_sampling_frequency,
7875
self._orig_samp_freq,
7976
decimation_factor,
8077
decimation_offset,
@@ -93,22 +90,26 @@ class DecimateRecordingSegment(BaseRecordingSegment):
9390
def __init__(
9491
self,
9592
parent_recording_segment,
96-
resample_rate,
93+
decimated_sampling_frequency,
9794
parent_rate,
9895
decimation_factor,
9996
decimation_offset,
10097
dtype,
10198
):
102-
if parent_recording_segment.t_start is None:
103-
new_t_start = None
99+
if parent_recording_segment.time_vector is not None:
100+
time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor]
101+
decimated_sampling_frequency = None
102+
t_start = None
104103
else:
105-
new_t_start = parent_recording_segment.t_start + decimation_offset / parent_rate
104+
time_vector = None
105+
if parent_recording_segment.t_start is None:
106+
t_start = None
107+
else:
108+
t_start = parent_recording_segment.t_start + (decimation_offset / parent_rate)
106109

107110
# Do not use BasePreprocessorSegment bcause we have to reset the sampling rate!
108111
BaseRecordingSegment.__init__(
109-
self,
110-
sampling_frequency=resample_rate,
111-
t_start=new_t_start,
112+
self, sampling_frequency=decimated_sampling_frequency, t_start=t_start, time_vector=time_vector
112113
)
113114
self._parent_segment = parent_recording_segment
114115
self._decimation_factor = decimation_factor

src/spikeinterface/preprocessing/tests/test_decimate.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,39 +8,74 @@
88
import numpy as np
99

1010

11-
@pytest.mark.parametrize("N_segments", [1, 2])
12-
@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101])
13-
@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101])
14-
@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000])
15-
@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000])
16-
def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame):
17-
rec = generate_recording()
18-
19-
segment_num_samps = [101 + i for i in range(N_segments)]
20-
11+
@pytest.mark.parametrize("num_segments", [1, 2])
12+
@pytest.mark.parametrize("decimation_offset", [0, 1, 5, 21, 101])
13+
@pytest.mark.parametrize("decimation_factor", [1, 7, 50])
14+
def test_decimate(num_segments, decimation_offset, decimation_factor):
15+
segment_num_samps = [20000, 40000]
2116
rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1)
2217

23-
parent_traces = [rec.get_traces(i) for i in range(N_segments)]
18+
parent_traces = [rec.get_traces(i) for i in range(num_segments)]
2419

2520
if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor:
2621
with pytest.raises(ValueError):
2722
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
2823
return
2924

3025
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
31-
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)]
26+
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)]
3227

33-
if start_frame is None:
34-
start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments))
35-
if end_frame is None:
36-
end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments))
28+
for start_frame in [0, 1, 5, None, 1000]:
29+
for end_frame in [0, 1, 5, None, 1000]:
30+
if start_frame is None:
31+
start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
32+
if end_frame is None:
33+
end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
3734

38-
for i in range(N_segments):
35+
for i in range(num_segments):
36+
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
37+
assert np.all(
38+
decimated_rec.get_traces(i, start_frame, end_frame)
39+
== decimated_parent_traces[i][start_frame:end_frame]
40+
)
41+
42+
for i in range(num_segments):
3943
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
4044
assert np.all(
4145
decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame]
4246
)
4347

4448

49+
def test_decimate_with_times():
50+
rec = generate_recording(durations=[5, 10])
51+
52+
# test with times
53+
times = [rec.get_times(0) + 10, rec.get_times(1) + 20]
54+
for i, t in enumerate(times):
55+
rec.set_times(t, i)
56+
57+
decimation_factor = 2
58+
decimation_offset = 1
59+
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
60+
61+
for segment_index in range(rec.get_num_segments()):
62+
assert np.allclose(
63+
decimated_rec.get_times(segment_index),
64+
rec.get_times(segment_index)[decimation_offset::decimation_factor],
65+
)
66+
67+
# test with t_start
68+
rec = generate_recording(durations=[5, 10])
69+
t_starts = [10, 20]
70+
for t_start, rec_segment in zip(t_starts, rec._recording_segments):
71+
rec_segment.t_start = t_start
72+
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
73+
for segment_index in range(rec.get_num_segments()):
74+
assert np.allclose(
75+
decimated_rec.get_times(segment_index),
76+
rec.get_times(segment_index)[decimation_offset::decimation_factor],
77+
)
78+
79+
4580
if __name__ == "__main__":
4681
test_decimate()

0 commit comments

Comments
 (0)