Skip to content

Commit e6fdfed

Browse files
dsikkaHDCharles
authored andcommitted
add additional test; add hints
1 parent 9bce0d8 commit e6fdfed

File tree

2 files changed

+63
-4
lines changed

2 files changed

+63
-4
lines changed

src/llmcompressor/modeling/qwen3_next_moe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424

2525
@register_moe_calibration("Qwen3NextSparseMoeBlock")
2626
class CalibrationQwen3NextSparseMoeBlock(MoECalibrationModule):
27+
from transformers import Qwen3NextConfig
28+
from transformers.models.qwen3_next.modeling_qwen3_next import (
29+
Qwen3NextSparseMoeBlock,
30+
)
31+
2732
"""
2833
Calibration version of Qwen3NextSparseMoeBlock that sends all tokens to all experts.
2934
"""
@@ -32,8 +37,8 @@ class CalibrationQwen3NextSparseMoeBlock(MoECalibrationModule):
3237

3338
def __init__(
3439
self,
35-
original,
36-
config,
40+
original: Qwen3NextSparseMoeBlock,
41+
config: Qwen3NextConfig,
3742
calibrate_all_experts: bool = True,
3843
):
3944
super().__init__()

tests/llmcompressor/modeling/test_calib_qwen3_next.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,62 @@
1+
import contextlib
2+
from functools import partial
3+
4+
import pytest
15
import torch
6+
from transformers import AutoModelForCausalLM
27

8+
from llmcompressor.modeling.moe_context import moe_calibration_context
39
from llmcompressor.modeling.qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock
4-
from llmcompressor.utils.helpers import calibration_forward_context
5-
from tests.testing_utils import requires_gpu
10+
from llmcompressor.utils.dev import skip_weights_download
11+
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
12+
from tests.testing_utils import requires_cadence, requires_gpu
13+
14+
15+
@requires_cadence("weekly")
16+
@pytest.mark.parametrize("model_stub", ["Qwen/Qwen3-Next-80B-A3B-Instruct"])
17+
def test_calib_replace_qwen3moe_all_experts(model_stub):
18+
with skip_weights_download():
19+
model = AutoModelForCausalLM.from_pretrained(model_stub)
20+
21+
# Qwen3MoE layer replacement is temporary within the context
22+
with contextlib.ExitStack() as stack:
23+
stack.enter_context(calibration_forward_context(model))
24+
stack.enter_context(DisableQuantization(model))
25+
stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True))
26+
27+
# Find one MoE layer
28+
moe_layer = None
29+
for name, module in model.named_modules():
30+
if isinstance(module, CalibrationQwen3NextSparseMoeBlock):
31+
moe_layer = module
32+
break
33+
34+
assert moe_layer is not None
35+
36+
num_experts = len(moe_layer.experts)
37+
expert_triggered = [False for _ in range(num_experts)]
38+
39+
# Define the hook function
40+
def hook_fn(i, module, input, output):
41+
expert_triggered[i] = True
42+
43+
# Attach hooks using functools.partial to bind each index
44+
for i, expert in enumerate(moe_layer.experts):
45+
expert.register_forward_hook(partial(hook_fn, i))
46+
47+
# Create dummy input tensor that simulates hidden_states
48+
hidden_dim = model.config.hidden_size
49+
batch, seq_len = 4, 32
50+
sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32)
51+
52+
# Forward through the MoE layer directly
53+
with torch.no_grad():
54+
_ = moe_layer(sample)
55+
56+
# Assert all experts are used
57+
assert all(
58+
expert_triggered
59+
), f"Not all experts were triggered: {expert_triggered}"
660

761

862
@requires_gpu

0 commit comments

Comments
 (0)