Skip to content

Commit 9bce0d8

Browse files
dsikkaHDCharles
authored andcommitted
update test
1 parent 23a613a commit 9bce0d8

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/llmcompressor/modeling/qwen3_next_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
):
3939
super().__init__()
4040
self.num_experts = config.num_experts
41-
self.top_k = config.top_k
41+
self.top_k = original.top_k
4242
self.norm_topk_prob = config.norm_topk_prob
4343

4444
# gating
@@ -56,7 +56,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
5656
router_logits = self.gate(hidden_states)
5757

5858
routing_weights = torch.nn.functional.softmax(
59-
router_logits, dim=1, dtype=torch.float
59+
router_logits, dim=-1, dtype=torch.float
6060
)
6161
routing_weights, selected_experts = torch.topk(
6262
routing_weights, self.top_k, dim=-1

tests/llmcompressor/modeling/test_calib_qwen3_next.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
@requires_gpu
99
def test_calib_qwen3_moe_module():
1010
from transformers import Qwen3NextConfig
11-
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
11+
from transformers.models.qwen3_next.modeling_qwen3_next import (
12+
Qwen3NextSparseMoeBlock,
13+
)
14+
1215
config = Qwen3NextConfig()
1316
with torch.device("cuda"):
1417
original = Qwen3NextSparseMoeBlock(config).eval()
@@ -27,13 +30,13 @@ def test_calib_qwen3_moe_module():
2730

2831
with calibration_forward_context(module):
2932
output = module(sample)
30-
#assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
33+
assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
3134
assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10
3235

3336
module = CalibrationQwen3NextSparseMoeBlock(
3437
original, config, calibrate_all_experts=False
3538
)
3639
with calibration_forward_context(module):
3740
output = module(sample)
38-
#assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
41+
assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
3942
assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10

0 commit comments

Comments
 (0)