Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/e2e/singlecard/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
import torch_npu
from vllm.model_executor.layers.activation import SiluAndMul

from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.experts_selector import (check_npu_moe_gating_top_k,
select_experts)
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather

Expand Down Expand Up @@ -296,7 +297,10 @@ def test_select_experts(
e_score_correction_bias=e_score_correction_bias,
)

if use_grouped_topk:
call_moe_gatingtopk = check_npu_moe_gating_top_k(
hidden_states, topk, topk_group, num_expert_group, scoring_func,
custom_routing_function)
if not call_moe_gatingtopk and use_grouped_topk:
mock_native_grouped_topk.assert_called_once()
else:
mock_native_grouped_topk.assert_not_called()
Expand Down
41 changes: 22 additions & 19 deletions tests/ut/quantization/test_w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,14 @@ def setUp(self):

self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
"""Mock custom routing"""
self.mock_custom_routing = MagicMock()
self.mock_custom_routing.return_value = (torch.ones(
self.num_tokens, self.top_k),
torch.zeros(
self.num_tokens,
self.top_k,
dtype=torch.int32))

self.mock_ctx = MagicMock()
self.mock_ctx.weight_prefetch_method = MagicMock()
Expand All @@ -762,7 +770,7 @@ def setUp(self):
self.addCleanup(patcher.stop)
patcher.start()

@patch('torch_npu.npu_moe_gating_top_k_softmax')
@patch('torch_npu.npu_moe_gating_top_k')
def test_softmax_scoring(self, mock_topk):
"""Test softmax scoring function"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
Expand All @@ -789,12 +797,14 @@ def test_softmax_scoring(self, mock_topk):
def test_sigmoid_scoring(self):
"""Test sigmoid scoring function"""

weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid",
custom_routing_function=self.mock_custom_routing)

self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
Expand Down Expand Up @@ -853,27 +863,20 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):

def test_custom_routing_function(self):
"""Test custom routing function"""
mock_custom_routing = MagicMock()
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.int32))

weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
custom_routing_function=mock_custom_routing)
custom_routing_function=self.mock_custom_routing)

mock_custom_routing.assert_called_once()
self.mock_custom_routing.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)

@patch('torch_npu.npu_moe_gating_top_k_softmax')
@patch('torch_npu.npu_moe_gating_top_k')
def test_renormalize(self, mock_topk):
"""Test renormalization"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
Expand All @@ -899,13 +902,13 @@ def test_renormalize(self, mock_topk):
sums = weights.sum(dim=-1)
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))

@patch('torch_npu.npu_moe_gating_top_k_softmax')
@patch('torch_npu.npu_moe_gating_top_k')
def test_output_dtypes(self, mock_topk):
"""Test output dtypes"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long),
dtype=torch.int32),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def set_ascend_forward_context(
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)

# fused_moe_state is used in torchair, it will be deleted along with torchair
is_deepseek_v3_r1 = hasattr(
vllm_config.model_config.hf_config, 'n_routed_experts'
) and vllm_config.model_config.hf_config.n_routed_experts == 256
Expand Down
98 changes: 63 additions & 35 deletions vllm_ascend/ops/moe/experts_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import torch_npu
from vllm.forward_context import get_forward_context

from vllm_ascend.ascend_config import get_ascend_config


def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand Down Expand Up @@ -62,21 +60,28 @@ def select_experts(hidden_states: torch.Tensor,
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up")
topk_weights, topk_ids = _select_experts_with_fusion_ops(
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts)
custom_routing_function=custom_routing_function)

if topk_weights is None:
if is_support_npu_moe_gating_top_k:
topk_weights, topk_ids = _select_experts_with_fusion_ops(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts)
else:
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
Expand All @@ -93,6 +98,32 @@ def select_experts(hidden_states: torch.Tensor,
return topk_weights, topk_ids


def check_npu_moe_gating_top_k(
hidden_states: torch.Tensor,
top_k: int,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
custom_routing_function: Optional[Callable] = None):
if custom_routing_function is not None:
return False
if scoring_func != "softmax" and scoring_func != "sigmoid":
return False
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group
== 0 and hidden_states.shape[-1] // num_expert_group > 2):
return False
if topk_group < 1 or topk_group > num_expert_group:
return False
if top_k < 1 or \
top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
return False
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k:
return False
return True


def _native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
Expand Down Expand Up @@ -171,34 +202,31 @@ def _select_experts_with_fusion_ops(
e_score_correction_bias: Optional[torch.Tensor],
topk_group: Optional[int],
num_expert_group: Optional[int],
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
global_num_experts: int = -1):

topk_weights, topk_ids = None, None
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
norm_type = 0 if scoring_func == "softmax" else 1
if e_score_correction_bias is not None and \
e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(
router_logits.dtype)
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k,
bias=e_score_correction_bias,
k_group=topk_group,
group_count=num_expert_group,
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=norm_type, # 0: softmax; 1: sigmoid
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
if scoring_func == "softmax":
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)

return topk_weights, topk_ids
Expand Down
Loading