Skip to content

Commit 9a99c9c

Browse files
authored
Merge pull request #101 from Snuffy2/Linting/Formatting-of-utils-folder
Linting/Formatting of utils folder
2 parents 3667658 + 8189b60 commit 9a99c9c

File tree

4 files changed

+95
-53
lines changed

4 files changed

+95
-53
lines changed

app/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1+
"""Utility modules for the mlx-openai-server."""

app/utils/dill.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,69 +15,81 @@
1515
# limitations under the License.
1616
"""Extends `dill` to support pickling more types and produce more consistent dumps."""
1717

18-
import sys
18+
from collections.abc import Callable, Iterable
1919
from io import BytesIO
20+
import sys
2021
from types import FunctionType
21-
from typing import Any, Dict, List, Union
22+
from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar
2223

2324
import dill
2425
import xxhash
2526

27+
if TYPE_CHECKING:
28+
import spacy
29+
import torch
30+
2631

2732
class Hasher:
2833
"""Hasher that accepts python objects as inputs."""
2934

30-
dispatch: Dict = {}
35+
dispatch: ClassVar[dict[Any, Any]] = {}
3136

32-
def __init__(self):
37+
def __init__(self) -> None:
3338
self.m = xxhash.xxh64()
3439

3540
@classmethod
36-
def hash_bytes(cls, value: Union[bytes, List[bytes]]) -> str:
41+
def hash_bytes(cls, value: bytes | list[bytes]) -> str:
42+
"""Hash bytes or list of bytes using xxhash."""
3743
value = [value] if isinstance(value, bytes) else value
3844
m = xxhash.xxh64()
3945
for x in value:
4046
m.update(x)
4147
return m.hexdigest()
4248

4349
@classmethod
44-
def hash(cls, value: Any) -> str:
50+
def hash(cls, value: object) -> str:
51+
"""Hash a Python object by pickling it first."""
4552
return cls.hash_bytes(dumps(value))
4653

47-
def update(self, value: Any) -> None:
54+
def update(self, value: object) -> None:
55+
"""Update the hasher with a Python object."""
4856
header_for_update = f"=={type(value)}=="
4957
value_for_update = self.hash(value)
5058
self.m.update(header_for_update.encode("utf8"))
5159
self.m.update(value_for_update.encode("utf-8"))
5260

5361
def hexdigest(self) -> str:
62+
"""Return the hexadecimal digest of the hash."""
5463
return self.m.hexdigest()
5564

5665

5766
class Pickler(dill.Pickler):
58-
dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy())
67+
"""Custom Pickler that extends dill with additional type support."""
68+
69+
dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy()) # noqa: SLF001
5970
_legacy_no_dict_keys_sorting = False
6071

61-
def save(self, obj, save_persistent_id=True):
72+
def save(self, obj: Any, save_persistent_id: bool = True) -> None:
73+
"""Save an object to the pickle stream with custom handling."""
6274
obj_type = type(obj)
6375
if obj_type not in self.dispatch:
6476
if "regex" in sys.modules:
65-
import regex # type: ignore
77+
import regex # noqa: PLC0415
6678

6779
if obj_type is regex.Pattern:
6880
pklregister(obj_type)(_save_regexPattern)
6981
if "spacy" in sys.modules:
70-
import spacy # type: ignore
82+
import spacy # noqa: PLC0415
7183

7284
if issubclass(obj_type, spacy.Language):
7385
pklregister(obj_type)(_save_spacyLanguage)
7486
if "tiktoken" in sys.modules:
75-
import tiktoken # type: ignore
87+
import tiktoken # noqa: PLC0415
7688

7789
if obj_type is tiktoken.Encoding:
7890
pklregister(obj_type)(_save_tiktokenEncoding)
7991
if "torch" in sys.modules:
80-
import torch # type: ignore
92+
import torch # noqa: PLC0415
8193

8294
if issubclass(obj_type, torch.Tensor):
8395
pklregister(obj_type)(_save_torchTensor)
@@ -89,7 +101,7 @@ def save(self, obj, save_persistent_id=True):
89101
if issubclass(obj_type, torch.nn.Module):
90102
obj = getattr(obj, "_orig_mod", obj)
91103
if "transformers" in sys.modules:
92-
import transformers # type: ignore
104+
import transformers # noqa: PLC0415
93105

