Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/multimodal_vision/llama4_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# NOTE: This restructuring is specifically required for vLLM compatibility.
# To define custom calibration logic, create a new calibration module in
# modeling/llama4.py that inherits from `MoECalibrationModule`, and register
# it using the `@register_moe_calibration` decorator with the appropriate
# it using the `@MoECalibrationModule.register` decorator with the appropriate
# module class name (e.g., "Llama4TextMoe").

DATASET_ID = "neuralmagic/calibration"
Expand Down
8 changes: 8 additions & 0 deletions src/llmcompressor/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,13 @@
needed for efficient compression.
"""

# trigger registration
from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
from .llama4 import SequentialLlama4TextMoe # noqa: F401
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401
# TODO: add granite4, Qwen3Next

from .fuse import *
from .prepare import *
7 changes: 2 additions & 5 deletions src/llmcompressor/modeling/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
DeepseekV3MoE as OriginalDeepseekV3MoE,
)

from llmcompressor.modeling.moe_context import (
MoECalibrationModule,
register_moe_calibration,
)
from llmcompressor.modeling.moe_context import MoECalibrationModule


@register_moe_calibration("DeepseekV3MoE")
@MoECalibrationModule.register("DeepseekV3MoE")
class CalibrationDeepseekV3MoE(MoECalibrationModule):
"""
Calibration version of DeepseekV3MoE that sends all tokens to all experts.
Expand Down
13 changes: 4 additions & 9 deletions src/llmcompressor/modeling/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
Llama4TextMoe,
)

from llmcompressor.modeling.moe_context import (
MoECalibrationModule,
register_moe_calibration,
)
from llmcompressor.modeling.moe_context import MoECalibrationModule
from llmcompressor.utils.dev import skip_weights_initialize


@register_moe_calibration("Llama4TextMoe")
@MoECalibrationModule.register("Llama4TextMoe")
class SequentialLlama4TextMoe(MoECalibrationModule):
"""
Calibration version of Llama4TextMoe that unpacks experts for sequential processing.
Expand All @@ -38,10 +35,8 @@ def __init__(
calibrate_all_experts: bool = True,
):
super().__init__()
# Extract text config from multimodal config if needed
text_config = (
config.get_text_config() if hasattr(config, "get_text_config") else config
)
# Extract text config from multimodal config
text_config: Llama4TextConfig = config.get_text_config()
self.top_k = text_config.num_experts_per_tok
self.hidden_dim = text_config.hidden_size
self.num_experts = text_config.num_local_experts
Expand Down
46 changes: 11 additions & 35 deletions src/llmcompressor/modeling/moe_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,25 @@

Key components:
- MoECalibrationModule: Abstract base class for calibration modules
- MOE_CALIBRATION_MODULES: Registry mapping module class names to calibration classes
- moe_calibration_context: Context manager that applies calibration to a model
"""

import contextlib
from abc import ABC
from typing import Dict, Type

import torch
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
from loguru import logger
from tqdm import tqdm
from transformers import PreTrainedModel

__all__ = [
"MoECalibrationModule",
"MOE_CALIBRATION_MODULES",
"register_moe_calibration",
"moe_calibration_context",
]


class MoECalibrationModule(ABC, torch.nn.Module):
class MoECalibrationModule(ABC, torch.nn.Module, RegistryMixin):
"""
Abstract base class for MoE calibration modules.

Expand Down Expand Up @@ -62,32 +59,6 @@ def restore(self, original: torch.nn.Module) -> torch.nn.Module:
)


# Registry: module class name -> calibration module class
MOE_CALIBRATION_MODULES: Dict[str, Type[MoECalibrationModule]] = {}


def register_moe_calibration(module_class_name: str):
"""
Decorator to register a MoE calibration module.

Usage:
@register_moe_calibration("DeepseekV3MoE")
class CalibrationDeepseekV3MoE(MoECalibrationModule):
...

Args:
module_class_name: The class name of the original module to replace
"""

def decorator(cls: Type[MoECalibrationModule]) -> Type[MoECalibrationModule]:
if not issubclass(cls, MoECalibrationModule):
raise TypeError(f"{cls.__name__} must inherit from MoECalibrationModule")
MOE_CALIBRATION_MODULES[module_class_name] = cls
return cls

return decorator


@contextlib.contextmanager
def moe_calibration_context(
model: PreTrainedModel,
Expand Down Expand Up @@ -115,14 +86,15 @@ def moe_calibration_context(
model(**batch)
# Model is now restored (unless permanent)
"""

