Skip to content

Commit 4bbf71d

Browse files
authored
add tests for simulate_values function (#2292)
* add tests for simulate_values function * remove test configs definition duplicates
1 parent 4a37e35 commit 4bbf71d

File tree

1 file changed

+39
-28
lines changed

1 file changed

+39
-28
lines changed

tests/ignite/handlers/test_state_param_scheduler.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,30 @@
1515
config1 = (3, [(2, 0), (5, 10)], True, [0.0, 0.0, 3.3333333333333335])
1616
expected_hist2 = [0.0] * 10 + [float(i) for i in range(1, 11)] + [10.0] * 10
1717
config2 = (30, [(10, 0), (20, 10)], True, expected_hist2)
18+
config3 = (
19+
PiecewiseLinearStateScheduler,
20+
{"param_name": "linear_scheduled_param", "milestones_values": [(3, 12), (5, 10)]},
21+
)
22+
config4 = (ExpStateScheduler, {"param_name": "exp_scheduled_param", "initial_value": 10, "gamma": 0.99})
23+
config5 = (
24+
MultiStepStateScheduler,
25+
{"param_name": "multistep_scheduled_param", "initial_value": 10, "gamma": 0.99, "milestones": [3, 6]},
26+
)
27+
28+
29+
class LambdaState:
30+
def __init__(self, initial_value, gamma):
31+
self.initial_value = initial_value
32+
self.gamma = gamma
33+
34+
def __call__(self, event_index):
35+
return self.initial_value * self.gamma ** (event_index % 9)
36+
37+
38+
config6 = (
39+
LambdaStateScheduler,
40+
{"param_name": "custom_scheduled_param", "lambda_obj": LambdaState(initial_value=10, gamma=0.99)},
41+
)
1842

1943

2044
@pytest.mark.parametrize(
@@ -216,33 +240,7 @@ def __init__(self, initial_value, gamma):
216240
)
217241

218242

219-
config1 = (
220-
PiecewiseLinearStateScheduler,
221-
{"param_name": "linear_scheduled_param", "milestones_values": [(3, 12), (5, 10)]},
222-
)
223-
config2 = (ExpStateScheduler, {"param_name": "exp_scheduled_param", "initial_value": 10, "gamma": 0.99})
224-
config3 = (
225-
MultiStepStateScheduler,
226-
{"param_name": "multistep_scheduled_param", "initial_value": 10, "gamma": 0.99, "milestones": [3, 6]},
227-
)
228-
229-
230-
class LambdaState:
231-
def __init__(self, initial_value, gamma):
232-
self.initial_value = initial_value
233-
self.gamma = gamma
234-
235-
def __call__(self, event_index):
236-
return self.initial_value * self.gamma ** (event_index % 9)
237-
238-
239-
config4 = (
240-
LambdaStateScheduler,
241-
{"param_name": "custom_scheduled_param", "lambda_obj": LambdaState(initial_value=10, gamma=0.99)},
242-
)
243-
244-
245-
@pytest.mark.parametrize("scheduler_cls,scheduler_kwargs", [config1, config2, config3, config4])
243+
@pytest.mark.parametrize("scheduler_cls,scheduler_kwargs", [config3, config4, config5, config6])
246244
def test_simulate_and_plot_values(scheduler_cls, scheduler_kwargs):
247245

248246
import matplotlib
@@ -265,7 +263,20 @@ def _test(scheduler_cls, scheduler_kwargs):
265263
_test(scheduler_cls, scheduler_kwargs)
266264

267265

268-
@pytest.mark.parametrize("scheduler_cls,scheduler_kwargs", [config1, config2, config3, config4])
266+
@pytest.mark.parametrize("scheduler_cls,scheduler_kwargs", [config3, config4, config5, config6])
267+
def test_simulate_values(scheduler_cls, scheduler_kwargs):
268+
def _test(scheduler_cls, scheduler_kwargs):
269+
max_epochs = 2
270+
data = [0] * 10
271+
scheduler_cls.simulate_values(num_events=len(data) * max_epochs, **scheduler_kwargs)
272+
273+
assert "save_history" not in scheduler_kwargs
274+
_test(scheduler_cls, scheduler_kwargs)
275+
scheduler_kwargs["save_history"] = True
276+
_test(scheduler_cls, scheduler_kwargs)
277+
278+
279+
@pytest.mark.parametrize("scheduler_cls,scheduler_kwargs", [config3, config4, config5, config6])
269280
def test_state_param_asserts(scheduler_cls, scheduler_kwargs):
270281
import re
271282

0 commit comments

Comments
 (0)