Skip to content

Commit a3e9673

Browse files
Delphine-NicDelphine-NicDelphine-Nic
authored
[long seq feat]GQA support long-prefill-token-threshold and fixbug (#4209)
### What this PR does / why we need it? GQA chunk prefill with pcp and dcp support long-prefill-token-threshold The markdown format results is as below: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8kdataset | - | accuracy | gen | 96.13 | - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b --------- Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: Delphine-Nic <t00608739@china.huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: Delphine-Nic <t00608739@china.huawei.com>
1 parent 97daf7f commit a3e9673

File tree

2 files changed

+51
-26
lines changed

2 files changed

+51
-26
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,12 @@ class ChunkedContextMetadata:
166166
actual_chunk_seq_lengths: list[int]
167167
actual_seq_lengths_kv: list[int]
168168
starts: torch.Tensor
169+
chunk_seq_mask_filtered_indices: torch.Tensor
169170
chunked_req_mask: Optional[list[bool]] = None
170171
local_context_lens_allranks: Optional[list[list[int]]] = None
171172
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
172173
kv_inverse_idx_for_chunk: Optional[list[int]] = None
174+
batch_chunk_seq_mask: Optional[list[bool]] = None
173175

174176
""" Prefill Specific Metadata for Ascend"""
175177
pcp_metadata: Optional[AscendPCPMetadata] = None
@@ -401,6 +403,14 @@ def build(
401403
cp_kv_recover_idx_for_chunk.to(torch.float32)
402404
) if cp_kv_recover_idx_for_chunk is not None else None
403405

406+
batch_chunk_seq_mask = (
407+
local_context_lens_allranks[:, self.pcp_rank,
408+
self.dcp_rank] == 0)
409+
batch_chunk_seq_mask = torch.repeat_interleave(
410+
batch_chunk_seq_mask,
411+
repeats=(query_lens * self.pcp_size).to(self.device))
412+
chunk_seq_mask_filtered_indices = filter_chunked_req_indices(
413+
query_lens, chunked_req_mask).to(self.device)
404414
chunked_context_metadata = \
405415
AscendMetadataForPrefill.ChunkedContextMetadata(
406416
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
@@ -409,7 +419,9 @@ def build(
409419
starts=local_chunk_starts,
410420
local_context_lens_allranks=local_context_lens_allranks,
411421
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
412-
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk
422+
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
423+
batch_chunk_seq_mask=batch_chunk_seq_mask,
424+
chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices
413425
)
414426
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
415427
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
@@ -571,10 +583,15 @@ def full_graph_attention(self,
571583
query: torch.Tensor,
572584
key: torch.Tensor,
573585
value: torch.Tensor,
586+
kv_cache: Tuple[torch.Tensor],
574587
attn_metadata: AscendMetadata,
575588
output: torch.Tensor,
576589
num_tokens=0):
577-
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
590+
if self.pcp_size * self.dcp_size > 1:
591+
intermediate_output = self._forward_pcp_dcp(
592+
query, key, value, kv_cache, attn_metadata, output)
593+
return intermediate_output, query.shape[0]
594+
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
578595
block_size = 128
579596
block_table = None
580597
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
@@ -1276,9 +1293,7 @@ def _update_chunk_attn_out_lse_with_current_attn_out_lse(
12761293
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
12771294

12781295
assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape
1279-
seq_len = attn_metadata.query_lens.detach().clone()
1280-
filtered_indices = filter_chunked_req_indices(
1281-
seq_len, attn_metadata.prefill.chunked_context.chunked_req_mask)
1296+
filtered_indices = attn_metadata.prefill.chunked_context.chunk_seq_mask_filtered_indices
12821297

12831298
attn_output_prefill_filtered = current_attn_output_prefill[
12841299
filtered_indices, :, :]
@@ -1322,9 +1337,11 @@ def _compute_prefill_context(self, query: torch.Tensor,
13221337

13231338
local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank,
13241339
self.dcp_rank]
1340+
total_toks = local_chunked_kv_lens_rank.sum()
13251341

13261342
key, value = self._load_kv_for_chunk(attn_metadata, kv_cache,
1327-
local_chunked_kv_lens_rank, query)
1343+
local_chunked_kv_lens_rank, query,
1344+
total_toks)
13281345
if self.dcp_size > 1:
13291346
num_heads = self.num_heads * self.dcp_size
13301347
else:
@@ -1340,7 +1357,7 @@ def _compute_prefill_context(self, query: torch.Tensor,
13401357
dtype=torch.float32,
13411358
device=query.device)
13421359

1343-
if not torch.all(local_chunked_kv_lens_rank == 0).item():
1360+
if total_toks > 0:
13441361
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
13451362
query,
13461363
key,
@@ -1358,6 +1375,14 @@ def _compute_prefill_context(self, query: torch.Tensor,
13581375
actual_seq_lengths_kv,
13591376
actual_seq_lengths=attn_metadata.prefill.chunked_context.
13601377
actual_chunk_seq_lengths)
1378+
batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask
1379+
out_mask = batch_chunk_seq_mask[:, None, None].expand_as(
1380+
prefix_chunk_output)
1381+
prefix_chunk_output = torch.where(out_mask, 0, prefix_chunk_output)
1382+
lse_mask = batch_chunk_seq_mask[:, None,
1383+
None].expand_as(prefix_chunk_lse)
1384+
prefix_chunk_lse = torch.where(lse_mask, -torch.inf,
1385+
prefix_chunk_lse)
13611386

13621387
prefix_output, prefix_lse = self._update_chunk_attn_out_lse(
13631388
prefix_chunk_output, prefix_chunk_lse)
@@ -1413,14 +1438,12 @@ def _update_chunk_attn_out_lse(self, prefix_chunk_output,
14131438
return prefix_output, prefix_lse
14141439

14151440
def _load_kv_for_chunk(self, attn_metadata, kv_cache,
1416-
local_chunked_kv_lens_rank, query):
1441+
local_chunked_kv_lens_rank, query, total_toks):
14171442
cache_key = kv_cache[0]
14181443
cache_value = kv_cache[1]
14191444
num_heads = cache_key.size(2)
14201445
head_size = kv_cache[0].size(-1)
14211446

1422-
total_toks = local_chunked_kv_lens_rank.sum()
1423-
14241447
key = torch.empty(total_toks,
14251448
num_heads,
14261449
head_size,
@@ -1579,7 +1602,7 @@ def forward(
15791602
query, attn_metadata, output)
15801603
else:
15811604
intermediate_output, num_tokens = self.full_graph_attention(
1582-
query, key, value, attn_metadata, output)
1605+
query, key, value, kv_cache, attn_metadata, output)
15831606
output[:num_tokens] = intermediate_output[:num_tokens]
15841607

15851608
return output

vllm_ascend/worker/model_runner_v1.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -294,21 +294,23 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
294294
self.scheduler_config = vllm_config.scheduler_config
295295
self.speculative_config = vllm_config.speculative_config
296296
self.block_size = vllm_config.cache_config.block_size
297-
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
298-
self.block_size)
299-
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
300-
decode_max_num_seqs = getattr(self.scheduler_config,
301-
'decode_max_num_seqs', 0)
302-
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
303-
decode_max_num_seqs)
304297
self.dp_size = vllm_config.parallel_config.data_parallel_size
305298
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
299+
self.dcp_size = get_dcp_group().world_size
300+
self.dcp_rank = get_dcp_group().rank_in_group
306301
self.pcp_size = get_prefill_context_model_parallel_world_size(
307302
) if prefill_context_parallel_enable() else 1
308303
self.pcp_rank = get_prefill_context_model_parallel_rank(
309304
) if self.pcp_size > 1 else 0
310-
self.dcp_size = get_dcp_group().world_size
311-
self.dcp_rank = get_dcp_group().rank_in_group
305+
decode_max_num_seqs = getattr(self.scheduler_config,
306+
'decode_max_num_seqs', 0)
307+
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
308+
decode_max_num_seqs)
309+
if self.pcp_size > 1:
310+
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
311+
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
312+
self.block_size)
313+
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
312314
self.device = device
313315
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
314316
self.prefetch_stream = torch.npu.Stream(device=device)
@@ -1007,23 +1009,20 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
10071009