replaced = {}

# Step 1: Collect all MoE modules that need replacement
logger.info("Entering MoE calibration context")
logger.debug("Entering MoE calibration context")
modules_to_replace = []
for name, module in model.named_modules():
class_name = module.__class__.__name__
if class_name in MOE_CALIBRATION_MODULES:
if _is_registered(class_name, MoECalibrationModule):
modules_to_replace.append((name, module, class_name))

# Step 2: Replace modules with progress bar
Expand All @@ -131,8 +103,8 @@ def moe_calibration_context(
for name, module, class_name in tqdm(
modules_to_replace, desc="Replacing MoE modules for calibration"
):
calibration_cls = MOE_CALIBRATION_MODULES[class_name]
replacement = calibration_cls(
replacement = MoECalibrationModule.load_from_registry(
class_name,
module,
model.config,
calibrate_all_experts=calibrate_all_experts,
Expand Down Expand Up @@ -165,3 +137,7 @@ def moe_calibration_context(
if not replacement.is_permanent:
restored = replacement.restore(original)
model.set_submodule(name, restored)


def _is_registered(name: str, subclass: RegistryMixin):
return standardize_lookup_name(name) in subclass.registered_names()
31 changes: 5 additions & 26 deletions src/llmcompressor/modeling/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,12 @@
from compressed_tensors.utils import deprecated, replace_module
from transformers import PreTrainedModel

# Import MoE calibration modules to trigger registration
from llmcompressor.modeling.deepseek_v3 import ( # noqa: F401
CalibrationDeepseekV3MoE,
)
from llmcompressor.modeling.deepseek_v3 import (
replace as replace_deepseekv3,
)
from llmcompressor.modeling.llama4 import ( # noqa: F401
SequentialLlama4TextMoe,
)
from llmcompressor.modeling.llama4 import (
replace as replace_llama4,
)
from llmcompressor.modeling.moe_context import ( # noqa: F401
moe_calibration_context,
)
from llmcompressor.modeling.qwen3_moe import ( # noqa: F401
CalibrationQwen3MoeSparseMoeBlock,
)
from llmcompressor.modeling.qwen3_next_moe import ( # noqa: F401
CalibrationQwen3NextSparseMoeBlock,
)
from llmcompressor.modeling.qwen3_vl_moe import (
replace as replace_Qwen3VLMoE,
)
# deprecated replacement functions
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
from llmcompressor.modeling.llama4 import replace as replace_llama4
from llmcompressor.modeling.qwen3_vl_moe import replace as replace_Qwen3VLMoE

__all__ = ["moe_calibration_context", "replace_modules_for_calibration"]
__all__ = ["replace_modules_for_calibration"]

# ---------------------- module replacements; permanent -------------------------
replacements = {
Expand Down
7 changes: 2 additions & 5 deletions src/llmcompressor/modeling/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
)

from llmcompressor.modeling.moe_context import (
MoECalibrationModule,
register_moe_calibration,
)
from llmcompressor.modeling.moe_context import MoECalibrationModule


@register_moe_calibration("Qwen3MoeSparseMoeBlock")
@MoECalibrationModule.register("Qwen3MoeSparseMoeBlock")
class CalibrationQwen3MoeSparseMoeBlock(MoECalibrationModule):
"""
Calibration version of Qwen3MoeSparseMoeBlock that sends all tokens to all experts.
Expand Down
7 changes: 2 additions & 5 deletions src/llmcompressor/modeling/qwen3_next_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@

import torch

from llmcompressor.modeling.moe_context import (
MoECalibrationModule,
register_moe_calibration,
)
from llmcompressor.modeling.moe_context import MoECalibrationModule


@register_moe_calibration("Qwen3NextSparseMoeBlock")
@MoECalibrationModule.register("Qwen3NextSparseMoeBlock")
class CalibrationQwen3NextSparseMoeBlock(MoECalibrationModule):
from transformers import Qwen3NextConfig
from transformers.models.qwen3_next.modeling_qwen3_next import (
Expand Down
45 changes: 36 additions & 9 deletions src/llmcompressor/modeling/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
import torch
from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextSparseMoeBlock as OriginalQwen3VLMoeTextSparseMoeBlock,
)

from llmcompressor.modeling.moe_context import MoECalibrationModule
from llmcompressor.utils.dev import skip_weights_initialize


class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module):
def __init__(self, config, original, calibrate_all_experts):
@MoECalibrationModule.register("Qwen3VLMoeTextSparseMoeBlock")
class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule):
"""
Calibration version of Qwen3VLMoeTextSparseMoeBlock that sends all tokens to all
experts.
"""

is_permanent = True

def __init__(
self,
original: OriginalQwen3VLMoeTextSparseMoeBlock,
config: Qwen3VLMoeConfig,
calibrate_all_experts: bool,
):
super().__init__()
self.hidden_size = config.hidden_size
self.num_experts = config.num_experts
text_config: Qwen3VLMoeTextConfig = config.get_text_config()

self.hidden_size = text_config.hidden_size
self.num_experts = text_config.num_experts
self.top_k = original.top_k
# Note: gate was changed to be a Linear layer in transformers==4.57.0
# https://github.com/JJJYmmm/transformers/commit/f5dea1c694af8c994c769170813a8702332119ee
self.gate = original.gate
self.calibrate_all_experts = calibrate_all_experts
self.experts = SequentialQwen3VLMoeTextExperts(config, original.experts)
self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
Expand Down Expand Up @@ -64,6 +84,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
next_states = next_states.reshape(batch_size, sequence_length, hidden_dim)
return next_states, router_logits

def restore(self, original: torch.nn.Module) -> torch.nn.Module:
return original


class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList):
def __init__(self, config, original):
Expand Down Expand Up @@ -91,9 +114,13 @@ def __init__(self, config, original):
self[i].down_proj.weight.data = down.t().clone().contiguous()


def replace(config, module, calibrate_all_experts):
return LinearQwen3VLMoeTextSparseMoeBlock(
config=config.get_text_config(),
original=module,
def replace(
config: Qwen3VLMoeConfig,
original: OriginalQwen3VLMoeTextSparseMoeBlock,
calibrate_all_experts: bool,
):
return CalibrateQwen3VLMoeTextSparseMoeBlock(
original=original,
config=config,
calibrate_all_experts=calibrate_all_experts,
)
25 changes: 12 additions & 13 deletions tests/llmcompressor/modeling/test_calib_qwen3_vl_moe.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,43 @@
import torch
from transformers import Qwen3VLMoeConfig
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextSparseMoeBlock,
)

from llmcompressor.modeling.qwen3_vl_moe import LinearQwen3VLMoeTextSparseMoeBlock
from llmcompressor.modeling.qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock
from llmcompressor.utils.helpers import calibration_forward_context
from tests.testing_utils import requires_gpu


@requires_gpu
def test_calib_qwen3_vl_moe_module():
from transformers import Qwen3VLMoeTextConfig
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextSparseMoeBlock,
)

config = Qwen3VLMoeTextConfig()
config = Qwen3VLMoeConfig()
with torch.device("cuda"):
original = Qwen3VLMoeTextSparseMoeBlock(config).eval()
original = Qwen3VLMoeTextSparseMoeBlock(config.get_text_config()).eval()
# these are initialized as empty / all 0s which results in outputs
# from the experts being all 0
# update to use a small random value
original.experts.gate_up_proj.data.normal_(mean=0.0, std=0.02)
original.experts.down_proj.data.normal_(mean=0.0, std=0.02)

# Create dummy input tensor that simulates hidden_states
hidden_dim = config.hidden_size
hidden_dim = config.get_text_config().hidden_size
batch, seq_len = 4, 32
sample = torch.randn(batch, seq_len, hidden_dim, device="cuda")

with calibration_forward_context(original):
true_output = original(sample)

module = LinearQwen3VLMoeTextSparseMoeBlock(
config, original, calibrate_all_experts=True
module = CalibrateQwen3VLMoeTextSparseMoeBlock(
original, config, calibrate_all_experts=True
)
with calibration_forward_context(module):
output = module(sample)
assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10

module = LinearQwen3VLMoeTextSparseMoeBlock(
config, original, calibrate_all_experts=False
module = CalibrateQwen3VLMoeTextSparseMoeBlock(
original, config, calibrate_all_experts=False
)
with calibration_forward_context(module):
output = module(sample)
Expand Down