Skip to content

Commit 3f1ef20

Browse files
committed
update mixins and engine
1 parent e414978 commit 3f1ef20

File tree

4 files changed

+152
-71
lines changed

4 files changed

+152
-71
lines changed

ignite/base/mixins.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from collections import OrderedDict
22
from collections.abc import Mapping
3-
from typing import Tuple
3+
from typing import List, Tuple
44

55

66
class Serializable:
7-
_state_dict_all_req_keys: Tuple = ()
8-
_state_dict_one_of_opt_keys: Tuple = ()
7+
_state_dict_all_req_keys: Tuple[str, ...] = ()
8+
_state_dict_one_of_opt_keys: Tuple[Tuple[str, ...], ...] = ((),)
9+
10+
def __init__(self) -> None:
11+
self._state_dict_user_keys: List[str] = []
12+
13+
@property
14+
def state_dict_user_keys(self) -> List:
15+
return self._state_dict_user_keys
916

1017
def state_dict(self) -> OrderedDict:
1118
raise NotImplementedError
@@ -19,6 +26,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
1926
raise ValueError(
2027
f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
2128
)
22-
opts = [k in state_dict for k in self._state_dict_one_of_opt_keys]
23-
if len(opts) > 0 and ((not any(opts)) or (all(opts))):
24-
raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys")
29+
for one_of_opt_keys in self._state_dict_one_of_opt_keys:
30+
opts = [k in state_dict for k in one_of_opt_keys]
31+
if len(opts) > 0 and (not any(opts)) or (all(opts)):
32+
raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys")
33+
34+
for k in self._state_dict_user_keys:
35+
if k not in state_dict:
36+
raise ValueError(
37+
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
38+
)

