Skip to content

Commit a2e4c3f

Browse files
authored
Revert "[cherry-pick][refactor]support gatingtopk operator generalization (#4050)" (#4352)
This reverts commit c87a77e. it breaks ops e2e test Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 5ad0ccd commit a2e4c3f

File tree

3 files changed

+69
-74
lines changed

3 files changed

+69
-74
lines changed

tests/ut/quantization/test_w8a8.py

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -754,14 +754,6 @@ def setUp(self):
754754

755755
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
756756
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
757-
"""Mock custom routing"""
758-
self.mock_custom_routing = MagicMock()
759-
self.mock_custom_routing.return_value = (torch.ones(
760-
self.num_tokens, self.top_k),
761-
torch.zeros(
762-
self.num_tokens,
763-
self.top_k,
764-
dtype=torch.int32))
765757

766758
self.mock_ctx = MagicMock()
767759
self.mock_ctx.weight_prefetch_method = MagicMock()
@@ -771,7 +763,7 @@ def setUp(self):
771763
self.addCleanup(patcher.stop)
772764
patcher.start()
773765

774-
@patch('torch_npu.npu_moe_gating_top_k')
766+
@patch('torch_npu.npu_moe_gating_top_k_softmax')
775767
def test_softmax_scoring(self, mock_topk):
776768
"""Test softmax scoring function"""
777769
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
@@ -798,14 +790,12 @@ def test_softmax_scoring(self, mock_topk):
798790
def test_sigmoid_scoring(self):
799791
"""Test sigmoid scoring function"""
800792

801-
weights, ids = select_experts(
802-
hidden_states=self.hidden_states,
803-
router_logits=self.router_logits,
804-
top_k=self.top_k,
805-
use_grouped_topk=False,
806-
renormalize=False,
807-
scoring_func="sigmoid",
808-
custom_routing_function=self.mock_custom_routing)
793+
weights, ids = select_experts(hidden_states=self.hidden_states,
794+
router_logits=self.router_logits,
795+
top_k=self.top_k,
796+
use_grouped_topk=False,
797+
renormalize=False,
798+
scoring_func="sigmoid")
809799

810800
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
811801
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -818,8 +808,7 @@ def test_invalid_scoring_func(self):
818808
top_k=self.top_k,
819809
use_grouped_topk=False,
820810
renormalize=False,
821-
scoring_func="invalid_func",
822-
custom_routing_function=self.mock_custom_routing)
811+
scoring_func="invalid_func")
823812

824813
@patch('torch.topk')
825814
def test_grouped_topk(self, mock_topk):
@@ -829,15 +818,13 @@ def test_grouped_topk(self, mock_topk):
829818
self.top_k,
830819
dtype=torch.long))
831820

832-
weights, ids = select_experts(
833-
hidden_states=self.hidden_states,
834-
router_logits=self.router_logits,
835-
top_k=self.top_k,
836-
use_grouped_topk=True,
837-
renormalize=False,
838-
topk_group=4,
839-
num_expert_group=2,
840-
custom_routing_function=self.mock_custom_routing)
821+
weights, ids = select_experts(hidden_states=self.hidden_states,
822+
router_logits=self.router_logits,
823+
top_k=self.top_k,
824+
use_grouped_topk=True,
825+
renormalize=False,
826+
topk_group=4,
827+
num_expert_group=2)
841828

842829
mock_topk.assert_called()
843830
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -859,29 +846,35 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
859846
renormalize=False,
860847
topk_group=4,
861848
num_expert_group=2,
862-
e_score_correction_bias=e_score_correction_bias,
863-
custom_routing_function=self.mock_custom_routing)
849+
e_score_correction_bias=e_score_correction_bias)
864850

865851
mock_grouped_topk.assert_called_once()
866852
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
867853
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
868854

869855
def test_custom_routing_function(self):
870856
"""Test custom routing function"""
857+
mock_custom_routing = MagicMock()
858+
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
859+
self.top_k),
860+
torch.zeros(self.num_tokens,
861+
self.top_k,
862+
dtype=torch.int32))
863+
871864
weights, ids = select_experts(
872865
hidden_states=self.hidden_states,
873866
router_logits=self.router_logits,
874867
top_k=self.top_k,
875868
use_grouped_topk=False,
876869
renormalize=False,
877-
custom_routing_function=self.mock_custom_routing)
870+
custom_routing_function=mock_custom_routing)
878871

879-
self.mock_custom_routing.assert_called_once()
872+
mock_custom_routing.assert_called_once()
880873
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
881874
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
882875
self.assertEqual(ids.dtype, torch.int32)
883876

884-
@patch('torch_npu.npu_moe_gating_top_k')
877+
@patch('torch_npu.npu_moe_gating_top_k_softmax')
885878
def test_renormalize(self, mock_topk):
886879
"""Test renormalization"""
887880
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
@@ -907,13 +900,13 @@ def test_renormalize(self, mock_topk):
907900
sums = weights.sum(dim=-1)
908901
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
909902

