Skip to content

Commit 802cef4

Browse files
authored
convert : parse safetensors directly (#15667)
* convert : parse safetensors directly * gguf-py : order safetensors tensors by name Applies to both local and remote safetensors custom parsing. This matches the behavior of the official safetensors implementation. * convert : rename from_safetensors_meta to from_local_tensor For consistency with from_remote_tensor * convert : fix no-lazy dtypes from direct safetensors
1 parent 1c07c0c commit 802cef4

File tree

2 files changed

+98
-9
lines changed

2 files changed

+98
-9
lines changed

convert_hf_to_gguf.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,7 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
218218
logger.info(f"gguf: indexing model part '{part_name}'")
219219
ctx: ContextManager[Any]
220220
if is_safetensors:
221-
from safetensors import safe_open
222-
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
221+
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
223222
else:
224223
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
225224

@@ -228,18 +227,18 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
228227

229228
for name in model_part.keys():
230229
if is_safetensors:
230+
data: gguf.utility.LocalTensor = model_part[name]
231231
if self.lazy:
232-
data = model_part.get_slice(name)
233-
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
232+
data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
234233
else:
235-
data = model_part.get_tensor(name)
236-
data_gen = lambda data=data: data # noqa: E731
234+
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
235+
data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
237236
else:
238-
data = model_part[name]
237+
data_torch: Tensor = model_part[name]
239238
if self.lazy:
240-
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
239+
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
241240
else:
242-
data_gen = lambda data=data: data # noqa: E731
241+
data_gen = lambda data=data_torch: data # noqa: E731
243242
tensors[name] = data_gen
244243

245244
# verify tensor name presence and identify potentially missing files
@@ -10079,6 +10078,16 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
1007910078
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
1008010079
return cast(torch.Tensor, lazy)
1008110080

10081+
@classmethod
10082+
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
10083+
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
10084+
dtype = cls._dtype_str_map[tensor.dtype]
10085+
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
10086+
dtype = cls._dtype_str_map[t.dtype]
10087+
shape = t.shape
10088+
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
10089+
return cast(torch.Tensor, lazy)
10090+
1008210091
@classmethod
1008310092
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
1008410093
dtype = cls._dtype_str_map[remote_tensor.dtype]

gguf-py/gguf/utility.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4+
from pathlib import Path
45
from typing import Literal
56

67
import os
78
import json
9+
import numpy as np
810

911

1012
def fill_templated_filename(filename: str, output_type: str | None) -> str:
@@ -177,6 +179,10 @@ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
177179
except KeyError as e:
178180
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
179181

182+
# order by name (same as default safetensors behavior)
183+
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
184+
res = dict(sorted(res.items(), key=lambda t: t[0]))
185+
180186
return res
181187

182188
@classmethod
@@ -266,3 +272,77 @@ def _get_request_headers(cls) -> dict[str, str]:
266272
if os.environ.get("HF_TOKEN"):
267273
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
268274
return headers
275+
276+
277+
@dataclass
278+
class LocalTensorRange:
279+
filename: Path
280+
offset: int
281+
size: int
282+
283+
284+
@dataclass
285+
class LocalTensor:
286+
dtype: str
287+
shape: tuple[int, ...]
288+
data_range: LocalTensorRange
289+
290+
def mmap_bytes(self) -> np.ndarray:
291+
return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
292+
293+
294+
class SafetensorsLocal:
295+
"""
296+
Read a safetensors file from the local filesystem.
297+
298+
Custom parsing gives a bit more control over the memory usage.
299+
The official safetensors library doesn't expose file ranges.
300+
"""
301+
ALIGNMENT = 8 # bytes
302+
303+
tensors: dict[str, LocalTensor]
304+
305+
def __init__(self, filename: Path):
306+
with open(filename, "rb") as f:
307+
metadata_length = int.from_bytes(f.read(8), byteorder='little')
308+
file_size = os.stat(filename).st_size
309+
if file_size < 8 + metadata_length:
310+
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
311+
312+
metadata_str = f.read(metadata_length).decode('utf-8')
313+
try:
314+
metadata = json.loads(metadata_str)
315+
except json.JSONDecodeError as e:
316+
raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
317+
318+
data_start_offset = f.tell()
319+
alignment = self.ALIGNMENT
320+
if data_start_offset % alignment != 0:
321+
data_start_offset += alignment - (data_start_offset % alignment)
322+
323+
tensors: dict[str, LocalTensor] = {}
324+
for name, meta in metadata.items():
325+
if name == "__metadata__":
326+
# ignore metadata, it's not a tensor
327+
continue
328+
329+
tensors[name] = LocalTensor(
330+
dtype=meta["dtype"],
331+
shape=tuple(meta["shape"]),
332+
data_range=LocalTensorRange(
333+
filename,
334+
data_start_offset + meta["data_offsets"][0],
335+
meta["data_offsets"][1] - meta["data_offsets"][0],
336+
),
337+
)
338+
339+
# order by name (same as default safetensors behavior)
340+
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
341+
self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
342+
343+
def __enter__(self, *args, **kwargs):
344+
del args, kwargs # unused
345+
return self.tensors
346+
347+
def __exit__(self, *args, **kwargs):
348+
del args, kwargs # unused

0 commit comments

Comments
 (0)