ignite/engine/engine.py

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,16 @@ def compute_mean_std(engine, batch):
129129
"""
130130

131131
_state_dict_all_req_keys = ("epoch_length", "max_epochs")
132-
_state_dict_one_of_opt_keys = ("iteration", "epoch")
132+
_state_dict_one_of_opt_keys = (
133+
(
134+
"iteration",
135+
"epoch",
136+
),
137+
(
138+
"max_epochs",
139+
"max_iters",
140+
),
141+
)
133142

134143
# Flag to disable engine._internal_run as generator feature for BC
135144
interrupt_resume_enabled = True
@@ -310,6 +319,7 @@ def execute_something():
310319
for e in event_name:
311320
self.add_event_handler(e, handler, *args, **kwargs)
312321
return RemovableEventHandle(event_name, handler, self)
322+
313323
if isinstance(event_name, CallableEventWithFilter) and event_name.filter is not None:
314324
event_filter = event_name.filter
315325
handler = self._handler_wrapper(handler, event_name, event_filter)
@@ -332,6 +342,16 @@ def execute_something():
332342

333343
return RemovableEventHandle(event_name, handler, self)
334344

345+
@staticmethod
346+
def _assert_non_filtered_event(event_name: Any) -> None:
347+
if (
348+
isinstance(event_name, CallableEventWithFilter)
349+
and event_name.filter != CallableEventWithFilter.default_event_filter
350+
):
351+
raise TypeError(
352+
"Argument event_name should not be a filtered event, " "please use event without any event filtering"
353+
)
354+
335355
def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) -> bool:
336356
"""Check if the specified event has the specified handler.
337357
@@ -675,7 +695,12 @@ def save_engine(_):
675695
a dictionary containing engine's state
676696
677697
"""
678-
keys: Tuple[str, ...] = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],)
698+
keys: Tuple[str, ...] = self._state_dict_all_req_keys
699+
keys += ("iteration",)
700+
if self.state.max_epochs is not None:
701+
keys += ("max_epochs",)
702+
else:
703+
keys += ("max_iters",)
679704
keys += tuple(self._state_dict_user_keys)
680705
return OrderedDict([(k, getattr(self.state, k)) for k in keys])
681706

@@ -728,6 +753,8 @@ def load_state_dict(self, state_dict: Mapping) -> None:
728753
f"Input state_dict: {state_dict}"
729754
)
730755
self.state.iteration = self.state.epoch_length * self.state.epoch
756+
self._check_and_set_max_epochs(state_dict.get("max_epochs", None))
757+
self._check_and_set_max_iters(state_dict.get("max_iters", None))
731758

732759
@staticmethod
733760
def _is_done(state: State) -> bool:
@@ -864,12 +891,26 @@ def switch_batch(engine):
864891

865892
epoch_length = self._get_data_length(data)
866893
if epoch_length is not None and epoch_length < 1:
867-
raise ValueError("Input data has zero size. Please provide non-empty data")
894+
raise ValueError(
895+
"Argument epoch_length is invalid. Please, either set a"
896+
" correct epoch_length value or check if input data has"
897+
" non-zero size."
898+
)
868899

869900
if max_iters is None:
870901
if max_epochs is None:
871902
max_epochs = 1
872903
else:
904+
if max_iters < 1:
905+
raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value")
906+
if (self.state.max_iters is not None) and max_iters <= self.state.iteration:
907+
raise ValueError(
908+
"Argument max_iters should be larger than the current iteration "
909+
f"defined in the state: {max_iters} vs {self.state.iteration}. "
910+
"Please, set engine.state.max_iters = None "
911+
"before calling engine.run() in order to restart the training from the beginning."
912+
)
913+
self.state.max_iters = max_iters
873914
if max_epochs is not None:
874915
raise ValueError(
875916
"Arguments max_iters and max_epochs are mutually exclusive."
@@ -932,6 +973,53 @@ def _setup_dataloader_iter(self) -> None:
932973
else:
933974
self._dataloader_iter = iter(self.state.dataloader)
934975

976+
def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None:
977+
if max_epochs is not None:
978+
if max_epochs < 1:
979+
raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value")
980+
if self.state.max_epochs is not None and max_epochs <= self.state.epoch:
981+
raise ValueError(
982+
"Argument max_epochs should be larger than the current epoch "
983+
f"defined in the state: {max_epochs} vs {self.state.epoch}. "
984+
"Please, set engine.state.max_epochs = None "
985+
"before calling engine.run() in order to restart the training from the beginning."
986+
)
987+
self.state.max_epochs = max_epochs
988+
989+
def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None:
990+
if max_iters is not None:
991+
if max_iters < 1:
992+
raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value")
993+
if (self.state.max_iters is not None) and max_iters <= self.state.iteration:
994+
raise ValueError(
995+
"Argument max_iters should be larger than the current iteration "
996+
f"defined in the state: {max_iters} vs {self.state.iteration}. "
997+
"Please, set engine.state.max_iters = None "
998+
"before calling engine.run() in order to restart the training from the beginning."
999+
)
1000+
self.state.max_iters = max_iters
1001+
1002+
def _check_and_set_epoch_length(self, data: Iterable, epoch_length: Optional[int] = None) -> None:
1003+
# Can't we accept a redefinition ?
1004+
if self.state.epoch_length is not None:
1005+
if epoch_length is not None:
1006+
if epoch_length != self.state.epoch_length:
1007+
raise ValueError(
1008+
"Argument epoch_length should be same as in the state, "
1009+
f"but given {epoch_length} vs {self.state.epoch_length}"
1010+
)
1011+
else:
1012+
if epoch_length is None:
1013+
epoch_length = self._get_data_length(data)
1014+
1015+
if epoch_length is not None and epoch_length < 1:
1016+
raise ValueError(
1017+
"Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
1018+
"check if input data has non-zero size."
1019+
)
1020+
1021+
self.state.epoch_length = epoch_length
1022+
9351023
def _setup_engine(self) -> None:
9361024
self._setup_dataloader_iter()
9371025

@@ -1064,6 +1152,13 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10641152
self.state.epoch_length = iter_counter
10651153
if self.state.max_iters is not None:
10661154
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
1155+
# Warn but will continue until max iters is reached
1156+
warnings.warn(
1157+
"Data iterator can not provide data anymore but required total number of "
1158+
"iterations to run is not reached. "
1159+
f"Current iteration: {self.state.iteration} vs Total iterations to run :"
1160+
f" {self.state.max_iters}"
1161+
)
10671162
break
10681163

10691164
# Should exit while loop if we can not iterate
@@ -1106,7 +1201,13 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11061201

11071202
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
11081203
self.should_terminate = True
1109-
raise _EngineTerminateException()
1204+
warnings.warn(
1205+
"Data iterator can not provide data anymore but required total number of "
1206+
"iterations to run is not reached. "
1207+
f"Current iteration: {self.state.iteration} vs Total iterations to run : ? total_iters"
1208+
)
1209+
break
1210+
# raise _EngineTerminateException()
11101211

11111212
except _EngineTerminateSingleEpochException:
11121213
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
@@ -1231,6 +1332,13 @@ def _run_once_on_dataset_legacy(self) -> float:
12311332
self.state.epoch_length = iter_counter
12321333
if self.state.max_iters is not None:
12331334
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
1335+
# Warn but will continue until max iters is reached
1336+
warnings.warn(
1337+
"Data iterator can not provide data anymore but required total number of "
1338+
"iterations to run is not reached. "
1339+
f"Current iteration: {self.state.iteration} vs Total iterations to run :"
1340+
f" {self.state.max_iters}"
1341+
)
12341342
break
12351343

12361344
# Should exit while loop if we can not iterate
@@ -1291,6 +1399,16 @@ def _run_once_on_dataset_legacy(self) -> float:
12911399

12921400
return time.time() - start_time
12931401

1402+
def debug(self, enabled: bool = True) -> None:
1403+
"""Enables/disables engine's logging debug mode"""
1404+
from ignite.utils import setup_logger
1405+
1406+
if enabled:
1407+
setattr(self, "_stored_logger", self.logger)
1408+
self.logger = setup_logger(level=logging.DEBUG)
1409+
elif hasattr(self, "_stored_logger"):
1410+
self.logger = getattr(self, "_stored_logger")
1411+
12941412

12951413
def _get_none_data_iter(size: int) -> Iterator:
12961414
# Sized iterator for data as None

tests/ignite/base/test_mixins.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def test_state_dict():
1313
with pytest.raises(NotImplementedError):
1414
s.state_dict()
1515

