@@ -125,9 +125,9 @@ def test_epoch_unbound():
125125 batch_size = 10
126126 n_classes = 10
127127 data = list (range (n_iters ))
128- loss_values = iter (range (n_epochs * n_iters ))
129- y_true_batch_values = iter (np .random .randint (0 , n_classes , size = (n_epochs * n_iters , batch_size )))
130- y_pred_batch_values = iter (np .random .rand (n_epochs * n_iters , batch_size , n_classes ))
128+ loss_values = iter (range (2 * n_epochs * n_iters ))
129+ y_true_batch_values = iter (np .random .randint (0 , n_classes , size = (2 * n_epochs * n_iters , batch_size )))
130+ y_pred_batch_values = iter (np .random .rand (2 * n_epochs * n_iters , batch_size , n_classes ))
131131
132132 def update_fn (engine , batch ):
133133 loss_value = next (loss_values )
@@ -146,9 +146,7 @@ def update_fn(engine, batch):
146146
147147 running_avg_acc = [None ]
148148
149- @trainer .on (Events .STARTED )
150- def running_avg_output_init (engine ):
151- engine .state .running_avg_output = None
149+ trainer .state .running_avg_output = None
152150
153151 @trainer .on (Events .ITERATION_COMPLETED , running_avg_acc )
154152 def manual_running_avg_acc (engine , running_avg_acc ):
@@ -187,6 +185,10 @@ def assert_equal_running_avg_output_values(engine):
187185
188186 trainer .run (data , max_epochs = 3 )
189187
188+ running_avg_acc [0 ] = None
189+ trainer .state .running_avg_output = None
190+ trainer .run (data , max_epochs = 3 )
191+
190192
191193def test_multiple_attach ():
192194 n_iters = 100
0 commit comments