Skip to content

Commit 1aa196f

Browse files
authored
[MoE Calibration] Simplify MoE calibration interface (#1851)
## Introduce standardized MoE calibration interface and deprecate legacy replace_modules_for_calibration ### Summary Implements a simplified, decorator-based registration system for MoE model calibration using a single `MoECalibrationModule` base class, making MoE model integration easier and deprecates the legacy `replace_modules_for_calibration` function. ### Problem MoE model calibration currently requires module replacement logic scattered across `replace_modules_for_calibration` and manual context management. This makes contributing new MoE model support difficult and error-prone. Additionally, each model required custom replacement functions with duplicated boilerplate code. ### Relevant Issues Fixes #1829 ### Solution **`MoECalibrationModule`** abstract base class implementation - Only two required methods: `from_original()` classmethod and optional `restore()` - `is_permanent` flag to specify if module replacement is to be restored using `restore()` - Clear contract: permanent modules stay in calibration form, non-permanent modules get restored after context exit **Decorator-Based Registration**: `@register_moe_calibration("ModuleName")` decorator - Automatic registration in `MOE_CALIBRATION_MODULES` registry - Models self-register when their module is imported **New Model Integration**: Adding MoE support requires only: ```python @register_moe_calibration("YourMoEModule") class CalibrationYourMoE(MoECalibrationModule): is_permanent = True # or False @classmethod def from_original(cls, original, config, calibrate_all_experts=True): return cls(config, original, calibrate_all_experts) ``` **Dataset Arguments**: New: `moe_calibrate_all_experts: bool = True` - Controls whether all experts see all tokens during calibration - `True` (default): All experts receive all tokens for proper quantization statistics - `False`: Normal routing behavior (only routed experts are used) - Used by both `oneshot()` and `DatasetArguments` - Automatically passed to `moe_calibration_context` by pipelines **Automatic Context Management**: `moe_calibration_context` integrated into pipelines - Wraps calibration automatically in `oneshot.py` - Handles module replacement and restoration transparently - No manual context management required by users **Backward Compatibility**: Deprecation of `replace_modules_for_calibration` with warnings - Legacy function preserved for compatibility - Clear migration path documented in deprecation message ### Test Plan - ✅ Unit tests for contextual MoE calibration with automatic module restoration - ✅ Unit tests for permanent MoE calibration persistence - ✅ Integration tests with Qwen3, Llama4, and DeepSeek V3 models - ✅ Verification that all experts receive data during calibration - ✅ Deprecation warning verification for legacy functions ### Testing - ✅ All unit tests pass - ✅ Calibration types working correctly - ✅ Model structure correctly modified and restored inside/outside contexts - ✅ Linting and type checking pass - ✅ Backward compatibility verified with deprecation warnings ### Migration Guide **Before**: ```python # Required defining MoEModelConfig entries, handling context manually from llmcompressor.modeling.prepare import replace_modules_for_calibration model = replace_modules_for_calibration(model, calibrate_all_experts=True) ``` **After**: ```python # Automatic - just use moe_calibration_context from llmcompressor.modeling import moe_calibration_context with moe_calibration_context(model, calibrate_all_experts=True): # Run calibration - modules replaced automatically for batch in dataloader: model(**batch) # Modules restored automatically (if not permanent) ``` --------- Signed-off-by: Sairam Pillai <sairam.pillai61@gmail.com>
1 parent 0f346cf commit 1aa196f

File tree

19 files changed

+460
-225
lines changed

19 files changed

+460
-225
lines changed

examples/multimodal_vision/llama4_example.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,24 @@
33
from transformers import Llama4ForConditionalGeneration, Llama4Processor
44

55
from llmcompressor import oneshot
6-
from llmcompressor.modeling import replace_modules_for_calibration
76
from llmcompressor.modifiers.quantization import GPTQModifier
87

98
# Select model and load it.
109
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
1110
model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
1211
processor = Llama4Processor.from_pretrained(model_id)
13-
# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`.
14-
# This change allows compatibility with vllm.
15-
# To apply your own custom module for experimentation, consider updating
16-
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
17-
model = replace_modules_for_calibration(model)
12+
# MoE calibration is now handled automatically by the pipeline.
13+
# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`)
14+
# will be applied during calibration to enable proper
15+
# expert calibration and vLLM compatibility.
16+
# These replace the original `Llama4TextMoe` class from
17+
# `transformers.models.llama4.modeling_llama4`.
18+
#
19+
# NOTE: This restructuring is specifically required for vLLM compatibility.
20+
# To define custom calibration logic, create a new calibration module in
21+
# modeling/llama4.py that inherits from `MoECalibrationModule`, and register
22+
# it using the `@register_moe_calibration` decorator with the appropriate
23+
# module class name (e.g., "Llama4TextMoe").
1824

1925
DATASET_ID = "neuralmagic/calibration"
2026
NUM_CALIBRATION_SAMPLES = 512

examples/quantization_w4a4_fp4/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ We have successfully created an `nvfp4` model!
8484

8585
# Quantizing MoEs
8686

87-
To quantize MoEs, a few additional steps are required. An example quantizing Llama4 can be found under `llama4_example.py`. Here, we replace all `Llama4TextMoe` modules by calling `replace_modules_for_calibration`. This replacement allows us to:
87+
To quantize MoEs, MoE calibration is now handled automatically by the pipeline. An example quantizing Llama4 can be found under `llama4_example.py`. The pipeline automatically applies the appropriate MoE calibration context which:
8888

89-
1. Linearize the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization.
90-
2. Ensure experts are quantized correctly as not all experts are activated during calibration
89+
1. Linearizes the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization.
90+
2. Ensures experts are quantized correctly as not all experts are activated during calibration
9191

92-
Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model does not require additional linearization as required by the Llama4 model. However, similar to Llama4, in order to ensure the experts are quantized correctly, we can pass in `calibrate_moe_context` which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance.
92+
Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model uses contextual MoE calibration which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance.
9393

9494

examples/quantization_w4a4_fp4/llama4_example.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33
from transformers import Llama4ForConditionalGeneration, Llama4Processor
44

55
from llmcompressor import oneshot
6-
from llmcompressor.modeling import replace_modules_for_calibration
76
from llmcompressor.modifiers.quantization import QuantizationModifier
87

98
# Select model and load it.
109
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
1110
model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
1211
processor = Llama4Processor.from_pretrained(model_id)
13-
# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`.
14-
# This change allows compatibility with vllm.
15-
# To apply your own custom module for experimentation, consider updating
16-
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
17-
model = replace_modules_for_calibration(model)
12+
# MoE calibration is now handled automatically by the pipeline.
13+
# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`)
14+
# will be applied during calibration to enable
15+
# proper expert calibration and vLLM compatibility.
16+
# These replace the original `Llama4TextMoe` class from
17+
# `transformers.models.llama4.modeling_llama4`.
1818

1919
DATASET_ID = "neuralmagic/calibration"
2020
NUM_CALIBRATION_SAMPLES = 20

examples/quantization_w4a4_fp4/qwen_30b_a3b.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,23 @@ def tokenize(sample):
5959
)
6060

