Skip to content

Commit a99ea7f

Browse files
Improve RunningAverage reset when epoch_bound=False (#2950)
* Do the improvement * A few bug fix in test * two improvements in test
1 parent 1df9932 commit a99ea7f

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

ignite/metrics/running_average.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def attach(self, engine: Engine, name: str, _usage: Union[str, MetricUsage] = Ep
143143
if self.epoch_bound:
144144
# restart average every epoch
145145
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
146+
else:
147+
engine.add_event_handler(Events.STARTED, self.started)
146148
# compute metric
147149
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
148150
# apply running average

tests/ignite/metrics/test_running_average.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def test_epoch_unbound():
125125
batch_size = 10
126126
n_classes = 10
127127
data = list(range(n_iters))
128-
loss_values = iter(range(n_epochs * n_iters))
129-
y_true_batch_values = iter(np.random.randint(0, n_classes, size=(n_epochs * n_iters, batch_size)))
130-
y_pred_batch_values = iter(np.random.rand(n_epochs * n_iters, batch_size, n_classes))
128+
loss_values = iter(range(2 * n_epochs * n_iters))
129+
y_true_batch_values = iter(np.random.randint(0, n_classes, size=(2 * n_epochs * n_iters, batch_size)))
130+
y_pred_batch_values = iter(np.random.rand(2 * n_epochs * n_iters, batch_size, n_classes))
131131

132132
def update_fn(engine, batch):
133133
loss_value = next(loss_values)
@@ -146,9 +146,7 @@ def update_fn(engine, batch):
146146

147147
running_avg_acc = [None]
148148

149-
@trainer.on(Events.STARTED)
150-
def running_avg_output_init(engine):
151-
engine.state.running_avg_output = None
149+
trainer.state.running_avg_output = None
152150

153151
@trainer.on(Events.ITERATION_COMPLETED, running_avg_acc)
154152
def manual_running_avg_acc(engine, running_avg_acc):
@@ -187,6 +185,10 @@ def assert_equal_running_avg_output_values(engine):
187185

188186
trainer.run(data, max_epochs=3)
189187

188+
running_avg_acc[0] = None
189+
trainer.state.running_avg_output = None
190+
trainer.run(data, max_epochs=3)
191+
190192

191193
def test_multiple_attach():
192194
n_iters = 100

0 commit comments

Comments
 (0)