94106
if issubclass(obj_type, transformers.PreTrainedTokenizerBase):
95107
pklregister(obj_type)(_save_transformersPreTrainedTokenizerBase)
@@ -99,90 +111,92 @@ def save(self, obj, save_persistent_id=True):
99111
obj = getattr(obj, "_torchdynamo_orig_callable", obj)
100112
dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id)
101113

102-
def _batch_setitems(self, items):
114+
def _batch_setitems(self, items: Iterable[tuple[Any, Any]]) -> None:
103115
if self._legacy_no_dict_keys_sorting:
104-
return super()._batch_setitems(items)
116+
super()._batch_setitems(items)
117+
return
105118
# Ignore the order of keys in a dict
106119
try:
107120
# Faster, but fails for unorderable elements
108-
items = sorted(items)
121+
sorted_items = sorted(items)
109122
except Exception: # TypeError, decimal.InvalidOperation, etc.
110-
items = sorted(items, key=lambda x: Hasher.hash(x[0]))
111-
dill.Pickler._batch_setitems(self, items)
123+
sorted_items = sorted(items, key=lambda x: Hasher.hash(x[0]))
124+
super()._batch_setitems(sorted_items)
112125

113-
def memoize(self, obj):
126+
def memoize(self, obj: Any) -> None:
127+
"""Memoize an object, skipping strings to avoid id issues."""
114128
# Don't memoize strings since two identical strings can have different Python ids
115129
if type(obj) is not str: # noqa: E721
116130
dill.Pickler.memoize(self, obj)
117131

118132

119-
def pklregister(t):
133+
def pklregister(t: Any) -> Callable[[Any], Any]:
120134
"""Register a custom reducer for the type."""
121135

122-
def proxy(func):
136+
def proxy(func: Any) -> Any:
123137
Pickler.dispatch[t] = func
124138
return func
125139

126140
return proxy
127141

128142

129-
def dump(obj, file):
143+
def dump(obj: Any, file: BinaryIO) -> None:
130144
"""Pickle an object to a file."""
131145
Pickler(file, recurse=True).dump(obj)
132146

133147

134-
def dumps(obj):
148+
def dumps(obj: Any) -> bytes:
135149
"""Pickle an object to a string."""
136150
file = BytesIO()
137151
dump(obj, file)
138152
return file.getvalue()
139153

140154

141-
def log(pickler, msg):
142-
pass
155+
def log(pickler: Any, msg: Any) -> None:
156+
"""Log a message from the pickler (no-op)."""
143157

144158

145-
def _save_regexPattern(pickler, obj):
146-
import regex # type: ignore
159+
def _save_regexPattern(pickler: Any, obj: Any) -> None:
160+
import regex # noqa: PLC0415
147161

148162
log(pickler, f"Re: {obj}")
149163
args = (obj.pattern, obj.flags)
150164
pickler.save_reduce(regex.compile, args, obj=obj)
151165
log(pickler, "# Re")
152166

153167

154-
def _save_tiktokenEncoding(pickler, obj):
155-
import tiktoken # type: ignore
168+
def _save_tiktokenEncoding(pickler: Any, obj: Any) -> None:
169+
import tiktoken # noqa: PLC0415
156170

157171
log(pickler, f"Enc: {obj}")
158-
args = (obj.name, obj._pat_str, obj._mergeable_ranks, obj._special_tokens)
172+
args = (obj.name, obj._pat_str, obj._mergeable_ranks, obj._special_tokens) # noqa: SLF001
159173
pickler.save_reduce(tiktoken.Encoding, args, obj=obj)
160174
log(pickler, "# Enc")
161175

162176

163-
def _save_torchTensor(pickler, obj):
164-
import torch # type: ignore
177+
def _save_torchTensor(pickler: Any, obj: Any) -> None:
178+
import torch # noqa: PLC0415
165179

166180
# `torch.from_numpy` is not picklable in `torch>=1.11.0`
167-
def create_torchTensor(np_array, dtype=None):
181+
def create_torchTensor(np_array: Any, dtype: Any = None) -> "torch.Tensor":
168182
tensor = torch.from_numpy(np_array)
169183
if dtype:
170184
tensor = tensor.type(dtype)
171185
return tensor
172186

173187
log(pickler, f"To: {obj}")
174188
if obj.dtype == torch.bfloat16:
175-
args = (obj.detach().to(torch.float).cpu().numpy(), torch.bfloat16)
189+
args: tuple[Any, ...] = (obj.detach().to(torch.float).cpu().numpy(), torch.bfloat16)
176190
else:
177191
args = (obj.detach().cpu().numpy(),)
178192
pickler.save_reduce(create_torchTensor, args, obj=obj)
179193
log(pickler, "# To")
180194

