Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modeling.granite4 import replace_granite_moe_with_linear_experts, pack_3d_experts

from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "ibm-granite/granite-4.0-h-small"

model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

model = replace_granite_moe_with_linear_experts(model)

ignore_lay = ["lm_head"]
ignore_lay += ["re:.*block_sparse_moe.router"]
ignore_lay += ["re:.*mamba.in_proj"]
ignore_lay += ["re:.*shared_mlp.input_linear"]

recipe = QuantizationModifier(
targets=["Linear"],
scheme="FP8_BLOCK",
ignore=ignore_lay,
)

oneshot(model=model, recipe=recipe)
dispatch_for_generation(model)

print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer(
"Describe Large Language Model", return_tensors="pt"
).input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=35)
print(tokenizer.decode(output[0]))
print("==========================================")

SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-block"
print(f"Saving to {SAVE_DIR}")

model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

pack_3d_experts(SAVE_DIR)
243 changes: 243 additions & 0 deletions src/llmcompressor/modeling/granite4.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,252 @@
import torch
import json
import os
import shutil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsikka This saving logic is specific to granite4 small model to ensure the fp8 block quantized model is compatible with vLLM.

from pathlib import Path
from collections import defaultdict
from safetensors.torch import load_file, save_file

from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
GraniteMoeHybridParallelExperts,
)


#for fp8 block quantization
def replace_granite_moe_with_linear_experts(model):
"""
Convert GraniteMoeHybridParallelExperts modules into individual expert layers.
Each expert will be stored as a separate nn.Linear module.
"""

class SeparatedExperts(nn.Module):
"""Replacement module with individual expert linear layers"""
def __init__(self, num_experts, input_size, output_size, original_weight):
super().__init__()
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size

# Create individual linear layers for each expert
self.experts = nn.ModuleList([
nn.Linear(input_size, output_size, bias=False)
for _ in range(num_experts)
])

# Copy weights from the original 3D tensor
# Original format: [num_experts, output_size, input_size]
for i in range(num_experts):
self.experts[i].weight.data = original_weight[i].clone()

def forward(self, inputs, expert_size):
"""Forward pass using individual expert layers"""
input_list = inputs.split(expert_size, dim=0)
output_list = []
for i in range(self.num_experts):
output_list.append(self.experts[i](input_list[i]))
results = torch.cat(output_list, dim=0)
return results

# Find and replace all GraniteMoeHybridParallelExperts modules
def replace_parallel_experts(module, name=''):
for child_name, child in module.named_children():
full_name = f"{name}.{child_name}" if name else child_name

if child.__class__.__name__ == 'GraniteMoeHybridParallelExperts':
# Create replacement module with separated experts
separated = SeparatedExperts(
num_experts=child.num_experts,
input_size=child.input_size,
output_size=child.output_size,
original_weight=child.weight.data
)
# Replace the module
setattr(module, child_name, separated)
print(f"Replaced {full_name}: {child.num_experts} experts, "
f"input_size={child.input_size}, output_size={child.output_size}")
else:
# Recursively process children
replace_parallel_experts(child, full_name)

replace_parallel_experts(model)
return model



def pack_3d_experts(source_dir):
"""
Transform Granite MoE model from per-expert storage to stacked 3D tensor storage

From: model.layers.{L}.block_sparse_moe.{linear_type}.experts.{E}.{param}
To: model.layers.{L}.block_sparse_moe.{linear_type}.{param}

"""
source_dir = Path(source_dir)

# Load the index file
index_file = source_dir / "model.safetensors.index.json"
with open(index_file, "r") as f:
index_data = json.load(f)

weight_map = index_data["weight_map"]

# Group tensors by layer, linear type, and parameter
# Structure: {(layer_num, linear_type, param): {expert_num: (tensor_name, file_name)}}
grouped_tensors = defaultdict(dict)
other_tensors = {} # Non-expert tensors (router, embeddings, etc.)

for tensor_name, file_name in weight_map.items():
# Check if this is an expert tensor
# Pattern: model.layers.{L}.block_sparse_moe.{linear_type}.experts.{E}.{param}
if ".block_sparse_moe." in tensor_name and ".experts." in tensor_name:
parts = tensor_name.split(".")

try:
# Find the indices of key parts
layers_idx = parts.index("layers")
layer_num = int(parts[layers_idx + 1])

experts_idx = parts.index("experts")
expert_num = int(parts[experts_idx + 1])

# The linear type is right before "experts"
# e.g., "input_linear" or "output_linear"
linear_type = parts[experts_idx - 1]

# The parameter is after expert number
# e.g., "weight" or "weight_scale"
param = ".".join(parts[experts_idx + 2:])

# Create grouping key
group_key = (layer_num, linear_type, param)
grouped_tensors[group_key][expert_num] = (tensor_name, file_name)

