1515# limitations under the License.
1616"""Extends `dill` to support pickling more types and produce more consistent dumps."""
1717
18- import sys
18+ from collections . abc import Callable , Iterable
1919from io import BytesIO
20+ import sys
2021from types import FunctionType
21- from typing import Any , Dict , List , Union
22+ from typing import TYPE_CHECKING , Any , BinaryIO , ClassVar
2223
2324import dill
2425import xxhash
2526
27+ if TYPE_CHECKING :
28+ import spacy
29+ import torch
30+
2631
2732class Hasher :
2833 """Hasher that accepts python objects as inputs."""
2934
30- dispatch : Dict = {}
35+ dispatch : ClassVar [ dict [ Any , Any ]] = {}
3136
32- def __init__ (self ):
37+ def __init__ (self ) -> None :
3338 self .m = xxhash .xxh64 ()
3439
3540 @classmethod
36- def hash_bytes (cls , value : Union [bytes , List [bytes ]]) -> str :
41+ def hash_bytes (cls , value : bytes | list [bytes ]) -> str :
42+ """Hash bytes or list of bytes using xxhash."""
3743 value = [value ] if isinstance (value , bytes ) else value
3844 m = xxhash .xxh64 ()
3945 for x in value :
4046 m .update (x )
4147 return m .hexdigest ()
4248
4349 @classmethod
44- def hash (cls , value : Any ) -> str :
50+ def hash (cls , value : object ) -> str :
51+ """Hash a Python object by pickling it first."""
4552 return cls .hash_bytes (dumps (value ))
4653
47- def update (self , value : Any ) -> None :
54+ def update (self , value : object ) -> None :
55+ """Update the hasher with a Python object."""
4856 header_for_update = f"=={ type (value )} =="
4957 value_for_update = self .hash (value )
5058 self .m .update (header_for_update .encode ("utf8" ))
5159 self .m .update (value_for_update .encode ("utf-8" ))
5260
5361 def hexdigest (self ) -> str :
62+ """Return the hexadecimal digest of the hash."""
5463 return self .m .hexdigest ()
5564
5665
5766class Pickler (dill .Pickler ):
58- dispatch = dill ._dill .MetaCatchingDict (dill .Pickler .dispatch .copy ())
67+ """Custom Pickler that extends dill with additional type support."""
68+
69+ dispatch = dill ._dill .MetaCatchingDict (dill .Pickler .dispatch .copy ()) # noqa: SLF001
5970 _legacy_no_dict_keys_sorting = False
6071
61- def save (self , obj , save_persistent_id = True ):
72+ def save (self , obj : Any , save_persistent_id : bool = True ) -> None :
73+ """Save an object to the pickle stream with custom handling."""
6274 obj_type = type (obj )
6375 if obj_type not in self .dispatch :
6476 if "regex" in sys .modules :
65- import regex # type: ignore
77+ import regex # noqa: PLC0415
6678
6779 if obj_type is regex .Pattern :
6880 pklregister (obj_type )(_save_regexPattern )
6981 if "spacy" in sys .modules :
70- import spacy # type: ignore
82+ import spacy # noqa: PLC0415
7183
7284 if issubclass (obj_type , spacy .Language ):
7385 pklregister (obj_type )(_save_spacyLanguage )
7486 if "tiktoken" in sys .modules :
75- import tiktoken # type: ignore
87+ import tiktoken # noqa: PLC0415
7688
7789 if obj_type is tiktoken .Encoding :
7890 pklregister (obj_type )(_save_tiktokenEncoding )
7991 if "torch" in sys .modules :
80- import torch # type: ignore
92+ import torch # noqa: PLC0415
8193
8294 if issubclass (obj_type , torch .Tensor ):
8395 pklregister (obj_type )(_save_torchTensor )
@@ -89,7 +101,7 @@ def save(self, obj, save_persistent_id=True):
89101 if issubclass (obj_type , torch .nn .Module ):
90102 obj = getattr (obj , "_orig_mod" , obj )
91103 if "transformers" in sys .modules :
92- import transformers # type: ignore
104+ import transformers # noqa: PLC0415
93105
94106 if issubclass (obj_type , transformers .PreTrainedTokenizerBase ):
95107 pklregister (obj_type )(_save_transformersPreTrainedTokenizerBase )
@@ -99,90 +111,92 @@ def save(self, obj, save_persistent_id=True):
99111 obj = getattr (obj , "_torchdynamo_orig_callable" , obj )
100112 dill .Pickler .save (self , obj , save_persistent_id = save_persistent_id )
101113
102- def _batch_setitems (self , items ) :
114+ def _batch_setitems (self , items : Iterable [ tuple [ Any , Any ]]) -> None :
103115 if self ._legacy_no_dict_keys_sorting :
104- return super ()._batch_setitems (items )
116+ super ()._batch_setitems (items )
117+ return
105118 # Ignore the order of keys in a dict
106119 try :
107120 # Faster, but fails for unorderable elements
108- items = sorted (items )
121+ sorted_items = sorted (items )
109122 except Exception : # TypeError, decimal.InvalidOperation, etc.
110- items = sorted (items , key = lambda x : Hasher .hash (x [0 ]))
111- dill . Pickler . _batch_setitems (self , items )
123+ sorted_items = sorted (items , key = lambda x : Hasher .hash (x [0 ]))
124+ super (). _batch_setitems (sorted_items )
112125
113- def memoize (self , obj ):
126+ def memoize (self , obj : Any ) -> None :
127+ """Memoize an object, skipping strings to avoid id issues."""
114128 # Don't memoize strings since two identical strings can have different Python ids
115129 if type (obj ) is not str : # noqa: E721
116130 dill .Pickler .memoize (self , obj )
117131
118132
119- def pklregister (t ) :
133+ def pklregister (t : Any ) -> Callable [[ Any ], Any ] :
120134 """Register a custom reducer for the type."""
121135
122- def proxy (func ) :
136+ def proxy (func : Any ) -> Any :
123137 Pickler .dispatch [t ] = func
124138 return func
125139
126140 return proxy
127141
128142
129- def dump (obj , file ) :
143+ def dump (obj : Any , file : BinaryIO ) -> None :
130144 """Pickle an object to a file."""
131145 Pickler (file , recurse = True ).dump (obj )
132146
133147
134- def dumps (obj ) :
148+ def dumps (obj : Any ) -> bytes :
135149 """Pickle an object to a string."""
136150 file = BytesIO ()
137151 dump (obj , file )
138152 return file .getvalue ()
139153
140154
141- def log (pickler , msg ) :
142- pass
155+ def log (pickler : Any , msg : Any ) -> None :
156+ """Log a message from the pickler (no-op)."""
143157
144158
145- def _save_regexPattern (pickler , obj ) :
146- import regex # type: ignore
159+ def _save_regexPattern (pickler : Any , obj : Any ) -> None :
160+ import regex # noqa: PLC0415
147161
148162 log (pickler , f"Re: { obj } " )
149163 args = (obj .pattern , obj .flags )
150164 pickler .save_reduce (regex .compile , args , obj = obj )
151165 log (pickler , "# Re" )
152166
153167
154- def _save_tiktokenEncoding (pickler , obj ) :
155- import tiktoken # type: ignore
168+ def _save_tiktokenEncoding (pickler : Any , obj : Any ) -> None :
169+ import tiktoken # noqa: PLC0415
156170
157171 log (pickler , f"Enc: { obj } " )
158- args = (obj .name , obj ._pat_str , obj ._mergeable_ranks , obj ._special_tokens )
172+ args = (obj .name , obj ._pat_str , obj ._mergeable_ranks , obj ._special_tokens ) # noqa: SLF001
159173 pickler .save_reduce (tiktoken .Encoding , args , obj = obj )
160174 log (pickler , "# Enc" )
161175
162176
163- def _save_torchTensor (pickler , obj ) :
164- import torch # type: ignore
177+ def _save_torchTensor (pickler : Any , obj : Any ) -> None :
178+ import torch # noqa: PLC0415
165179
166180 # `torch.from_numpy` is not picklable in `torch>=1.11.0`
167- def create_torchTensor (np_array , dtype = None ):
181+ def create_torchTensor (np_array : Any , dtype : Any = None ) -> "torch.Tensor" :
168182 tensor = torch .from_numpy (np_array )
169183 if dtype :
170184 tensor = tensor .type (dtype )
171185 return tensor
172186
173187 log (pickler , f"To: { obj } " )
174188 if obj .dtype == torch .bfloat16 :
175- args = (obj .detach ().to (torch .float ).cpu ().numpy (), torch .bfloat16 )
189+ args : tuple [ Any , ...] = (obj .detach ().to (torch .float ).cpu ().numpy (), torch .bfloat16 )
176190 else :
177191 args = (obj .detach ().cpu ().numpy (),)
178192 pickler .save_reduce (create_torchTensor , args , obj = obj )
179193 log (pickler , "# To" )
180194
181195
182- def _save_torchGenerator (pickler , obj ) :
183- import torch # type: ignore
196+ def _save_torchGenerator (pickler : Any , obj : Any ) -> None :
197+ import torch # noqa: PLC0415
184198
185- def create_torchGenerator (state ) :
199+ def create_torchGenerator (state : Any ) -> "torch.Generator" :
186200 generator = torch .Generator ()
187201 generator .set_state (state )
188202 return generator
@@ -193,10 +207,10 @@ def create_torchGenerator(state):
193207 log (pickler , "# Ge" )
194208
195209
196- def _save_spacyLanguage (pickler , obj ) :
197- import spacy # type: ignore
210+ def _save_spacyLanguage (pickler : Any , obj : Any ) -> None :
211+ import spacy # noqa: PLC0415
198212
199- def create_spacyLanguage (config , bytes ) :
213+ def create_spacyLanguage (config : Any , bytes : Any ) -> "spacy.Language" :
200214 lang_cls = spacy .util .get_lang_class (config ["nlp" ]["lang" ])
201215 lang_inst = lang_cls .from_config (config )
202216 return lang_inst .from_bytes (bytes )
@@ -207,11 +221,11 @@ def create_spacyLanguage(config, bytes):
207221 log (pickler , "# Sp" )
208222
209223
210- def _save_transformersPreTrainedTokenizerBase (pickler , obj ) :
224+ def _save_transformersPreTrainedTokenizerBase (pickler : Any , obj : Any ) -> None :
211225 log (pickler , f"Tok: { obj } " )
212226 # Ignore the `cache` attribute
213227 state = obj .__dict__
214228 if "cache" in state and isinstance (state ["cache" ], dict ):
215229 state ["cache" ] = {}
216230 pickler .save_reduce (type (obj ), (), state = state , obj = obj )
217- log (pickler , "# Tok" )
231+ log (pickler , "# Tok" )
0 commit comments