diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index ec3751b2ce6..2eae492353e 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -116,7 +116,8 @@ def dummy_run(self, num_reqs: int = 0, num_tokens_across_dp: Optional[torch.Tensor] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None): + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): moe_comm_type = self.runner._select_moe_comm_method( num_tokens, with_prefill) with set_ascend_forward_context(None, @@ -128,6 +129,7 @@ def dummy_run(self, positions=self.positions[:num_tokens], hidden_states=self.hidden_states[:num_tokens], ) + dummy_compute_logits(self.hidden_states) def generate_token_ids(self, valid_sampled_token_ids: list[list[int]], diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index d8b25e8cd03..168daae6dec 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -114,7 +114,8 @@ def dummy_run(self, num_reqs: int = 0, num_tokens_across_dp=None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None) -> None: + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None) -> None: if not self.torchair_graph_enabled: # TODO: adapt enable_dbo later (num_tokens, num_tokens_across_dp, with_prefill, @@ -188,6 +189,7 @@ def dummy_run(self, self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) + dummy_compute_logits(previous_hidden_states) if with_prefill: break @@ -490,6 +492,7 @@ def _propose( logits = self.model.compute_logits(sample_hidden_states) if lmhead_tp_enable() and num_indices < logits.shape[0]: logits = logits[:num_indices] + last_token_indices = last_token_indices[:num_indices] draft_token_ids = logits.argmax(dim=-1) if self.num_speculative_tokens == 1: @@ -554,7 +557,7 @@ def _propose( # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. exceeds_max_model_len_cpu = exceeds_max_model_len.to( - attn_metadata_i.seq_lens.device, non_blocking=True) + attn_metadata_i.seq_lens.device, non_blocking=False) attn_metadata_i.seq_lens[:batch_size].masked_fill_( exceeds_max_model_len_cpu, 1) # Mask out the slot mappings that exceed the max model length. diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5ad4340663c..74df4e7242e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2466,13 +2466,21 @@ def _dummy_run( need_dummy_logits = (not self.in_profile_run and lmhead_tp_enable()) - if need_dummy_logits: - max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs - dummy_indices = torch.zeros(max_num_reqs_across_dp, - dtype=torch.int32) - - def dummy_compute_logits(hidden_states): - return self.model.compute_logits( + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, + dtype=torch.int32) + + def dummy_compute_logits(hidden_states): + if not need_dummy_logits: + return None + return self.model.compute_logits(hidden_states[dummy_indices]) + + def dummy_drafter_compute_logits(hidden_states): + if not need_dummy_logits or self.drafter is None: + return + if hasattr(self.drafter, "model") and hasattr( + self.drafter.model, "compute_logits"): + return self.drafter.model.compute_logits( hidden_states[dummy_indices]) with set_ascend_forward_context( @@ -2494,8 +2502,7 @@ def dummy_compute_logits(hidden_states): with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) - if need_dummy_logits: - dummy_compute_logits(hidden_states) + dummy_compute_logits(hidden_states) if self.drafter: self.drafter.dummy_run( @@ -2505,10 +2512,8 @@ def dummy_compute_logits(hidden_states): num_reqs=num_reqs, num_tokens_across_dp=num_tokens_across_dp, aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor) - if need_dummy_logits: - self.drafter.model.compute_logits( - hidden_states[dummy_indices]) + batch_descriptor=batch_descriptor, + dummy_compute_logits=dummy_drafter_compute_logits) if self.in_profile_run and self.dynamic_eplb: self.model.clear_all_moe_loads() if not self.in_profile_run and self.dynamic_eplb: