1111@pytest .mark .parametrize ("num_segments" , [1 , 2 ])
1212@pytest .mark .parametrize ("decimation_offset" , [0 , 5 , 21 , 101 ])
1313@pytest .mark .parametrize ("decimation_factor" , [1 , 7 , 50 ])
14- @pytest .mark .parametrize ("start_frame" , [0 , 1 , 5 , None , 1000 ])
15- @pytest .mark .parametrize ("end_frame" , [0 , 1 , 5 , None , 1000 ])
16- def test_decimate (num_segments , decimation_offset , decimation_factor , start_frame , end_frame ):
17- rec = generate_recording ()
18-
19- segment_num_samps = [101 + i for i in range (num_segments )]
20-
14+ def test_decimate (num_segments , decimation_offset , decimation_factor ):
15+ segment_num_samps = [20000 , 40000 ]
2116 rec = NumpyRecording ([np .arange (2 * N ).reshape (N , 2 ) for N in segment_num_samps ], 1 )
2217
2318 parent_traces = [rec .get_traces (i ) for i in range (num_segments )]
@@ -30,10 +25,19 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram
3025 decimated_rec = DecimateRecording (rec , decimation_factor , decimation_offset = decimation_offset )
3126 decimated_parent_traces = [parent_traces [i ][decimation_offset ::decimation_factor ] for i in range (num_segments )]
3227
33- if start_frame is None :
34- start_frame = max (decimated_rec .get_num_samples (i ) for i in range (num_segments ))
35- if end_frame is None :
36- end_frame = max (decimated_rec .get_num_samples (i ) for i in range (num_segments ))
28+ for start_frame in [0 , 1 , 5 , None , 1000 ]:
29+ for end_frame in [0 , 1 , 5 , None , 1000 ]:
30+ if start_frame is None :
31+ start_frame = max (decimated_rec .get_num_samples (i ) for i in range (num_segments ))
32+ if end_frame is None :
33+ end_frame = max (decimated_rec .get_num_samples (i ) for i in range (num_segments ))
34+
35+ for i in range (num_segments ):
36+ assert decimated_rec .get_num_samples (i ) == decimated_parent_traces [i ].shape [0 ]
37+ assert np .all (
38+ decimated_rec .get_traces (i , start_frame , end_frame )
39+ == decimated_parent_traces [i ][start_frame :end_frame ]
40+ )
3741
3842 for i in range (num_segments ):
3943 assert decimated_rec .get_num_samples (i ) == decimated_parent_traces [i ].shape [0 ]
@@ -42,5 +46,36 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram
4246 )
4347
4448
49+ def test_decimate_with_times ():
50+ rec = generate_recording (durations = [5 , 10 ])
51+
52+ # test with times
53+ times = [rec .get_times (0 ) + 10 , rec .get_times (1 ) + 20 ]
54+ for i , t in enumerate (times ):
55+ rec .set_times (t , i )
56+
57+ decimation_factor = 2
58+ decimation_offset = 1
59+ decimated_rec = DecimateRecording (rec , decimation_factor , decimation_offset = decimation_offset )
60+
61+ for segment_index in range (rec .get_num_segments ()):
62+ assert np .allclose (
63+ decimated_rec .get_times (segment_index ),
64+ rec .get_times (segment_index )[decimation_offset ::decimation_factor ],
65+ )
66+
67+ # test with t_start
68+ rec = generate_recording (durations = [5 , 10 ])
69+ t_starts = [10 , 20 ]
70+ for t_start , rec_segment in zip (t_starts , rec ._recording_segments ):
71+ rec_segment .t_start = t_start
72+ decimated_rec = DecimateRecording (rec , decimation_factor , decimation_offset = decimation_offset )
73+ for segment_index in range (rec .get_num_segments ()):
74+ assert np .allclose (
75+ decimated_rec .get_times (segment_index ),
76+ rec .get_times (segment_index )[decimation_offset ::decimation_factor ],
77+ )
78+
79+
4580if __name__ == "__main__" :
4681 test_decimate ()
0 commit comments