Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def dummy_run(self,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None):
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
with set_ascend_forward_context(None,
Expand All @@ -128,6 +129,7 @@ def dummy_run(self,
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
dummy_compute_logits(self.hidden_states)

def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],
Expand Down
7 changes: 5 additions & 2 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def dummy_run(self,
num_reqs: int = 0,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None:
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None) -> None:
if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill,
Expand Down Expand Up @@ -188,6 +189,7 @@ def dummy_run(self,
self.model(input_ids=input_ids,
positions=positions,
hidden_states=previous_hidden_states)
dummy_compute_logits(previous_hidden_states)
if with_prefill:
break

Expand Down Expand Up @@ -490,6 +492,7 @@ def _propose(
logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
last_token_indices = last_token_indices[:num_indices]
draft_token_ids = logits.argmax(dim=-1)

if self.num_speculative_tokens == 1:
Expand Down Expand Up @@ -554,7 +557,7 @@ def _propose(
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
attn_metadata_i.seq_lens.device, non_blocking=True)
attn_metadata_i.seq_lens.device, non_blocking=False)
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len_cpu, 1)
# Mask out the slot mappings that exceed the max model length.
Expand Down
31 changes: 18 additions & 13 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2466,13 +2466,21 @@ def _dummy_run(
need_dummy_logits = (not self.in_profile_run
and lmhead_tp_enable())

if need_dummy_logits:
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)

def dummy_compute_logits(hidden_states):
return self.model.compute_logits(
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)
Comment on lines +2470 to +2471
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dummy_indices tensor is created on the CPU by default and is then implicitly transferred to the NPU when used for indexing the hidden_states tensor. This host-to-device copy occurs in each dummy run, which can be inefficient. To improve performance, dummy_indices should be created directly on the NPU.

Suggested change
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32, device=self.device)


def dummy_compute_logits(hidden_states):
if not need_dummy_logits:
return None
return self.model.compute_logits(hidden_states[dummy_indices])

def dummy_drafter_compute_logits(hidden_states):
if not need_dummy_logits or self.drafter is None:
return
if hasattr(self.drafter, "model") and hasattr(
self.drafter.model, "compute_logits"):
return self.drafter.model.compute_logits(
hidden_states[dummy_indices])

with set_ascend_forward_context(
Expand All @@ -2494,8 +2502,7 @@ def dummy_compute_logits(hidden_states):
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,
inputs_embeds)
if need_dummy_logits:
dummy_compute_logits(hidden_states)
dummy_compute_logits(hidden_states)

if self.drafter:
self.drafter.dummy_run(
Expand All @@ -2505,10 +2512,8 @@ def dummy_compute_logits(hidden_states):
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor)
if need_dummy_logits:
self.drafter.model.compute_logits(
hidden_states[dummy_indices])
batch_descriptor=batch_descriptor,
dummy_compute_logits=dummy_drafter_compute_logits)
if self.in_profile_run and self.dynamic_eplb:
self.model.clear_all_moe_loads()
if not self.in_profile_run and self.dynamic_eplb:
Expand Down
Loading