16+
1617
def test_load_state_dict():
1718

1819
s = ExampleSerializable()

tests/ignite/engine/test_engine.py

Lines changed: 9 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ignite.engine import Engine, Events, State
1212
from ignite.engine.deterministic import keep_random_state
1313
from ignite.metrics import Average
14-
from tests.ignite.engine import BatchChecker, EpochCounter, get_iterable_dataset, IterationCounter
14+
from tests.ignite.engine import BatchChecker, EpochCounter, IterationCounter
1515

1616

1717
class RecordedEngine(Engine):
@@ -520,7 +520,7 @@ def test__setup_engine(self):
520520
data = list(range(100))
521521
engine.state.dataloader = data
522522
engine._setup_engine()
523-
assert len(engine._init_iter) == 1 and engine._init_iter[0] == 10
523+
assert engine._init_iter == 10
524524

525525
def test_run_asserts(self):
526526
engine = Engine(lambda e, b: 1)
@@ -531,7 +531,7 @@ def test_run_asserts(self):
531531
r"value or check if input data has non-zero size.",
532532
):
533533
engine.run([])
534-
with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"):
534+
with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than or equal to the start epoch"):
535535
engine.state.max_epochs = 5
536536
engine.state.epoch = 5
537537
engine.run([0, 1], max_epochs=3)
@@ -707,8 +707,8 @@ def infinite_data_iterator():
707707
kwargs = dict(exp_iter_stops=0, n_epoch_started=2, n_epoch_completed=1)
708708
self._test_check_triggered_events(infinite_data_iterator(), max_iters=30, epoch_length=20, **kwargs)
709709

710-
def limited_data_iterator():
711-
for i in range(100):
710+
def limited_data_iterator(length=100):
711+
for i in range(length):
712712
yield i
713713

714714
self._test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=100, exp_iter_stops=0)
@@ -730,7 +730,7 @@ def limited_data_iterator():
730730
n_batch_completed=20,
731731
n_terminate=1,
732732
)
733-
self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=20, **kwargs)
733+
self._test_check_triggered_events(limited_data_iterator(length=20), max_epochs=3, epoch_length=20, **kwargs)
734734

735735
with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"):
736736
kwargs = dict(
@@ -1019,7 +1019,7 @@ def restart_iter():
10191019
_test(max_iters=known_size)
10201020
_test(max_iters=known_size // 2)
10211021

1022-
def test_faq_inf_iterator_with_epoch_length():
1022+
def test_faq_inf_iterator_with_epoch_length(self):
10231023
def _test(max_epochs, max_iters):
10241024
# Code snippet from FAQ
10251025
# import torch
@@ -1132,59 +1132,6 @@ def val_step(evaluator, batch):
11321132

11331133
_test(max_epochs=None, max_iters=None)
11341134

1135-
def test_faq_fin_iterator(self):
1136-
# Code snippet from FAQ
1137-
# import torch
1138-
1139-
torch.manual_seed(12)
1140-
1141-
size = 11
1142-
1143-
def finite_size_data_iter(size):
1144-
for i in range(size):
1145-
yield i
1146-
1147-
def train_step(trainer, batch):
1148-
# ...
1149-
s = trainer.state
1150-
print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}")
1151-
1152-
trainer = Engine(train_step)
1153-
1154-
@trainer.on(Events.ITERATION_COMPLETED(every=size))
1155-
def restart_iter():
1156-
trainer.state.dataloader = finite_size_data_iter(size)
1157-
1158-
data_iter = finite_size_data_iter(size)
1159-
trainer.run(data_iter, max_epochs=5)
1160-
1161-
assert trainer.state.epoch == 5
1162-
assert trainer.state.iteration == 5 * size
1163-
1164-
# Code snippet from FAQ
1165-
# import torch
1166-
1167-
torch.manual_seed(12)
1168-
1169-
size = 11
1170-
1171-
def finite_size_data_iter(size):
1172-
for i in range(size):
1173-
yield i
1174-
1175-
def val_step(evaluator, batch):
1176-
# ...
1177-
s = evaluator.state
1178-
print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}")
1179-
1180-
evaluator = Engine(val_step)
1181-
1182-
data_iter = finite_size_data_iter(size)
1183-
evaluator.run(data_iter)
1184-
1185-
assert evaluator.state.epoch == 1
1186-
assert evaluator.state.iteration == size
1187-
11881135
def test_faq_fin_iterator(self):
11891136
def _test(max_epochs, max_iters):
11901137
# Code snippet from FAQ
@@ -1475,7 +1422,8 @@ def test_restart_training(self):
14751422
state = engine.run(data, max_epochs=5)
14761423
with pytest.raises(
14771424
ValueError,
1478-
match=r"Argument max_epochs should be larger than the current epoch defined in the state: 2 vs 5. "
1425+
match=r"Argument max_epochs should be greater than or equal to the start epoch"
1426+
" defined in the state: 2 vs 5. "
14791427
r"Please, .+ "
14801428
r"before calling engine.run\(\) in order to restart the training from the beginning.",
14811429
):

0 commit comments

Comments
 (0)