6161
# Apply quantization.
62-
# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock`
63-
# during calibration.
62+
# MoE calibration is now handled automatically by the pipeline.
63+
# We set `moe_calibrate_all_experts` to True to ensure all experts receive
64+
# calibration data. This temporarily updates the model definition to use
65+
# `CalibrationQwen3MoeSparseMoeBlock` (from `llmcompressor.modeling.qwen3_moe`)
66+
# which replaces the original `Qwen3MoeSparseMoeBlock` class from
67+
# `transformers.models.qwen3_moe.modeling_qwen3_moe`. This updates how the
68+
# forward pass is handled in the MoE block during calibration.
6469
# Feel free to update the definition under
65-
# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with
66-
# this behaviour and evaluate its impact on quantization performance
70+
# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py to play around with
71+
# this behavior and evaluate its impact on quantization performance.
6772
oneshot(
6873
model=model,
6974
dataset=ds,
7075
recipe=recipe,
7176
max_seq_length=MAX_SEQUENCE_LENGTH,
7277
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
73-
calibrate_moe_context=True,
78+
moe_calibrate_all_experts=True,
7479
)
7580

7681

examples/quantization_w8a8_fp8/llama4_fp8_block_example.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from transformers import AutoModelForCausalLM, AutoTokenizer
22

33
from llmcompressor import oneshot
4-
from llmcompressor.modeling import replace_modules_for_calibration
54
from llmcompressor.modifiers.quantization import QuantizationModifier
65
from llmcompressor.utils import dispatch_for_generation
76

@@ -10,7 +9,12 @@
109
# Load model.
1110
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
1211
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13-
model = replace_modules_for_calibration(model)
12+
# MoE calibration is now handled automatically by the pipeline.
13+
# The `SequentialLlama4TextMoe` modules (from `llmcompressor.modeling.llama4`)
14+
# will be applied during calibration to enable
15+
# proper expert calibration and vLLM compatibility.
16+
# These replace the original `Llama4TextMoe` class from
17+
# `transformers.models.llama4.modeling_llama4`.
1418
# Configure the quantization algorithm and scheme.
1519
# In this case, we:
1620
# * quantize the weights to fp8 with block size 128 via ptq

