From c069bf0b653ce5c044cbcb396b0085c629a5fcfd Mon Sep 17 00:00:00 2001 From: 1092626063 <1092626063@qq.com> Date: Wed, 19 Nov 2025 10:39:28 +0800 Subject: [PATCH 1/2] [cherry-pick][refactor]support gatingtopk operator generalization (#4050) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? pick from : https://github.com/vllm-project/vllm-ascend/pull/2958 Past: npu_moe_gating_top_k can only support 'group_count=256' pattern Now: 1、npu_moe_gating_top_k support all size of group_count 2、the functionality of `torch_npu.npu_moe_gating_top_k_softmax` are included in `torch_npu.npu_moe_gating_top_k` CANN: depends on 8.3.RC1 Performance: 1. GLM4.5-w8a8, TPS improve 6% 2. Qwen3, the same as before Signed-off-by: 1092626063 <1092626063@qq.com> --- tests/ut/quantization/test_w8a8.py | 63 +++++++++++--------- vllm_ascend/ascend_forward_context.py | 1 + vllm_ascend/ops/moe/experts_selector.py | 79 ++++++++++++------------- 3 files changed, 74 insertions(+), 69 deletions(-) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 6702d2bdcbd..c2597bbe851 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -753,6 +753,14 @@ def setUp(self): self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) self.router_logits = torch.randn(self.num_tokens, self.num_experts) + """Mock custom routing""" + self.mock_custom_routing = MagicMock() + self.mock_custom_routing.return_value = (torch.ones( + self.num_tokens, self.top_k), + torch.zeros( + self.num_tokens, + self.top_k, + dtype=torch.int32)) self.mock_ctx = MagicMock() self.mock_ctx.weight_prefetch_method = MagicMock() @@ -762,7 +770,7 @@ def setUp(self): self.addCleanup(patcher.stop) patcher.start() - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_softmax_scoring(self, mock_topk): """Test softmax scoring function""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -789,12 +797,14 @@ def test_softmax_scoring(self, mock_topk): def test_sigmoid_scoring(self): """Test sigmoid scoring function""" - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="sigmoid") + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="sigmoid", + custom_routing_function=self.mock_custom_routing) self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -807,7 +817,8 @@ def test_invalid_scoring_func(self): top_k=self.top_k, use_grouped_topk=False, renormalize=False, - scoring_func="invalid_func") + scoring_func="invalid_func", + custom_routing_function=self.mock_custom_routing) @patch('torch.topk') def test_grouped_topk(self, mock_topk): @@ -817,13 +828,15 @@ def test_grouped_topk(self, mock_topk): self.top_k, dtype=torch.long)) - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=False, - topk_group=4, - num_expert_group=2) + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2, + custom_routing_function=self.mock_custom_routing) mock_topk.assert_called() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -845,7 +858,8 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): renormalize=False, topk_group=4, num_expert_group=2, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + custom_routing_function=self.mock_custom_routing) mock_grouped_topk.assert_called_once() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -853,27 +867,20 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): def test_custom_routing_function(self): """Test custom routing function""" - mock_custom_routing = MagicMock() - mock_custom_routing.return_value = (torch.ones(self.num_tokens, - self.top_k), - torch.zeros(self.num_tokens, - self.top_k, - dtype=torch.int32)) - weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, use_grouped_topk=False, renormalize=False, - custom_routing_function=mock_custom_routing) + custom_routing_function=self.mock_custom_routing) - mock_custom_routing.assert_called_once() + self.mock_custom_routing.assert_called_once() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.dtype, torch.int32) - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_renormalize(self, mock_topk): """Test renormalization""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -899,13 +906,13 @@ def test_renormalize(self, mock_topk): sums = weights.sum(dim=-1) self.assertTrue(torch.allclose(sums, torch.ones_like(sums))) - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_output_dtypes(self, mock_topk): """Test output dtypes""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), torch.zeros(self.num_tokens, self.top_k, - dtype=torch.long), + dtype=torch.int32), torch.arange(0, self.num_tokens * self.top_k, dtype=torch.int32).view( diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index a700fbfd775..580508ae734 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -96,6 +96,7 @@ def set_ascend_forward_context( ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) + # fused_moe_state is used in torchair, it will be deleted along with torchair is_deepseek_v3_r1 = hasattr( vllm_config.model_config.hf_config, 'n_routed_experts' ) and vllm_config.model_config.hf_config.n_routed_experts == 256 diff --git a/vllm_ascend/ops/moe/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py index e511d6b554f..eb3fc848c8e 100644 --- a/vllm_ascend/ops/moe/experts_selector.py +++ b/vllm_ascend/ops/moe/experts_selector.py @@ -20,8 +20,6 @@ import torch_npu from vllm.forward_context import get_forward_context -from vllm_ascend.ascend_config import get_ascend_config - def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -62,21 +60,20 @@ def select_experts(hidden_states: torch.Tensor, if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( hidden_states, "gate_up") - topk_weights, topk_ids = _select_experts_with_fusion_ops( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - renormalize=renormalize, - e_score_correction_bias=e_score_correction_bias, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - global_num_experts=global_num_experts) - - if topk_weights is None: + if custom_routing_function is None: + topk_weights, topk_ids = _select_experts_with_fusion_ops( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + global_num_experts=global_num_experts) + else: topk_weights, topk_ids = _native_select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -171,34 +168,34 @@ def _select_experts_with_fusion_ops( e_score_correction_bias: Optional[torch.Tensor], topk_group: Optional[int], num_expert_group: Optional[int], - custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor=1.0, global_num_experts: int = -1): - topk_weights, topk_ids = None, None - # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern - global_redundant_expert_num = get_ascend_config().init_redundancy_expert - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax": - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( - x=router_logits, finished=None, k=top_k) - topk_ids = topk_ids.to(torch.int32) + if scoring_func == "softmax": + norm_type = 0 + topk_group = 1 + num_expert_group = 1 + else: + norm_type = 1 + if e_score_correction_bias is not None and \ + e_score_correction_bias.dtype != router_logits.dtype: + e_score_correction_bias = e_score_correction_bias.to( + router_logits.dtype) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, + bias=e_score_correction_bias, + k_group=topk_group, + group_count=num_expert_group, + group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=norm_type, # 0: softmax; 1: sigmoid + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + if scoring_func == "softmax": topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids From 02a1a9e33b52bc09077dbef715b2d20d36087e81 Mon Sep 17 00:00:00 2001 From: 1092626063 <1092626063@qq.com> Date: Mon, 24 Nov 2025 16:52:32 +0800 Subject: [PATCH 2/2] ops gatingtopk fix nightly ci error Signed-off-by: 1092626063 <1092626063@qq.com> --- tests/e2e/singlecard/ops/test_fused_moe.py | 8 +++- tests/ut/quantization/test_w8a8.py | 22 +++++------ vllm_ascend/ops/moe/experts_selector.py | 45 ++++++++++++++++++---- 3 files changed, 53 insertions(+), 22 deletions(-) diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index 4735a5f159c..8180858d548 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -28,7 +28,8 @@ import torch_npu from vllm.model_executor.layers.activation import SiluAndMul -from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.experts_selector import (check_npu_moe_gating_top_k, + select_experts) from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather @@ -296,7 +297,10 @@ def test_select_experts( e_score_correction_bias=e_score_correction_bias, ) - if use_grouped_topk: + call_moe_gatingtopk = check_npu_moe_gating_top_k( + hidden_states, topk, topk_group, num_expert_group, scoring_func, + custom_routing_function) + if not call_moe_gatingtopk and use_grouped_topk: mock_native_grouped_topk.assert_called_once() else: mock_native_grouped_topk.assert_not_called() diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index c2597bbe851..2ad80887f69 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -817,8 +817,7 @@ def test_invalid_scoring_func(self): top_k=self.top_k, use_grouped_topk=False, renormalize=False, - scoring_func="invalid_func", - custom_routing_function=self.mock_custom_routing) + scoring_func="invalid_func") @patch('torch.topk') def test_grouped_topk(self, mock_topk): @@ -828,15 +827,13 @@ def test_grouped_topk(self, mock_topk): self.top_k, dtype=torch.long)) - weights, ids = select_experts( - hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=False, - topk_group=4, - num_expert_group=2, - custom_routing_function=self.mock_custom_routing) + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2) mock_topk.assert_called() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -858,8 +855,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): renormalize=False, topk_group=4, num_expert_group=2, - e_score_correction_bias=e_score_correction_bias, - custom_routing_function=self.mock_custom_routing) + e_score_correction_bias=e_score_correction_bias) mock_grouped_topk.assert_called_once() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) diff --git a/vllm_ascend/ops/moe/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py index eb3fc848c8e..05ec0e38491 100644 --- a/vllm_ascend/ops/moe/experts_selector.py +++ b/vllm_ascend/ops/moe/experts_selector.py @@ -60,7 +60,15 @@ def select_experts(hidden_states: torch.Tensor, if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( hidden_states, "gate_up") - if custom_routing_function is None: + is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k( + hidden_states=hidden_states, + top_k=top_k, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + custom_routing_function=custom_routing_function) + + if is_support_npu_moe_gating_top_k: topk_weights, topk_ids = _select_experts_with_fusion_ops( hidden_states=hidden_states, router_logits=router_logits, @@ -90,6 +98,32 @@ def select_experts(hidden_states: torch.Tensor, return topk_weights, topk_ids +def check_npu_moe_gating_top_k( + hidden_states: torch.Tensor, + top_k: int, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scoring_func: str = "softmax", + custom_routing_function: Optional[Callable] = None): + if custom_routing_function is not None: + return False + if scoring_func != "softmax" and scoring_func != "sigmoid": + return False + topk_group = topk_group if topk_group is not None else 1 + num_expert_group = num_expert_group if num_expert_group is not None else 1 + if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group + == 0 and hidden_states.shape[-1] // num_expert_group > 2): + return False + if topk_group < 1 or topk_group > num_expert_group: + return False + if top_k < 1 or \ + top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): + return False + if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: + return False + return True + + def _native_grouped_topk( topk_weights: torch.Tensor, num_expert_group: Optional[int], @@ -172,12 +206,9 @@ def _select_experts_with_fusion_ops( routed_scaling_factor=1.0, global_num_experts: int = -1): - if scoring_func == "softmax": - norm_type = 0 - topk_group = 1 - num_expert_group = 1 - else: - norm_type = 1 + topk_group = topk_group if topk_group is not None else 1 + num_expert_group = num_expert_group if num_expert_group is not None else 1 + norm_type = 0 if scoring_func == "softmax" else 1 if e_score_correction_bias is not None and \ e_score_correction_bias.dtype != router_logits.dtype: e_score_correction_bias = e_score_correction_bias.to(