-
Notifications
You must be signed in to change notification settings - Fork 629
[refactor]support gatingtopk operator generalization #4356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[refactor]support gatingtopk operator generalization #4356
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for a fused gatingtopk MoE operator. The changes introduce a check function, _check_npu_moe_gating_top_k, to determine if the fused operator can be used, and refactors the expert selection logic to use this check. The tests have also been updated to reflect these changes.
My review has identified a critical bug in _check_npu_moe_gating_top_k (and its duplicated version in the test file) that could lead to a ZeroDivisionError. The order of checks is incorrect, using num_expert_group as a divisor before validating it. Please see the detailed comments for the fix.
| if top_k < 1 or \ | ||
| top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): | ||
| return False | ||
| if topk_group < 1 or topk_group > num_expert_group: | ||
| return False | ||
| if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: | ||
| return False | ||
| if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group | ||
| == 0 and hidden_states.shape[-1] // num_expert_group > 2): | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function has a potential ZeroDivisionError. The variable num_expert_group is used as a divisor on lines 100 and 104, but it is only checked to be greater than 0 on line 106. If num_expert_group is 0, this will cause a crash. To fix this, the check for num_expert_group > 0 should be performed before any division by it. It would also be better to import and use the _check_npu_moe_gating_top_k function from vllm_ascend.ops.moe.experts_selector to avoid code duplication and ensure consistency.
| if top_k < 1 or \ | |
| top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): | |
| return False | |
| if topk_group < 1 or topk_group > num_expert_group: | |
| return False | |
| if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: | |
| return False | |
| if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group | |
| == 0 and hidden_states.shape[-1] // num_expert_group > 2): | |
| return False | |
| if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group | |
| == 0 and hidden_states.shape[-1] // num_expert_group > 2): | |
| return False | |
| if top_k < 1 or \ | |
| top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): | |
| return False | |
| if topk_group < 1 or topk_group > num_expert_group: | |
| return False | |
| if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: | |
| return False |
| if top_k < 1 or \ | ||
| top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): | ||
| return False | ||
| if topk_group < 1 or topk_group > num_expert_group: | ||
| return False | ||
| if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: | ||
| return False | ||
| if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group | ||
| == 0 and hidden_states.shape[-1] // num_expert_group > 2): | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a potential ZeroDivisionError in this function. The variable num_expert_group is used as a divisor on lines 115 and 119 before it is checked to be greater than 0 on line 121. If num_expert_group is 0, the code will crash. The check for num_expert_group > 0 should be moved before it is used in any division operations to prevent this.
| if top_k < 1 or \ | |
| top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): | |
| return False | |
| if topk_group < 1 or topk_group > num_expert_group: | |
| return False | |
| if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: | |
| return False | |
| if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group | |
| == 0 and hidden_states.shape[-1] // num_expert_group > 2): | |
| return False | |
| if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group | |
| == 0 and hidden_states.shape[-1] // num_expert_group > 2): | |
| return False | |
| if top_k < 1 or \ | |
| top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): | |
| return False | |
| if topk_group < 1 or topk_group > num_expert_group: | |
| return False | |
| if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: | |
| return False |
dc446ff to
d517c57
Compare
b3ad464 to
4e6de1e
Compare
…lm-project#4050) ### What this PR does / why we need it? pick from : vllm-project#2958 Past: npu_moe_gating_top_k can only support 'group_count=256' pattern Now: 1、npu_moe_gating_top_k support all size of group_count 2、the functionality of `torch_npu.npu_moe_gating_top_k_softmax` are included in `torch_npu.npu_moe_gating_top_k` CANN: depends on 8.3.RC1 Performance: 1. GLM4.5-w8a8, TPS improve 6% 2. Qwen3, the same as before Signed-off-by: 1092626063 <1092626063@qq.com>
Signed-off-by: 1092626063 <1092626063@qq.com>
4e6de1e to
02a1a9e
Compare
What this PR does / why we need it?
This pr is cherry-pick from : #2958 and #4340
Past:
npu_moe_gating_top_k can only support 'group_count=256' pattern
Now:
1、npu_moe_gating_top_k support all size of group_count
2、the functionality of
torch_npu.npu_moe_gating_top_k_softmaxare included intorch_npu.npu_moe_gating_top_kCANN: depends on 8.3.RC1
Performance:
Does this PR introduce any user-facing change?
How was this patch tested?