Skip to content

Commit f900118

Browse files
committed
More skimming and test decimate with times
1 parent 2ba37a8 commit f900118

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

src/spikeinterface/preprocessing/decimate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
if parent_recording_segment.time_vector is not None:
100100
time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor]
101101
decimated_sampling_frequency = None
102+
t_start = None
102103
else:
103104
time_vector = None
104105
if parent_recording_segment.t_start is None:

src/spikeinterface/preprocessing/tests/test_decimate.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,8 @@
1111
@pytest.mark.parametrize("num_segments", [1, 2])
1212
@pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101])
1313
@pytest.mark.parametrize("decimation_factor", [1, 7, 50])
14-
@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000])
15-
@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000])
16-
def test_decimate(num_segments, decimation_offset, decimation_factor, start_frame, end_frame):
17-
rec = generate_recording()
18-
19-
segment_num_samps = [101 + i for i in range(num_segments)]
20-
14+
def test_decimate(num_segments, decimation_offset, decimation_factor):
15+
segment_num_samps = [20000, 40000]
2116
rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1)
2217

2318
parent_traces = [rec.get_traces(i) for i in range(num_segments)]
@@ -30,10 +25,19 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram
3025
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
3126
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)]
3227

33-
if start_frame is None:
34-
start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
35-
if end_frame is None:
36-
end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
28+
for start_frame in [0, 1, 5, None, 1000]:
29+
for end_frame in [0, 1, 5, None, 1000]:
30+
if start_frame is None:
31+
start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
32+
if end_frame is None:
33+
end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
34+
35+
for i in range(num_segments):
36+
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
37+
assert np.all(
38+
decimated_rec.get_traces(i, start_frame, end_frame)
39+
== decimated_parent_traces[i][start_frame:end_frame]
40+
)
3741

3842
for i in range(num_segments):
3943
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
@@ -42,5 +46,36 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram
4246
)
4347

4448

49+
def test_decimate_with_times():
50+
rec = generate_recording(durations=[5, 10])
51+
52+
# test with times
53+
times = [rec.get_times(0) + 10, rec.get_times(1) + 20]
54+
for i, t in enumerate(times):
55+
rec.set_times(t, i)
56+
57+
decimation_factor = 2
58+
decimation_offset = 1
59+
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
60+
61+
for segment_index in range(rec.get_num_segments()):
62+
assert np.allclose(
63+
decimated_rec.get_times(segment_index),
64+
rec.get_times(segment_index)[decimation_offset::decimation_factor],
65+
)
66+
67+
# test with t_start
68+
rec = generate_recording(durations=[5, 10])
69+
t_starts = [10, 20]
70+
for t_start, rec_segment in zip(t_starts, rec._recording_segments):
71+
rec_segment.t_start = t_start
72+
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
73+
for segment_index in range(rec.get_num_segments()):
74+
assert np.allclose(
75+
decimated_rec.get_times(segment_index),
76+
rec.get_times(segment_index)[decimation_offset::decimation_factor],
77+
)
78+
79+
4580
if __name__ == "__main__":
4681
test_decimate()

0 commit comments

Comments
 (0)