-
Notifications
You must be signed in to change notification settings - Fork 606
[cherry-pick pr-4254] bugfix for mtp>1 when lm_head_tp>1 #4360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v0.11.0-dev
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a bug in speculative decoding with multi-token prediction (MTP) and language model head tensor parallelism (lm_head_tp > 1), where an incorrect number of compute_logits calls in the dummy run caused the model execution to hang. The fix introduces a dummy_compute_logits function that is passed down to the proposers' dummy_run methods and called within the speculative loop, ensuring the dummy run accurately reflects the actual execution flow. The changes also include related correctness fixes for tensor slicing and device synchronization. My review suggests a performance optimization to avoid an unnecessary host-to-device memory copy during the dummy run by creating a tensor directly on the target device.
| dummy_indices = torch.zeros(max_num_reqs_across_dp, | ||
| dtype=torch.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dummy_indices tensor is created on the CPU by default and is then implicitly transferred to the NPU when used for indexing the hidden_states tensor. This host-to-device copy occurs in each dummy run, which can be inefficient. To improve performance, dummy_indices should be created directly on the NPU.
| dummy_indices = torch.zeros(max_num_reqs_across_dp, | |
| dtype=torch.int32) | |
| dummy_indices = torch.zeros(max_num_reqs_across_dp, | |
| dtype=torch.int32, device=self.device) |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
837d924 to
dab8df5
Compare
What this PR does / why we need it?
Previously, the dummy run executed compute_logits only once, regardless of num_speculative_tokens. This caused execute_model to hang on compute_logits when lm head tensor parallelism exceeded 1. The fix ensures compute_logits executes correctly during dummy run, matching num_speculative_tokens.
Does this PR introduce any user-facing change?
no
How was this patch tested?