Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 92 additions & 8 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,38 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)

return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T

# ref: https://github.com/vllm-project/compressed-tensors/blob/52792be02ec09e59f3517104e755a02d0e003fbb/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py
def dequant_compressed_tensor(weight: Tensor, scale: Tensor) -> Tensor:
weights_config = quant_config["config_groups"]["group_0"]["weights"]
group_size = weights_config["group_size"]
num_bits = weights_config["num_bits"]
# only tested with https://huggingface.co/moonshotai/Kimi-K2-Thinking/blob/main/config.json
# TODO: extend this if other configurations are needed
assert(group_size == 32)
assert(num_bits == 4)
assert(quant_config["format"] == "pack-quantized")

pack_factor = group_size // num_bits
mask = (1 << num_bits) - 1
unpacked = torch.zeros(
(weight.shape[0], weight.shape[1] * pack_factor),
dtype=torch.int32,
)
if self.lazy:
unpacked = LazyTorchTensor.from_eager(unpacked)
else:
unpacked = unpacked.to(weight.device) # is this needed?
for i in range(pack_factor):
unpacked[:, i::pack_factor] = (weight >> (num_bits * i)) & mask
# TODO: may need to unpad
unpacked = unpacked - (mask + 1) // 2 # convert uint4 to int4 (shift scale)
scale = scale.to(torch.float32)
scale = scale.unsqueeze(2)
unpacked = unpacked.to(torch.float32)
unpacked = unpacked.reshape(-1, unpacked.shape[1] // group_size, group_size)
dequantized = (unpacked * scale).reshape(-1, unpacked.shape[1] * group_size)
return dequantized

if quant_method == "bitnet":
for name in self.model_tensors.keys():
if name.endswith(".weight_scale"):
Expand Down Expand Up @@ -371,6 +403,22 @@ def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor)
".scales",
)
]
elif quant_method == "compressed-tensors":
for name in self.model_tensors.keys():
if name.endswith("_packed"):
base_name = name.removesuffix("_packed")
packed = self.model_tensors[base_name + "_packed"]
scale = self.model_tensors[base_name + "_scale"]
# TODO: use _shape for unpadding if necessary
new_tensors[base_name] = lambda p=packed, s=scale: dequant_compressed_tensor(p(), s())
tensors_to_remove += [
base_name + n
for n in (
"_packed",
"_scale",
"_shape",
)
]
else:
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")

Expand Down Expand Up @@ -441,7 +489,7 @@ def prepare_tensors(self):
old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
if data_torch.dtype not in (torch.float16, torch.float32, torch.int32):
data_torch = data_torch.to(torch.float32)

# use the first number-like part of the tensor name as the block id
Expand Down Expand Up @@ -7045,6 +7093,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])

_experts: list[dict[str, Tensor]] | None = None
_experts_s: list[dict[str, Tensor]] | None = None # scale (for quantized experts)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
Expand Down Expand Up @@ -7072,28 +7121,42 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch
if self._experts_s is None:
self._experts_s = [{} for _ in range(self.block_count)]

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
if name.endswith(".weight_packed"):
self._experts[bid][name] = data_torch

if name.endswith(".weight_scale"):
self._experts_s[bid][name] = data_torch

# TODO @ngxson : this is demo, won't compat with other models
if len(self._experts[bid]) + len(self._experts_s[bid]) >= n_experts * 3 * 2:
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
datas_s: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight_packed"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight_scale"
datas_s.append(self._experts_s[bid][ename])
del self._experts_s[bid][ename]

data_packed = torch.stack(datas, dim=0)
data_scale = torch.stack(datas_s, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
target_shape = (n_experts, data_packed.shape[1], data_packed.shape[2] * 8)
self.repack_compressed_tensor(new_name, data_packed, data_scale, target_shape)
#tensors.append((new_name, data_torch))
return []
else:
return []

Expand Down Expand Up @@ -7128,6 +7191,27 @@ def prepare_tensors(self):
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")

def repack_compressed_tensor(self, new_name: str, blocks: Tensor, scales: Tensor, shape: Sequence[int]):
assert blocks.dtype == torch.int32
assert len(blocks.shape) == 3
assert len(scales.shape) == 3
logger.info(f"Repacking compressed_tensor {new_name} with shape {shape}")
# flatten the first two dimensions
blocks = blocks.reshape(-1, blocks.shape[2])
scales = scales.reshape(-1, scales.shape[2])
# TODO: for kimi-k2, this will cast bf16 to f16, this may reduce the accuracy of the model
# we have to do this because Q4_0 in GGUF only supports f16 scales
scales = scales.to(torch.float16)
scales = scales.to(torch.float16).view(torch.uint16).reshape(-1, 1)
repacked = blocks.reshape((blocks.shape[0] * blocks.shape[1]) // 4, 4)
repacked = repacked.view(torch.uint16)
assert repacked.shape[0] == scales.shape[0] # should have the same number of blocks
repacked = torch.concat([scales, repacked], dim=1)
repacked = repacked.view(torch.uint8)
shape_list = list(shape)
shape_list[-1] = (shape_list[-1] // 32) * 18 # block * 18 bytes for Q4_0 block size
self.gguf_writer.add_tensor(new_name, repacked.numpy(), raw_dtype=gguf.GGMLQuantizationType.Q4_0, raw_shape=shape_list)


@ModelBase.register("MiniMaxM2ForCausalLM")
class MiniMaxM2Model(TextModel):
Expand Down
Loading