examples/quantizing_moe/deepseek_r1_example.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
33

44
from llmcompressor import oneshot
5-
from llmcompressor.modeling import replace_modules_for_calibration
65
from llmcompressor.modifiers.quantization import GPTQModifier
76

87
# Select model and load it.
@@ -20,7 +19,11 @@
2019
model_id, torch_dtype="auto", config=config
2120
)
2221
tokenizer = AutoTokenizer.from_pretrained(model_id)
23-
model = replace_modules_for_calibration(model)
22+
# MoE calibration is now handled automatically by the pipeline.
23+
# The `CalibrationDeepseekV3MoE` modules (from `llmcompressor.modeling.deepseek_v3`)
24+
# will be applied during calibration to enable proper expert calibration.
25+
# These replace the original `DeepseekV3MoE` class from
26+
# `transformers.models.deepseek_v3.modeling_deepseek_v3`.
2427

2528
# Select calibration dataset.
2629
DATASET_ID = "HuggingFaceH4/ultrachat_200k"

src/llmcompressor/args/dataset_arguments.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,6 @@ class DatasetArguments(CustomDatasetArguments):
126126
default=512,
127127
metadata={"help": "Number of samples to use for one-shot calibration"},
128128
)
129-
calibrate_moe_context: bool = field(
130-
default=False,
131-
metadata={
132-
"help": "If during calibration, the MoE context should be enabled "
133-
"for the given model. This usually involves updating all MoE modules "
134-
"in the model for the duration of calibration. See moe_context under "
135-
"modeling/prepare.py for a list of supported MoEs and their updated "
136-
"module definitions"
137-
},
138-
)
139129
shuffle_calibration_samples: bool | None = field(
140130
default=True,
141131
metadata={
@@ -181,6 +171,18 @@ class DatasetArguments(CustomDatasetArguments):
181171
),
182172
},
183173
)
174+
moe_calibrate_all_experts: bool = field(
175+
default=True,
176+
metadata={
177+
"help": (
178+
"Whether to calibrate all experts during MoE model calibration. "
179+
"When True, all experts will see all tokens during calibration, "
180+
"ensuring proper quantization statistics for all experts. "
181+
"When False, only routed experts will be used. "
182+
"Only relevant for MoE models. Default is True."
183+
),
184+
},
185+
)
184186
# --- pipeline arguments --- #
185187
pipeline: str | None = field(
186188
default="independent",

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from llmcompressor.core.session_functions import active_session
2121
from llmcompressor.datasets import get_calibration_dataloader
2222
from llmcompressor.entrypoints.utils import post_process, pre_process
23+
from llmcompressor.modeling.moe_context import moe_calibration_context
2324
from llmcompressor.pipelines import CalibrationPipeline
2425

2526
__all__ = ["Oneshot", "oneshot"]
@@ -209,11 +210,16 @@ def apply_recipe_modifiers(
209210
user_pipeline = self.dataset_args.pipeline
210211
modifiers = session.lifecycle.recipe.modifiers
211212
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
212-
pipeline(
213+
# Apply MoE calibration context for the entire calibration process
214+
with moe_calibration_context(
213215
self.model,
214-
calibration_dataloader,
215-
self.dataset_args,
216-
)
216+
calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts,
217+
):
218+
pipeline(
219+
self.model,
220+
calibration_dataloader,
221+
self.dataset_args,
222+
)
217223

218224
session.finalize()
219225

@@ -252,7 +258,7 @@ def oneshot(
252258
overwrite_cache: bool = False,
253259
preprocessing_num_workers: Optional[int] = None,
254260
min_tokens_per_module: Optional[float] = None,
255-
calibrate_moe_context: bool = False,
261+
moe_calibrate_all_experts: bool = True,
256262
quantization_aware_calibration: bool = True,
257263
# Miscellaneous arguments
258264
output_dir: Optional[str] = None,
@@ -316,9 +322,10 @@ def oneshot(
316322
preprocessing.
317323
:param min_tokens_per_module: Minimum percentage of tokens per
318324
module, relevant for MoE models.
319-
:param calibrate_moe_context: If during calibration, the MoE context should be
320-
enabled for the given model. This usually involves updating all MoE modules
321-
in the model for the duration of calibration.
325+
:param moe_calibrate_all_experts: Whether to calibrate all experts during MoE
326+
model calibration. When True, all experts will see all tokens during
327+
calibration, ensuring proper quantization statistics. When False, only
328+
routed experts will be used. Only relevant for MoE models. Default is True.
322329
:param quantization_aware_calibration: Whether to enable quantization-aware
323330
calibration in the sequential pipeline. When True, quantization is applied
324331
during forward pass in calibration. When False, quantization is disabled

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,25 @@
44
DeepseekV3MoE as OriginalDeepseekV3MoE,
55
)
66

7+
from llmcompressor.modeling.moe_context import (
8+
MoECalibrationModule,
9+
register_moe_calibration,
10+
)
11+
712

8-
class DeepseekV3MoECalibrate(torch.nn.Module):
13+
@register_moe_calibration("DeepseekV3MoE")
14+
class CalibrationDeepseekV3MoE(MoECalibrationModule):
915
"""
10-
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
16+
Calibration version of DeepseekV3MoE that sends all tokens to all experts.
1117
"""
1218

19+
is_permanent = True
20+
1321
def __init__(
1422
self,
15-
config: DeepseekV3Config,
1623
original: OriginalDeepseekV3MoE,
17-
calibrate_all_experts: bool,
24+
config: DeepseekV3Config,
25+
calibrate_all_experts: bool = True,
1826
):
1927
super().__init__()
2028
self.config = config
@@ -65,11 +73,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6573
return hidden_states
6674

6775

76+
# Legacy function for backward compatibility
6877
def replace(
6978
config: DeepseekV3Config,
7079
module: OriginalDeepseekV3MoE,
7180
calibrate_all_experts: bool,
7281
):
73-
return DeepseekV3MoECalibrate(
74-
config=config, original=module, calibrate_all_experts=calibrate_all_experts
82+
"""
83+
Legacy replacement function.
84+
Use CalibrationDeepseekV3MoE instead.
85+
"""
86+
return CalibrationDeepseekV3MoE(
87+
module,
88+
config,
89+
calibrate_all_experts=calibrate_all_experts,
7590
)

src/llmcompressor/modeling/llama4.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,47 @@
1111
Llama4TextMoe,
1212
)
1313

14+
from llmcompressor.modeling.moe_context import (
15+
MoECalibrationModule,
16+
register_moe_calibration,
17+
)
1418
from llmcompressor.utils.dev import skip_weights_initialize
1519

1620

17-
class SequentialLlama4TextMoe(torch.nn.Module):
21+
@register_moe_calibration("Llama4TextMoe")
22+
class SequentialLlama4TextMoe(MoECalibrationModule):
23+
"""
24+
Calibration version of Llama4TextMoe that unpacks experts for sequential processing.
25+
26+
This module:
27+
1. Unpacks the packed expert weights (3D -> 2D) for calibration
28+
2. Optionally sends all tokens to all experts during calibration
29+
3. Stays in unpacked form (permanent) for vLLM compatibility
30+
"""
31+
32+
is_permanent = True
33+
1834
def __init__(
1935
self,
20-
config: Llama4TextConfig,
2136
original: Llama4TextMoe,
22-
calibrate_all_experts: bool,
37+
config: Llama4Config,
38+
calibrate_all_experts: bool = True,
2339
):
2440
super().__init__()
25-
self.top_k = config.num_experts_per_tok
26-
self.hidden_dim = config.hidden_size
27-
self.num_experts = config.num_local_experts
28-
29-
self.experts = SequentialLlama4TextExperts(config, original.experts)
41+
# Extract text config from multimodal config if needed
42+
text_config = (
43+
config.get_text_config() if hasattr(config, "get_text_config") else config
44+
)
45+
self.top_k = text_config.num_experts_per_tok
46+
self.hidden_dim = text_config.hidden_size
47+
self.num_experts = text_config.num_local_experts
48+
49+
self.experts = SequentialLlama4TextExperts(text_config, original.experts)
3050
self.router = original.router
3151
self.shared_expert = original.shared_expert
3252
self.calibrate_all_experts = calibrate_all_experts
3353

34-
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]:
54+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
3555
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
3656
router_scores, router_logits = self.router(hidden_states) # transformers>=4.54
3757

@@ -74,9 +94,14 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts):
7494
self[i].down_proj.weight.data = down.t().contiguous()
7595

7696

97+
# Legacy function for backward compatibility
7798
def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool):
99+
"""
100+
Legacy replacement function.
101+
Use SequentialLlama4TextMoe instead.
102+
"""
78103
return SequentialLlama4TextMoe(
79-
config=config.get_text_config(),
80-
original=module,
104+
module,
105+
config,
81106
calibrate_all_experts=calibrate_all_experts,
82107
)

0 commit comments

Comments
 (0)