Skip to content

Commit 837d924

Browse files
committed
bugfix for mtp>1 when lm_head_tp>1
Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
1 parent c506ba6 commit 837d924

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def dummy_run(self,
116116
num_reqs: int = 0,
117117
num_tokens_across_dp: Optional[torch.Tensor] = None,
118118
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
119-
batch_descriptor=None):
119+
batch_descriptor=None,
120+
dummy_compute_logits=lambda hidden_states: None):
120121
moe_comm_type = self.runner._select_moe_comm_method(
121122
num_tokens, with_prefill)
122123
with set_ascend_forward_context(None,
@@ -128,6 +129,7 @@ def dummy_run(self,
128129
positions=self.positions[:num_tokens],
129130
hidden_states=self.hidden_states[:num_tokens],
130131
)
132+
dummy_compute_logits(self.hidden_states)
131133

132134
def generate_token_ids(self,
133135
valid_sampled_token_ids: list[list[int]],

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def dummy_run(self,
113113
num_reqs: int = 0,
114114
num_tokens_across_dp=None,
115115
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
116-
batch_descriptor=None) -> None:
116+
batch_descriptor=None,
117+
dummy_compute_logits=lambda hidden_states: None) -> None:
117118
if not self.torchair_graph_enabled:
118119
# TODO: adapt enable_dbo later
119120
(num_tokens, num_tokens_across_dp, with_prefill,
@@ -187,6 +188,7 @@ def dummy_run(self,
187188
self.model(input_ids=input_ids,
188189
positions=positions,
189190
hidden_states=previous_hidden_states)
191+
dummy_compute_logits(previous_hidden_states)
190192
if with_prefill:
191193
break
192194

@@ -489,6 +491,7 @@ def _propose(
489491
logits = self.model.compute_logits(sample_hidden_states)
490492
if lmhead_tp_enable() and num_indices < logits.shape[0]:
491493
logits = logits[:num_indices]
494+
last_token_indices = last_token_indices[:num_indices]
492495
draft_token_ids = logits.argmax(dim=-1)
493496

494497
if self.num_speculative_tokens == 1:
@@ -553,7 +556,7 @@ def _propose(
553556
# For the requests that exceed the max model length, we set the
554557
# sequence length to 1 to minimize their overheads in attention.
555558
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
556-
attn_metadata_i.seq_lens.device, non_blocking=True)
559+
attn_metadata_i.seq_lens.device, non_blocking=False)
557560
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
558561
exceeds_max_model_len_cpu, 1)
559562
# Mask out the slot mappings that exceed the max model length.

vllm_ascend/worker/model_runner_v1.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2475,13 +2475,21 @@ def _dummy_run(
24752475
need_dummy_logits = (not self.in_profile_run
24762476
and lmhead_tp_enable())
24772477

2478-
if need_dummy_logits:
2479-
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
2480-
dummy_indices = torch.zeros(max_num_reqs_across_dp,
2481-
dtype=torch.int32)
2482-
2483-
def dummy_compute_logits(hidden_states):
2484-
return self.model.compute_logits(
2478+
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
2479+
dummy_indices = torch.zeros(max_num_reqs_across_dp,
2480+
dtype=torch.int32)
2481+
2482+
def dummy_compute_logits(hidden_states):
2483+
if not need_dummy_logits:
2484+
return None
2485+
return self.model.compute_logits(hidden_states[dummy_indices])
2486+
2487+
def dummy_drafter_compute_logits(hidden_states):
2488+
if not need_dummy_logits or self.drafter is None:
2489+
return
2490+
if hasattr(self.drafter, "model") and hasattr(
2491+
self.drafter.model, "compute_logits"):
2492+
return self.drafter.model.compute_logits(
24852493
hidden_states[dummy_indices])
24862494

24872495
with set_ascend_forward_context(
@@ -2503,8 +2511,7 @@ def dummy_compute_logits(hidden_states):
25032511
with_prefill, is_torchair_compile, input_ids, positions,
25042512
attn_metadata, num_tokens, intermediate_tensors,
25052513
inputs_embeds)
2506-
if need_dummy_logits:
2507-
dummy_compute_logits(hidden_states)
2514+
dummy_compute_logits(hidden_states)
25082515

25092516
if self.drafter:
25102517
self.drafter.dummy_run(
@@ -2514,9 +2521,8 @@ def dummy_compute_logits(hidden_states):
25142521
num_reqs=num_reqs,
25152522
num_tokens_across_dp=num_tokens_across_dp,
25162523
aclgraph_runtime_mode=aclgraph_runtime_mode,
2517-
batch_descriptor=batch_descriptor)
2518-
if need_dummy_logits:
2519-
dummy_compute_logits(hidden_states)
2524+
batch_descriptor=batch_descriptor,
2525+
dummy_compute_logits=dummy_drafter_compute_logits)
25202526
if self.in_profile_run and self.dynamic_eplb:
25212527
self.model.clear_all_moe_loads()
25222528
if not self.in_profile_run and self.dynamic_eplb:

0 commit comments

Comments
 (0)