910-
@patch('torch_npu.npu_moe_gating_top_k')
903+
@patch('torch_npu.npu_moe_gating_top_k_softmax')
911904
def test_output_dtypes(self, mock_topk):
912905
"""Test output dtypes"""
913906
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
914907
torch.zeros(self.num_tokens,
915908
self.top_k,
916-
dtype=torch.int32),
909+
dtype=torch.long),
917910
torch.arange(0,
918911
self.num_tokens * self.top_k,
919912
dtype=torch.int32).view(

vllm_ascend/ascend_forward_context.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def set_ascend_forward_context(
9696
ep_size = (get_ep_group().world_size if
9797
vllm_config.parallel_config.enable_expert_parallel else 1)
9898

99-
# fused_moe_state is used in torchair, it will be deleted along with torchair
10099
is_deepseek_v3_r1 = hasattr(
101100
vllm_config.model_config.hf_config, 'n_routed_experts'
102101
) and vllm_config.model_config.hf_config.n_routed_experts == 256

vllm_ascend/ops/moe/experts_selector.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import torch_npu
2121
from vllm.forward_context import get_forward_context
2222

23+
from vllm_ascend.ascend_config import get_ascend_config
24+
2325

2426
def select_experts(hidden_states: torch.Tensor,
2527
router_logits: torch.Tensor,
@@ -60,20 +62,21 @@ def select_experts(hidden_states: torch.Tensor,
6062
if weight_prefetch_method:
6163
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
6264
hidden_states, "gate_up")
63-
if custom_routing_function is None:
64-
topk_weights, topk_ids = _select_experts_with_fusion_ops(
65-
hidden_states=hidden_states,
66-
router_logits=router_logits,
67-
top_k=top_k,
68-
use_grouped_topk=use_grouped_topk,
69-
topk_group=topk_group,
70-
renormalize=renormalize,
71-
e_score_correction_bias=e_score_correction_bias,
72-
num_expert_group=num_expert_group,
73-
scoring_func=scoring_func,
74-
routed_scaling_factor=routed_scaling_factor,
75-
global_num_experts=global_num_experts)
76-
else:
65+
topk_weights, topk_ids = _select_experts_with_fusion_ops(
66+
hidden_states=hidden_states,
67+
router_logits=router_logits,
68+
top_k=top_k,
69+
use_grouped_topk=use_grouped_topk,
70+
topk_group=topk_group,
71+
renormalize=renormalize,
72+
e_score_correction_bias=e_score_correction_bias,
73+
num_expert_group=num_expert_group,
74+
custom_routing_function=custom_routing_function,
75+
scoring_func=scoring_func,
76+
routed_scaling_factor=routed_scaling_factor,
77+
global_num_experts=global_num_experts)
78+
79+
if topk_weights is None:
7780
topk_weights, topk_ids = _native_select_experts(
7881
hidden_states=hidden_states,
7982
router_logits=router_logits,
@@ -168,34 +171,34 @@ def _select_experts_with_fusion_ops(
168171
e_score_correction_bias: Optional[torch.Tensor],
169172
topk_group: Optional[int],
170173
num_expert_group: Optional[int],
174+
custom_routing_function: Optional[Callable] = None,
171175
scoring_func: str = "softmax",
172176
routed_scaling_factor=1.0,
173177
global_num_experts: int = -1):
174178

175-
if scoring_func == "softmax":
176-
norm_type = 0
177-
topk_group = 1
178-
num_expert_group = 1
179-
else:
180-
norm_type = 1
181-
if e_score_correction_bias is not None and \
182-
e_score_correction_bias.dtype != router_logits.dtype:
183-
e_score_correction_bias = e_score_correction_bias.to(
184-
router_logits.dtype)
185-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
186-
router_logits,
187-
k=top_k,
188-
bias=e_score_correction_bias,
189-
k_group=topk_group,
190-
group_count=num_expert_group,
191-
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
192-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
193-
norm_type=norm_type, # 0: softmax; 1: sigmoid
194-
# out_flag=False, # todo new api; should the third output be output
195-
# y2_flag=False, # old api; should the third output be output
196-
routed_scaling_factor=1,
197-
eps=float(1e-20))
198-
if scoring_func == "softmax":
179+
topk_weights, topk_ids = None, None
180+
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
181+
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
182+
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
183+
if is_deepseek_v3_r1:
184+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
185+
router_logits,
186+
k=top_k, # topk currently 8
187+
bias=e_score_correction_bias,
188+
k_group=topk_group, # fix: 4
189+
group_count=num_expert_group, # fix 8
190+
group_select_mode=
191+
1, # 0: the maximum in the group; 1: topk2.sum(fix)
192+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
193+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
194+
# out_flag=False, # todo new api; should the third output be output
195+
# y2_flag=False, # old api; should the third output be output
196+
routed_scaling_factor=1,
197+
eps=float(1e-20))
198+
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
199+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
200+
x=router_logits, finished=None, k=top_k)
201+
topk_ids = topk_ids.to(torch.int32)
199202
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
200203

201204
return topk_weights, topk_ids

0 commit comments

Comments
 (0)