except (ValueError, IndexError) as e:
# If parsing fails, treat as other tensor
print(f" Warning: Could not parse expert tensor: {tensor_name}")
other_tensors[tensor_name] = file_name
else:
other_tensors[tensor_name] = file_name

# Load all safetensors files
print("Loading source safetensors files...")
loaded_tensors = {}
unique_files = set(weight_map.values())
old_files = list(unique_files) # Store list of old files to delete later

for file_name in unique_files:
file_path = source_dir / file_name
print(f" Loading {file_name}...")
loaded_tensors[file_name] = load_file(str(file_path))

# Create new tensors by stacking experts
print("\nStacking expert tensors...")
new_tensors = {}

# Process each grouped tensor
for (layer_num, linear_type, param), experts_dict in sorted(grouped_tensors.items()):
print(f" Processing layer {layer_num}, {linear_type}.{param}...")

# Get all expert tensors for this group
expert_nums = sorted(experts_dict.keys())
expert_tensors = []

for expert_num in expert_nums:
tensor_name, file_name = experts_dict[expert_num]
tensor = loaded_tensors[file_name][tensor_name]
expert_tensors.append(tensor)

# Stack along first dimension to create 3D tensor
stacked_tensor = torch.stack(expert_tensors, dim=0)

# Create new tensor name (remove .experts.{E} part)
new_tensor_name = f"model.layers.{layer_num}.block_sparse_moe.{linear_type}.{param}"
new_tensors[new_tensor_name] = stacked_tensor

print(f" {new_tensor_name}: {list(stacked_tensor.shape)} (stacked {len(expert_tensors)} experts)")

# Copy non-expert tensors (router, embeddings, etc.)
print("\nCopying non-expert tensors...")
for tensor_name, file_name in other_tensors.items():
tensor = loaded_tensors[file_name][tensor_name]
new_tensors[tensor_name] = tensor
print(f" Copied: {tensor_name}")

# Determine file distribution for new tensors
# Simple strategy: distribute roughly equally across same number of files
num_output_files = len(unique_files)
tensors_list = list(new_tensors.items())

# Calculate approximate size per file
total_numel = sum(t.numel() * t.element_size() for _, t in tensors_list)
target_size_per_file = total_numel / num_output_files

# Distribute tensors across files
print(f"\nDistributing tensors across {num_output_files} files...")
file_tensors = [{} for _ in range(num_output_files)]
file_sizes = [0] * num_output_files
new_weight_map = {}

for tensor_name, tensor in tensors_list:
# Find file with smallest current size
min_idx = file_sizes.index(min(file_sizes))
file_tensors[min_idx][tensor_name] = tensor
file_sizes[min_idx] += tensor.numel() * tensor.element_size()

# Update weight map
file_name = f"model-{min_idx+1:05d}-of-{num_output_files:05d}.safetensors"
new_weight_map[tensor_name] = file_name

# Save new safetensors files with temporary names
print("\nSaving new safetensors files (temporary)...")
temp_files = []
for i, tensors_dict in enumerate(file_tensors):
if tensors_dict: # Only save if not empty
file_name = f"model-{i+1:05d}-of-{num_output_files:05d}.safetensors"
temp_file_name = f"model-{i+1:05d}-of-{num_output_files:05d}.safetensors.tmp"
output_path = source_dir / temp_file_name
print(f" Saving {temp_file_name} ({len(tensors_dict)} tensors)...")
save_file(tensors_dict, str(output_path))
temp_files.append((temp_file_name, file_name))

# Save updated index file with temporary name
print("\nSaving updated index file (temporary)...")
new_index_data = {
"metadata": index_data.get("metadata", {}),
"weight_map": new_weight_map
}

temp_index_file = source_dir / "model.safetensors.index.json.tmp"
with open(temp_index_file, "w") as f:
json.dump(new_index_data, f, indent=2)

# Now delete old files
print("\nDeleting old safetensors files...")
for old_file in old_files:
old_file_path = source_dir / old_file
if old_file_path.exists():
old_file_path.unlink()
print(f" Deleted {old_file}")

# Delete old index file
if index_file.exists():
index_file.unlink()
print(f" Deleted model.safetensors.index.json")

# Rename temporary files to final names
print("\nRenaming temporary files to final names...")
for temp_name, final_name in temp_files:
temp_path = source_dir / temp_name
final_path = source_dir / final_name
temp_path.rename(final_path)
print(f" Renamed {temp_name} -> {final_name}")

# Rename temporary index file
temp_index_file.rename(index_file)
print(f" Renamed model.safetensors.index.json.tmp -> model.safetensors.index.json")
print(f"\nCheckpoint Updated for vLLM Compatibility")



class GraniteMoeHybridParallelExpertsLinear(torch.nn.Linear):
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
"""Use a real Linear so that llmcompressor and vllm can handle it easier.
Expand Down