|
8 | 8 | import numpy as np |
9 | 9 |
|
10 | 10 |
|
11 | | -@pytest.mark.parametrize("N_segments", [1, 2]) |
12 | | -@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101]) |
13 | | -@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101]) |
| 11 | +@pytest.mark.parametrize("num_segments", [1, 2]) |
| 12 | +@pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101]) |
| 13 | +@pytest.mark.parametrize("decimation_factor", [1, 7, 50]) |
14 | 14 | @pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) |
15 | 15 | @pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) |
16 | | -def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame): |
| 16 | +def test_decimate(num_segments, decimation_offset, decimation_factor, start_frame, end_frame): |
17 | 17 | rec = generate_recording() |
18 | 18 |
|
19 | | - segment_num_samps = [101 + i for i in range(N_segments)] |
| 19 | + segment_num_samps = [101 + i for i in range(num_segments)] |
20 | 20 |
|
21 | 21 | rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) |
22 | 22 |
|
23 | | - parent_traces = [rec.get_traces(i) for i in range(N_segments)] |
| 23 | + parent_traces = [rec.get_traces(i) for i in range(num_segments)] |
24 | 24 |
|
25 | 25 | if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor: |
26 | 26 | with pytest.raises(ValueError): |
27 | 27 | decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) |
28 | 28 | return |
29 | 29 |
|
30 | 30 | decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) |
31 | | - decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)] |
| 31 | + decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] |
32 | 32 |
|
33 | 33 | if start_frame is None: |
34 | | - start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) |
| 34 | + start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) |
35 | 35 | if end_frame is None: |
36 | | - end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) |
| 36 | + end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) |
37 | 37 |
|
38 | | - for i in range(N_segments): |
| 38 | + for i in range(num_segments): |
39 | 39 | assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] |
40 | 40 | assert np.all( |
41 | 41 | decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame] |
|
0 commit comments