181195

182-
def _save_torchGenerator(pickler, obj):
183-
import torch # type: ignore
196+
def _save_torchGenerator(pickler: Any, obj: Any) -> None:
197+
import torch # noqa: PLC0415
184198

185-
def create_torchGenerator(state):
199+
def create_torchGenerator(state: Any) -> "torch.Generator":
186200
generator = torch.Generator()
187201
generator.set_state(state)
188202
return generator
@@ -193,10 +207,10 @@ def create_torchGenerator(state):
193207
log(pickler, "# Ge")
194208

195209

196-
def _save_spacyLanguage(pickler, obj):
197-
import spacy # type: ignore
210+
def _save_spacyLanguage(pickler: Any, obj: Any) -> None:
211+
import spacy # noqa: PLC0415
198212

199-
def create_spacyLanguage(config, bytes):
213+
def create_spacyLanguage(config: Any, bytes: Any) -> "spacy.Language":
200214
lang_cls = spacy.util.get_lang_class(config["nlp"]["lang"])
201215
lang_inst = lang_cls.from_config(config)
202216
return lang_inst.from_bytes(bytes)
@@ -207,11 +221,11 @@ def create_spacyLanguage(config, bytes):
207221
log(pickler, "# Sp")
208222

209223

210-
def _save_transformersPreTrainedTokenizerBase(pickler, obj):
224+
def _save_transformersPreTrainedTokenizerBase(pickler: Any, obj: Any) -> None:
211225
log(pickler, f"Tok: {obj}")
212226
# Ignore the `cache` attribute
213227
state = obj.__dict__
214228
if "cache" in state and isinstance(state["cache"], dict):
215229
state["cache"] = {}
216230
pickler.save_reduce(type(obj), (), state=state, obj=obj)
217-
log(pickler, "# Tok")
231+
log(pickler, "# Tok")

app/utils/errors.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,43 @@
1+
"""Utilities for creating error responses."""
2+
13
from http import HTTPStatus
2-
from typing import Union
4+
35

46
def create_error_response(
57
message: str,
68
err_type: str = "internal_error",
7-
status_code: Union[int, HTTPStatus] = HTTPStatus.INTERNAL_SERVER_ERROR,
8-
param: str = None,
9-
code: str = None
10-
):
9+
status_code: int | HTTPStatus = HTTPStatus.INTERNAL_SERVER_ERROR,
10+
param: str | None = None,
11+
code: str | None = None,
12+
) -> dict[str, object]:
13+
"""
14+
Create a standardized error response dictionary.
15+
16+
Parameters
17+
----------
18+
message : str
19+
The error message to include in the response.
20+
err_type : str, optional
21+
The type of error, by default "internal_error".
22+
status_code : int or HTTPStatus, optional
23+
The HTTP status code, by default HTTPStatus.INTERNAL_SERVER_ERROR.
24+
param : str or None, optional
25+
The parameter that caused the error, by default None.
26+
code : str or None, optional
27+
The error code, by default None.
28+
29+
Returns
30+
-------
31+
dict[str, object]
32+
A dictionary containing the error response structure.
33+
"""
1134
return {
1235
"error": {
1336
"message": message,
1437
"type": err_type,
1538
"param": param,
16-
"code": str(code or (status_code.value if isinstance(status_code, HTTPStatus) else status_code))
39+
"code": str(
40+
code or (status_code.value if isinstance(status_code, HTTPStatus) else status_code)
41+
),
1742
}
18-
}
43+
}
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
"""Custom tokenizer class for outlines integration."""
2+
13
from outlines.models.transformers import TransformerTokenizer
24

35
from .dill import Hasher
46

57

68
class OutlinesTransformerTokenizer(TransformerTokenizer):
79
"""
8-
Update the outlines TransformerTokenizer to use our own Hasher class, so that we don't need the datasets dependency
10+
Update the outlines TransformerTokenizer to use our own Hasher class, so that we don't need the datasets dependency.
911
1012
This class and the external dependency can be removed when the following import is deleted
1113
https://github.com/dottxt-ai/outlines/blob/69418d/outlines/models/transformers.py#L117
1214
"""
1315

14-
def __hash__(self):
16+
def __hash__(self) -> int:
17+
"""Return a hash based on the tokenizer using custom hasher."""
1518
return hash(Hasher.hash(self.tokenizer))

0 commit comments

Comments
 (0)