Skip to content

Commit 23a613a

Browse files
dsikkaHDCharles
authored andcommitted
Add tests
1 parent fb83b7f commit 23a613a

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
3+
from llmcompressor.modeling.qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock
4+
from llmcompressor.utils.helpers import calibration_forward_context
5+
from tests.testing_utils import requires_gpu
6+
7+
8+
@requires_gpu
9+
def test_calib_qwen3_moe_module():
10+
from transformers import Qwen3NextConfig
11+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
12+
config = Qwen3NextConfig()
13+
with torch.device("cuda"):
14+
original = Qwen3NextSparseMoeBlock(config).eval()
15+
16+
# Create dummy input tensor that simulates hidden_states
17+
hidden_dim = config.hidden_size
18+
batch, seq_len = 4, 32
19+
sample = torch.randn(batch, seq_len, hidden_dim, device="cuda")
20+
21+
with calibration_forward_context(original):
22+
true_output = original(sample)
23+
24+
module = CalibrationQwen3NextSparseMoeBlock(
25+
original, config, calibrate_all_experts=True
26+
)
27+
28+
with calibration_forward_context(module):
29+
output = module(sample)
30+
#assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
31+
assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10
32+
33+
module = CalibrationQwen3NextSparseMoeBlock(
34+
original, config, calibrate_all_experts=False
35+
)
36+
with calibration_forward_context(module):
37+
output = module(sample)
38+
#assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
39+
assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10

0 commit comments

Comments
 (0)