1515config1 = (3 , [(2 , 0 ), (5 , 10 )], True , [0.0 , 0.0 , 3.3333333333333335 ])
1616expected_hist2 = [0.0 ] * 10 + [float (i ) for i in range (1 , 11 )] + [10.0 ] * 10
1717config2 = (30 , [(10 , 0 ), (20 , 10 )], True , expected_hist2 )
18+ config3 = (
19+ PiecewiseLinearStateScheduler ,
20+ {"param_name" : "linear_scheduled_param" , "milestones_values" : [(3 , 12 ), (5 , 10 )]},
21+ )
22+ config4 = (ExpStateScheduler , {"param_name" : "exp_scheduled_param" , "initial_value" : 10 , "gamma" : 0.99 })
23+ config5 = (
24+ MultiStepStateScheduler ,
25+ {"param_name" : "multistep_scheduled_param" , "initial_value" : 10 , "gamma" : 0.99 , "milestones" : [3 , 6 ]},
26+ )
27+
28+
29+ class LambdaState :
30+ def __init__ (self , initial_value , gamma ):
31+ self .initial_value = initial_value
32+ self .gamma = gamma
33+
34+ def __call__ (self , event_index ):
35+ return self .initial_value * self .gamma ** (event_index % 9 )
36+
37+
38+ config6 = (
39+ LambdaStateScheduler ,
40+ {"param_name" : "custom_scheduled_param" , "lambda_obj" : LambdaState (initial_value = 10 , gamma = 0.99 )},
41+ )
1842
1943
2044@pytest .mark .parametrize (
@@ -216,33 +240,7 @@ def __init__(self, initial_value, gamma):
216240 )
217241
218242
219- config1 = (
220- PiecewiseLinearStateScheduler ,
221- {"param_name" : "linear_scheduled_param" , "milestones_values" : [(3 , 12 ), (5 , 10 )]},
222- )
223- config2 = (ExpStateScheduler , {"param_name" : "exp_scheduled_param" , "initial_value" : 10 , "gamma" : 0.99 })
224- config3 = (
225- MultiStepStateScheduler ,
226- {"param_name" : "multistep_scheduled_param" , "initial_value" : 10 , "gamma" : 0.99 , "milestones" : [3 , 6 ]},
227- )
228-
229-
230- class LambdaState :
231- def __init__ (self , initial_value , gamma ):
232- self .initial_value = initial_value
233- self .gamma = gamma
234-
235- def __call__ (self , event_index ):
236- return self .initial_value * self .gamma ** (event_index % 9 )
237-
238-
239- config4 = (
240- LambdaStateScheduler ,
241- {"param_name" : "custom_scheduled_param" , "lambda_obj" : LambdaState (initial_value = 10 , gamma = 0.99 )},
242- )
243-
244-
245- @pytest .mark .parametrize ("scheduler_cls,scheduler_kwargs" , [config1 , config2 , config3 , config4 ])
243+ @pytest .mark .parametrize ("scheduler_cls,scheduler_kwargs" , [config3 , config4 , config5 , config6 ])
246244def test_simulate_and_plot_values (scheduler_cls , scheduler_kwargs ):
247245
248246 import matplotlib
@@ -265,7 +263,20 @@ def _test(scheduler_cls, scheduler_kwargs):
265263 _test (scheduler_cls , scheduler_kwargs )
266264
267265
268- @pytest .mark .parametrize ("scheduler_cls,scheduler_kwargs" , [config1 , config2 , config3 , config4 ])
266+ @pytest .mark .parametrize ("scheduler_cls,scheduler_kwargs" , [config3 , config4 , config5 , config6 ])
267+ def test_simulate_values (scheduler_cls , scheduler_kwargs ):
268+ def _test (scheduler_cls , scheduler_kwargs ):
269+ max_epochs = 2
270+ data = [0 ] * 10
271+ scheduler_cls .simulate_values (num_events = len (data ) * max_epochs , ** scheduler_kwargs )
272+
273+ assert "save_history" not in scheduler_kwargs
274+ _test (scheduler_cls , scheduler_kwargs )
275+ scheduler_kwargs ["save_history" ] = True
276+ _test (scheduler_cls , scheduler_kwargs )
277+
278+
279+ @pytest .mark .parametrize ("scheduler_cls,scheduler_kwargs" , [config3 , config4 , config5 , config6 ])
269280def test_state_param_asserts (scheduler_cls , scheduler_kwargs ):
270281 import re
271282
0 commit comments