10081010
def _make_attention_mask(self, seq_lens, position,
10091011
attn_state) -> torch.Tensor:
1012+
# pcp situation.
10101013
if self.pcp_size > 1:
10111014
return None
10121015
if self.attn_mask_builder is None:
10131016
raise ValueError("Attn mask builder is None")
1017+
# dcp situation.
10141018
if self.dcp_size > 1:
10151019
return self.attn_mask_builder.get_splitfuse_attn_mask()
10161020
# Pooling situation.
10171021
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
10181022
return self.attn_mask_builder.get_pooling_mask(self.device)
10191023
# Chunk Prefill situation.
10201024
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
1021-
if self.dcp_size > 1:
1022-
max_seq_len = max(seq_lens.max().item(), 0)
1023-
return self.attn_mask_builder.get_attn_mask(
1024-
max_seq_len, self.dtype, self.device)
1025-
else:
1026-
return self.attn_mask_builder.get_splitfuse_attn_mask()
1025+
return self.attn_mask_builder.get_splitfuse_attn_mask()
10271026

10281027
# Prefill without cache situation.
10291028
elif attn_state == AscendAttentionState.PrefillNoCache:
@@ -1039,6 +1038,9 @@ def _make_attention_mask(self, seq_lens, position,
10391038
return None
10401039

10411040
def _make_fia_attention_mask(self) -> torch.Tensor:
1041+
# pcp situation.
1042+
if self.pcp_size > 1:
1043+
return None
10421044
if self.attn_mask_builder is None:
10431045
raise ValueError("Attn mask builder is None")
10441046
return self.attn_mask_builder.get_splitfuse_attn_mask()

0 commit comments

Comments
 (0)