|
8 | 8 | import numpy as np |
9 | 9 |
|
10 | 10 |
|
11 | | -@pytest.mark.parametrize("N_segments", [1, 2]) |
12 | | -@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101]) |
13 | | -@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101]) |
14 | | -@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) |
15 | | -@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) |
16 | | -def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame): |
17 | | - rec = generate_recording() |
18 | | - |
19 | | - segment_num_samps = [101 + i for i in range(N_segments)] |
20 | | - |
| 11 | +@pytest.mark.parametrize("num_segments", [1, 2]) |
| 12 | +@pytest.mark.parametrize("decimation_offset", [0, 1, 5, 21, 101]) |
| 13 | +@pytest.mark.parametrize("decimation_factor", [1, 7, 50]) |
| 14 | +def test_decimate(num_segments, decimation_offset, decimation_factor): |
| 15 | + segment_num_samps = [20000, 40000] |
21 | 16 | rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) |
22 | 17 |
|
23 | | - parent_traces = [rec.get_traces(i) for i in range(N_segments)] |
| 18 | + parent_traces = [rec.get_traces(i) for i in range(num_segments)] |
24 | 19 |
|
25 | 20 | if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor: |
26 | 21 | with pytest.raises(ValueError): |
27 | 22 | decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) |
28 | 23 | return |
29 | 24 |
|
30 | 25 | decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) |
31 | | - decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)] |
| 26 | + decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] |
32 | 27 |
|
33 | | - if start_frame is None: |
34 | | - start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) |
35 | | - if end_frame is None: |
36 | | - end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) |
| 28 | + for start_frame in [0, 1, 5, None, 1000]: |
| 29 | + for end_frame in [0, 1, 5, None, 1000]: |
| 30 | + if start_frame is None: |
| 31 | + start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) |
| 32 | + if end_frame is None: |
| 33 | + end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) |
37 | 34 |
|
38 | | - for i in range(N_segments): |
| 35 | + for i in range(num_segments): |
| 36 | + assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] |
| 37 | + assert np.all( |
| 38 | + decimated_rec.get_traces(i, start_frame, end_frame) |
| 39 | + == decimated_parent_traces[i][start_frame:end_frame] |
| 40 | + ) |
| 41 | + |
| 42 | + for i in range(num_segments): |
39 | 43 | assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] |
40 | 44 | assert np.all( |
41 | 45 | decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame] |
42 | 46 | ) |
43 | 47 |
|
44 | 48 |
|
| 49 | +def test_decimate_with_times(): |
| 50 | + rec = generate_recording(durations=[5, 10]) |
| 51 | + |
| 52 | + # test with times |
| 53 | + times = [rec.get_times(0) + 10, rec.get_times(1) + 20] |
| 54 | + for i, t in enumerate(times): |
| 55 | + rec.set_times(t, i) |
| 56 | + |
| 57 | + decimation_factor = 2 |
| 58 | + decimation_offset = 1 |
| 59 | + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) |
| 60 | + |
| 61 | + for segment_index in range(rec.get_num_segments()): |
| 62 | + assert np.allclose( |
| 63 | + decimated_rec.get_times(segment_index), |
| 64 | + rec.get_times(segment_index)[decimation_offset::decimation_factor], |
| 65 | + ) |
| 66 | + |
| 67 | + # test with t_start |
| 68 | + rec = generate_recording(durations=[5, 10]) |
| 69 | + t_starts = [10, 20] |
| 70 | + for t_start, rec_segment in zip(t_starts, rec._recording_segments): |
| 71 | + rec_segment.t_start = t_start |
| 72 | + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) |
| 73 | + for segment_index in range(rec.get_num_segments()): |
| 74 | + assert np.allclose( |
| 75 | + decimated_rec.get_times(segment_index), |
| 76 | + rec.get_times(segment_index)[decimation_offset::decimation_factor], |
| 77 | + ) |
| 78 | + |
| 79 | + |
45 | 80 | if __name__ == "__main__": |
46 | 81 | test_decimate() |
0 commit comments