Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/llmcompressor/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,13 @@
needed for efficient compression.
"""

# trigger registration
from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
from .llama4 import SequentialLlama4TextMoe # noqa: F401
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401
# TODO: add granite4, Qwen3Next

from .fuse import *
from .prepare import *
7 changes: 2 additions & 5 deletions src/llmcompressor/modeling/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
DeepseekV3MoE as OriginalDeepseekV3MoE,
)

from llmcompressor.modeling.moe_context import (
MoECalibrationModule,
register_moe_calibration,
)
from llmcompressor.modeling.moe_context import MoECalibrationModule


@register_moe_calibration("DeepseekV3MoE")
@MoECalibrationModule.register("DeepseekV3MoE")
class CalibrationDeepseekV3MoE(MoECalibrationModule):
"""
Calibration version of DeepseekV3MoE that sends all tokens to all experts.
Expand Down
13 changes: 4 additions & 9 deletions src/llmcompressor/modeling/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
Llama4TextMoe,
)

from llmcompressor.modeling.moe_context import (
MoECalibrationModule,
register_moe_calibration,
)
from llmcompressor.modeling.moe_context import MoECalibrationModule
from llmcompressor.utils.dev import skip_weights_initialize


@register_moe_calibration("Llama4TextMoe")
@MoECalibrationModule.register("Llama4TextMoe")
class SequentialLlama4TextMoe(MoECalibrationModule):
"""
Calibration version of Llama4TextMoe that unpacks experts for sequential processing.
Expand All @@ -38,10 +35,8 @@ def __init__(
calibrate_all_experts: bool = True,
):
super().__init__()
# Extract text config from multimodal config if needed
text_config = (
config.get_text_config() if hasattr(config, "get_text_config") else config
)
# Extract text config from multimodal config
text_config: Llama4TextConfig = config.get_text_config()
self.top_k = text_config.num_experts_per_tok
self.hidden_dim = text_config.hidden_size
self.num_experts = text_config.num_local_experts
Expand Down
46 changes: 11 additions & 35 deletions src/llmcompressor/modeling/moe_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,25 @@

Key components:
- MoECalibrationModule: Abstract base class for calibration modules
- MOE_CALIBRATION_MODULES: Registry mapping module class names to calibration classes
- moe_calibration_context: Context manager that applies calibration to a model
"""

import contextlib
from abc import ABC
from typing import Dict, Type

import torch
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
from loguru import logger
from tqdm import tqdm
from transformers import PreTrainedModel

__all__ = [
"MoECalibrationModule",
"MOE_CALIBRATION_MODULES",
"register_moe_calibration",
"moe_calibration_context",
]


class MoECalibrationModule(ABC, torch.nn.Module):
class MoECalibrationModule(ABC, torch.nn.Module, RegistryMixin):
"""
Abstract base class for MoE calibration modules.

Expand Down Expand Up @@ -62,32 +59,6 @@ def restore(self, original: torch.nn.Module) -> torch.nn.Module:
)


# Registry: module class name -> calibration module class
MOE_CALIBRATION_MODULES: Dict[str, Type[MoECalibrationModule]] = {}


def register_moe_calibration(module_class_name: str):
"""
Decorator to register a MoE calibration module.

Usage:
@register_moe_calibration("DeepseekV3MoE")
class CalibrationDeepseekV3MoE(MoECalibrationModule):
...

Args:
module_class_name: The class name of the original module to replace
"""

def decorator(cls: Type[MoECalibrationModule]) -> Type[MoECalibrationModule]:
if not issubclass(cls, MoECalibrationModule):
raise TypeError(f"{cls.__name__} must inherit from MoECalibrationModule")
MOE_CALIBRATION_MODULES[module_class_name] = cls
return cls

return decorator


@contextlib.contextmanager
def moe_calibration_context(
model: PreTrainedModel,
Expand Down Expand Up @@ -115,14 +86,15 @@ def moe_calibration_context(
model(**batch)
# Model is now restored (unless permanent)
"""

