Skip to content

Commit f08bb5f

Browse files
committed
ops gatingtopk fix nightly ci error
Signed-off-by: 1092626063 <1092626063@qq.com>
1 parent 5a4e8cd commit f08bb5f

File tree

3 files changed

+53
-22
lines changed

3 files changed

+53
-22
lines changed

tests/e2e/nightly/ops/test_fused_moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
import torch_npu
2929
from vllm.model_executor.layers.activation import SiluAndMul
3030

31-
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
31+
from vllm_ascend.ops.fused_moe.experts_selector import (
32+
select_experts, check_npu_moe_gating_top_k)
3233
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
3334
from vllm_ascend.ops.fused_moe.token_dispatcher import \
3435
TokenDispatcherWithAllGather
@@ -303,7 +304,10 @@ def test_select_experts(
303304
e_score_correction_bias=e_score_correction_bias,
304305
)
305306

306-
if use_grouped_topk:
307+
call_moe_gatingtopk = check_npu_moe_gating_top_k(
308+
hidden_states, topk, topk_group, num_expert_group, scoring_func,
309+
custom_routing_function)
310+
if not call_moe_gatingtopk and use_grouped_topk:
307311
mock_native_grouped_topk.assert_called_once()
308312
else:
309313
mock_native_grouped_topk.assert_not_called()

tests/ut/quantization/test_w8a8.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -818,8 +818,7 @@ def test_invalid_scoring_func(self):
818818
top_k=self.top_k,
819819
use_grouped_topk=False,
820820
renormalize=False,
821-
scoring_func="invalid_func",
822-
custom_routing_function=self.mock_custom_routing)
821+
scoring_func="invalid_func")
823822

824823
@patch('torch.topk')
825824
def test_grouped_topk(self, mock_topk):
@@ -829,15 +828,13 @@ def test_grouped_topk(self, mock_topk):
829828
self.top_k,
830829
dtype=torch.long))
831830

832-
weights, ids = select_experts(
833-
hidden_states=self.hidden_states,
834-
router_logits=self.router_logits,
835-
top_k=self.top_k,
836-
use_grouped_topk=True,
837-
renormalize=False,
838-
topk_group=4,
839-
num_expert_group=2,
840-
custom_routing_function=self.mock_custom_routing)
831+
weights, ids = select_experts(hidden_states=self.hidden_states,
832+
router_logits=self.router_logits,
833+
top_k=self.top_k,
834+
use_grouped_topk=True,
835+
renormalize=False,
836+
topk_group=4,
837+
num_expert_group=2)
841838

842839
mock_topk.assert_called()
843840
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -859,8 +856,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
859856
renormalize=False,
860857
topk_group=4,
861858
num_expert_group=2,
862-
e_score_correction_bias=e_score_correction_bias,
863-
custom_routing_function=self.mock_custom_routing)
859+
e_score_correction_bias=e_score_correction_bias)
864860

865861
mock_grouped_topk.assert_called_once()
866862
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))

vllm_ascend/ops/fused_moe/experts_selector.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,15 @@ def select_experts(hidden_states: torch.Tensor,
6060
if weight_prefetch_method:
6161
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
6262
hidden_states, "gate_up")
63-
if custom_routing_function is None:
63+
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
64+
hidden_states=hidden_states,
65+
top_k=top_k,
66+
topk_group=topk_group,
67+
num_expert_group=num_expert_group,
68+
scoring_func=scoring_func,
69+
custom_routing_function=custom_routing_function)
70+
71+
if is_support_npu_moe_gating_top_k:
6472
topk_weights, topk_ids = _select_experts_with_fusion_ops(
6573
hidden_states=hidden_states,
6674
router_logits=router_logits,
@@ -90,6 +98,32 @@ def select_experts(hidden_states: torch.Tensor,
9098
return topk_weights, topk_ids
9199

92100

101+
def check_npu_moe_gating_top_k(
102+
hidden_states: torch.Tensor,
103+
top_k: int,
104+
topk_group: Optional[int] = None,
105+
num_expert_group: Optional[int] = None,
106+
scoring_func: str = "softmax",
107+
custom_routing_function: Optional[Callable] = None):
108+
if custom_routing_function is not None:
109+
return False
110+
if scoring_func != "softmax" and scoring_func != "sigmoid":
111+
return False
112+
topk_group = topk_group if topk_group is not None else 1
113+
num_expert_group = num_expert_group if num_expert_group is not None else 1
114+
if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group
115+
== 0 and hidden_states.shape[-1] // num_expert_group > 2):
116+
return False
117+
if topk_group < 1 or topk_group > num_expert_group:
118+
return False
119+
if top_k < 1 or \
120+
top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
121+
return False
122+
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k:
123+
return False
124+
return True
125+
126+
93127
def _native_grouped_topk(
94128
topk_weights: torch.Tensor,
95129
num_expert_group: Optional[int],
@@ -172,12 +206,9 @@ def _select_experts_with_fusion_ops(
172206
routed_scaling_factor=1.0,
173207
global_num_experts: int = -1):
174208

175-
if scoring_func == "softmax":
176-
norm_type = 0
177-
topk_group = 1
178-
num_expert_group = 1
179-
else:
180-
norm_type = 1
209+
topk_group = topk_group if topk_group is not None else 1
210+
num_expert_group = num_expert_group if num_expert_group is not None else 1
211+
norm_type = 0 if scoring_func == "softmax" else 1
181212
if e_score_correction_bias is not None and \
182213
e_score_correction_bias.dtype != router_logits.dtype:
183214
e_score_correction_bias = e_score_correction_bias.to(

0 commit comments

Comments
 (0)