11from collections import OrderedDict
22from collections .abc import Mapping
3- from typing import Tuple
3+ from typing import List , Tuple
44
55
66class Serializable :
7- _state_dict_all_req_keys : Tuple = ()
8- _state_dict_one_of_opt_keys : Tuple = ()
7+ _state_dict_all_req_keys : Tuple [str , ...] = ()
8+ _state_dict_one_of_opt_keys : Tuple [Tuple [str , ...], ...] = ((),)
9+
10+ def __init__ (self ) -> None :
11+ self ._state_dict_user_keys : List [str ] = []
12+
13+ @property
14+ def state_dict_user_keys (self ) -> List :
15+ return self ._state_dict_user_keys
916
1017 def state_dict (self ) -> OrderedDict :
1118 raise NotImplementedError
@@ -19,6 +26,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
1926 raise ValueError (
2027 f"Required state attribute '{ k } ' is absent in provided state_dict '{ state_dict .keys ()} '"
2128 )
22- opts = [k in state_dict for k in self ._state_dict_one_of_opt_keys ]
23- if len (opts ) > 0 and ((not any (opts )) or (all (opts ))):
24- raise ValueError (f"state_dict should contain only one of '{ self ._state_dict_one_of_opt_keys } ' keys" )
29+ for one_of_opt_keys in self ._state_dict_one_of_opt_keys :
30+ opts = [k in state_dict for k in one_of_opt_keys ]
31+ if len (opts ) > 0 and (not any (opts )) or (all (opts )):
32+ raise ValueError (f"state_dict should contain only one of '{ one_of_opt_keys } ' keys" )
33+
34+ for k in self ._state_dict_user_keys :
35+ if k not in state_dict :
36+ raise ValueError (
37+ f"Required user state attribute '{ k } ' is absent in provided state_dict '{ state_dict .keys ()} '"
38+ )
0 commit comments