Skip to content

Commit af4f548

Browse files
committed
update mixins
1 parent f6b85ed commit af4f548

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

ignite/base/mixins.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from collections import OrderedDict
22
from collections.abc import Mapping
3-
from typing import Tuple
3+
from typing import List, Tuple
44

55

66
class Serializable:
7-
_state_dict_all_req_keys: Tuple = ()
8-
_state_dict_one_of_opt_keys: Tuple = ()
7+
_state_dict_all_req_keys: Tuple[str, ...] = ()
8+
_state_dict_one_of_opt_keys: Tuple[Tuple[str, ...], ...] = ((),)
9+
10+
def __init__(self) -> None:
11+
self._state_dict_user_keys: List[str] = []
12+
13+
@property
14+
def state_dict_user_keys(self) -> List:
15+
return self._state_dict_user_keys
916

1017
def state_dict(self) -> OrderedDict:
1118
raise NotImplementedError
@@ -19,6 +26,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
1926
raise ValueError(
2027
f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
2128
)
22-
opts = [k in state_dict for k in self._state_dict_one_of_opt_keys]
23-
if len(opts) > 0 and ((not any(opts)) or (all(opts))):
24-
raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys")
29+
for one_of_opt_keys in self._state_dict_one_of_opt_keys:
30+
opts = [k in state_dict for k in one_of_opt_keys]
31+
if len(opts) > 0 and (not any(opts)) or (all(opts)):
32+
raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys")
33+
34+
for k in self._state_dict_user_keys:
35+
if k not in state_dict:
36+
raise ValueError(
37+
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
38+
)

0 commit comments

Comments
 (0)