File tree Expand file tree Collapse file tree 3 files changed +28
-9
lines changed
examples/quantization_w4a4_fp4
src/llmcompressor/modeling Expand file tree Collapse file tree 3 files changed +28
-9
lines changed Original file line number Diff line number Diff line change @@ -68,18 +68,22 @@ def tokenize(sample):
6868)
6969
7070# Apply quantization.
71- # We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock`
72- # during calibration.
71+ # MoE calibration is now handled automatically by the pipeline.
72+ # We set `moe_calibrate_all_experts` to True to ensure all experts receive
73+ # calibration data. This temporarily updates the model definition to use
74+ # `CalibrationQwen3NextSparseMoeBlock` (from `llmcompressor.modeling.qwen3_next_moe`)
75+ # which replaces the original `Qwen3NextSparseMoeBlock` class.
76+ # This updates how the forward pass is handled in the MoE block during calibration.
7377# Feel free to update the definition under
74- # llm-compressor/src/llmcompressor/modeling/qwen3_moe .py` to play around with
75- # this behaviour and evaluate its impact on quantization performance
78+ # llm-compressor/src/llmcompressor/modeling/qwen3_next_moe .py to play around with
79+ # this behavior and evaluate its impact on quantization performance.
7680oneshot (
7781 model = model ,
7882 dataset = ds ,
7983 recipe = recipe ,
8084 max_seq_length = MAX_SEQUENCE_LENGTH ,
8185 num_calibration_samples = NUM_CALIBRATION_SAMPLES ,
82- calibrate_moe_context = True ,
86+ moe_calibrate_all_experts = True ,
8387)
8488
8589
Original file line number Diff line number Diff line change 2929from llmcompressor .modeling .qwen3_moe import ( # noqa: F401
3030 CalibrationQwen3MoeSparseMoeBlock ,
3131)
32+ from llmcompressor .modeling .qwen3_next_moe import ( # noqa: F401
33+ CalibrationQwen3NextSparseMoeBlock ,
34+ )
3235from llmcompressor .modeling .qwen3_vl_moe import (
3336 replace as replace_Qwen3VLMoE ,
3437)
Original file line number Diff line number Diff line change 1616
1717import torch
1818
19+ from llmcompressor .modeling .moe_context import (
20+ MoECalibrationModule ,
21+ register_moe_calibration ,
22+ )
23+
24+
25+ @register_moe_calibration ("Qwen3NextSparseMoeBlock" )
26+ class CalibrationQwen3NextSparseMoeBlock (MoECalibrationModule ):
27+ """
28+ Calibration version of Qwen3NextSparseMoeBlock that sends all tokens to all experts.
29+ """
30+
31+ is_permanent = False
1932
20- class Qwen3NextSparseMoeBlock (torch .nn .Module ):
2133 def __init__ (
2234 self ,
23- config ,
2435 original ,
25- calibrate_all_experts : bool ,
36+ config ,
37+ calibrate_all_experts : bool = True ,
2638 ):
2739 super ().__init__ ()
2840 self .num_experts = config .num_experts
@@ -109,6 +121,6 @@ def replace(
109121 module ,
110122 calibrate_all_experts ,
111123):
112- return Qwen3NextSparseMoeBlock (
124+ return CalibrationQwen3NextSparseMoeBlock (
113125 config = config , original = module , calibrate_all_experts = calibrate_all_experts
114126 )
You can’t perform that action at this time.
0 commit comments