replaced = {}

# Step 1: Collect all MoE modules that need replacement
logger.info("Entering MoE calibration context")
logger.debug("Entering MoE calibration context")
modules_to_replace = []
for name, module in model.named_modules():
class_name = module.__class__.__name__
if class_name in MOE_CALIBRATION_MODULES:
if _is_registered(class_name, MoECalibrationModule):
modules_to_replace.append((name, module, class_name))

# Step 2: Replace modules with progress bar
Expand All @@ -131,8 +103,8 @@ def moe_calibration_context(
for name, module, class_name in tqdm(
modules_to_replace, desc="Replacing MoE modules for calibration"
):
calibration_cls = MOE_CALIBRATION_MODULES[class_name]
replacement = calibration_cls(
replacement = MoECalibrationModule.load_from_registry(
class_name,
module,
model.config,
calibrate_all_experts=calibrate_all_experts,
Expand Down Expand Up @@ -165,3 +137,7 @@ def moe_calibration_context(
if not replacement.is_permanent:
restored = replacement.restore(original)
model.set_submodule(name, restored)


def _is_registered(name: str, subclass: RegistryMixin):
return standardize_lookup_name(name) in subclass.registered_names()
31 changes: 5 additions & 26 deletions src/llmcompressor/modeling/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,12 @@
from compressed_tensors.utils import deprecated, replace_module
from transformers import PreTrainedModel

# Import MoE calibration modules to trigger registration
from llmcompressor.modeling.deepseek_v3 import ( # noqa: F401
CalibrationDeepseekV3MoE,
)
from llmcompressor.modeling.deepseek_v3 import (
replace as replace_deepseekv3,
)
from llmcompressor.modeling.llama4 import ( # noqa: F401
SequentialLlama4TextMoe,
)
from llmcompressor.modeling.llama4 import (
replace as replace_llama4,
)
from llmcompressor.modeling.moe_context import ( # noqa: F401
moe_calibration_context,
)
from llmcompressor.modeling.qwen3_moe import ( # noqa: F401
CalibrationQwen3MoeSparseMoeBlock,
)
from llmcompressor.modeling.qwen3_next_moe import ( # noqa: F401
CalibrationQwen3NextSparseMoeBlock,
)
from llmcompressor.modeling.qwen3_vl_moe import (
replace as replace_Qwen3VLMoE,
)
# deprecated replacement functions
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
from llmcompressor.modeling.llama4 import replace as replace_llama4
from llmcompressor.modeling.qwen3_vl_moe import replace as replace_Qwen3VLMoE

__all__ = ["moe_calibration_context", "replace_modules_for_calibration"]
__all__ = ["replace_modules_for_calibration"]

# ---------------------- module replacements; permanent -------------------------
replacements = {
Expand Down
7 changes: 2 additions & 5 deletions src/llmcompressor/modeling/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
)

from llmcompressor.modeling.moe_context import (
MoECalibrationModule,
register_moe_calibration,
)
from llmcompressor.modeling.moe_context import MoECalibrationModule


@register_moe_calibration("Qwen3MoeSparseMoeBlock")
@MoECalibrationModule.register("Qwen3MoeSparseMoeBlock")
class CalibrationQwen3MoeSparseMoeBlock(MoECalibrationModule):
"""
Calibration version of Qwen3MoeSparseMoeBlock that sends all tokens to all experts.
Expand Down
45 changes: 36 additions & 9 deletions src/llmcompressor/modeling/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
import torch
from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextSparseMoeBlock as OriginalQwen3VLMoeTextSparseMoeBlock,
)

from llmcompressor.modeling.moe_context import MoECalibrationModule
from llmcompressor.utils.dev import skip_weights_initialize


class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module):
def __init__(self, config, original, calibrate_all_experts):
@MoECalibrationModule.register("Qwen3VLMoeTextSparseMoeBlock")
class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule):
"""
Calibration version of Qwen3VLMoeTextSparseMoeBlock that sends all tokens to all
experts.
"""

is_permanent = True

def __init__(
self,
original: OriginalQwen3VLMoeTextSparseMoeBlock,
config: Qwen3VLMoeConfig,
calibrate_all_experts: bool,
):
super().__init__()
self.hidden_size = config.hidden_size
self.num_experts = config.num_experts
text_config: Qwen3VLMoeTextConfig = config.get_text_config()

self.hidden_size = text_config.hidden_size
self.num_experts = text_config.num_experts
self.top_k = original.top_k
# Note: gate was changed to be a Linear layer in transformers==4.57.0
# https://github.com/JJJYmmm/transformers/commit/f5dea1c694af8c994c769170813a8702332119ee
self.gate = original.gate
self.calibrate_all_experts = calibrate_all_experts
self.experts = SequentialQwen3VLMoeTextExperts(config, original.experts)
self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
Expand Down Expand Up @@ -64,6 +84,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
next_states = next_states.reshape(batch_size, sequence_length, hidden_dim)
return next_states, router_logits

def restore(self, original: torch.nn.Module) -> torch.nn.Module:
return original


class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList):
def __init__(self, config, original):
Expand Down Expand Up @@ -91,9 +114,13 @@ def __init__(self, config, original):
self[i].down_proj.weight.data = down.t().clone().contiguous()


def replace(config, module, calibrate_all_experts):
return LinearQwen3VLMoeTextSparseMoeBlock(
config=config.get_text_config(),
original=module,
def replace(
config: Qwen3VLMoeConfig,
original: OriginalQwen3VLMoeTextSparseMoeBlock,
calibrate_all_experts: bool,
):
return CalibrateQwen3VLMoeTextSparseMoeBlock(
original=original,
config=config,
calibrate_all_experts=calibrate_all_experts,
)
25 changes: 12 additions & 13 deletions tests/llmcompressor/modeling/test_calib_qwen3_vl_moe.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,43 @@
import torch
from transformers import Qwen3VLMoeConfig
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextSparseMoeBlock,
)

from llmcompressor.modeling.qwen3_vl_moe import LinearQwen3VLMoeTextSparseMoeBlock
from llmcompressor.modeling.qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock
from llmcompressor.utils.helpers import calibration_forward_context
from tests.testing_utils import requires_gpu


@requires_gpu
def test_calib_qwen3_vl_moe_module():
from transformers import Qwen3VLMoeTextConfig
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextSparseMoeBlock,
)

config = Qwen3VLMoeTextConfig()
config = Qwen3VLMoeConfig()
with torch.device("cuda"):
original = Qwen3VLMoeTextSparseMoeBlock(config).eval()
original = Qwen3VLMoeTextSparseMoeBlock(config.get_text_config()).eval()
# these are initialized as empty / all 0s which results in outputs
# from the experts being all 0
# update to use a small random value
original.experts.gate_up_proj.data.normal_(mean=0.0, std=0.02)
original.experts.down_proj.data.normal_(mean=0.0, std=0.02)

# Create dummy input tensor that simulates hidden_states
hidden_dim = config.hidden_size
hidden_dim = config.get_text_config().hidden_size
batch, seq_len = 4, 32
sample = torch.randn(batch, seq_len, hidden_dim, device="cuda")

with calibration_forward_context(original):
true_output = original(sample)

module = LinearQwen3VLMoeTextSparseMoeBlock(
config, original, calibrate_all_experts=True
module = CalibrateQwen3VLMoeTextSparseMoeBlock(
original, config, calibrate_all_experts=True
)
with calibration_forward_context(module):
output = module(sample)
assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10

module = LinearQwen3VLMoeTextSparseMoeBlock(
config, original, calibrate_all_experts=False
module = CalibrateQwen3VLMoeTextSparseMoeBlock(
original, config, calibrate_all_experts=False
)
with calibration_forward_context(module):
output = module(sample)
Expand Down
Loading