@@ -166,10 +166,12 @@ class ChunkedContextMetadata:
166166 actual_chunk_seq_lengths : list [int ]
167167 actual_seq_lengths_kv : list [int ]
168168 starts : torch .Tensor
169+ chunk_seq_mask_filtered_indices : torch .Tensor
169170 chunked_req_mask : Optional [list [bool ]] = None
170171 local_context_lens_allranks : Optional [list [list [int ]]] = None
171172 cp_kv_recover_idx_for_chunk : Optional [list [int ]] = None
172173 kv_inverse_idx_for_chunk : Optional [list [int ]] = None
174+ batch_chunk_seq_mask : Optional [list [bool ]] = None
173175
174176 """ Prefill Specific Metadata for Ascend"""
175177 pcp_metadata : Optional [AscendPCPMetadata ] = None
@@ -401,6 +403,14 @@ def build(
401403 cp_kv_recover_idx_for_chunk .to (torch .float32 )
402404 ) if cp_kv_recover_idx_for_chunk is not None else None
403405
406+ batch_chunk_seq_mask = (
407+ local_context_lens_allranks [:, self .pcp_rank ,
408+ self .dcp_rank ] == 0 )
409+ batch_chunk_seq_mask = torch .repeat_interleave (
410+ batch_chunk_seq_mask ,
411+ repeats = (query_lens * self .pcp_size ).to (self .device ))
412+ chunk_seq_mask_filtered_indices = filter_chunked_req_indices (
413+ query_lens , chunked_req_mask ).to (self .device )
404414 chunked_context_metadata = \
405415 AscendMetadataForPrefill .ChunkedContextMetadata (
406416 actual_chunk_seq_lengths = torch .cumsum (query_lens * pcp_size , dim = 0 ),
@@ -409,7 +419,9 @@ def build(
409419 starts = local_chunk_starts ,
410420 local_context_lens_allranks = local_context_lens_allranks ,
411421 cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk ,
412- kv_inverse_idx_for_chunk = kv_inverse_idx_for_chunk
422+ kv_inverse_idx_for_chunk = kv_inverse_idx_for_chunk ,
423+ batch_chunk_seq_mask = batch_chunk_seq_mask ,
424+ chunk_seq_mask_filtered_indices = chunk_seq_mask_filtered_indices
413425 )
414426 attn_mask_seqlens = common_long_seq_metadata .attn_mask_seqlens
415427 head_attn_nomask_seqlens = common_long_seq_metadata .head_attn_nomask_seqlens
@@ -571,10 +583,15 @@ def full_graph_attention(self,
571583 query : torch .Tensor ,
572584 key : torch .Tensor ,
573585 value : torch .Tensor ,
586+ kv_cache : Tuple [torch .Tensor ],
574587 attn_metadata : AscendMetadata ,
575588 output : torch .Tensor ,
576589 num_tokens = 0 ):
577- if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
590+ if self .pcp_size * self .dcp_size > 1 :
591+ intermediate_output = self ._forward_pcp_dcp (
592+ query , key , value , kv_cache , attn_metadata , output )
593+ return intermediate_output , query .shape [0 ]
594+ elif attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
578595 block_size = 128
579596 block_table = None
580597 actual_seq_lengths_kv = attn_metadata .query_start_loc_list
@@ -1276,9 +1293,7 @@ def _update_chunk_attn_out_lse_with_current_attn_out_lse(
12761293 self .pcp_rank * num_tokens :(self .pcp_rank + 1 ) * num_tokens , :, :]
12771294
12781295 assert attn_output_full_chunk .shape == current_attn_output_prefill .shape and attn_lse_full_chunk .shape == current_attn_lse_prefill .shape
1279- seq_len = attn_metadata .query_lens .detach ().clone ()
1280- filtered_indices = filter_chunked_req_indices (
1281- seq_len , attn_metadata .prefill .chunked_context .chunked_req_mask )
1296+ filtered_indices = attn_metadata .prefill .chunked_context .chunk_seq_mask_filtered_indices
12821297
12831298 attn_output_prefill_filtered = current_attn_output_prefill [
12841299 filtered_indices , :, :]
@@ -1322,9 +1337,11 @@ def _compute_prefill_context(self, query: torch.Tensor,
13221337
13231338 local_chunked_kv_lens_rank = local_chunked_kv_lens [:, self .pcp_rank ,
13241339 self .dcp_rank ]
1340+ total_toks = local_chunked_kv_lens_rank .sum ()
13251341
13261342 key , value = self ._load_kv_for_chunk (attn_metadata , kv_cache ,
1327- local_chunked_kv_lens_rank , query )
1343+ local_chunked_kv_lens_rank , query ,
1344+ total_toks )
13281345 if self .dcp_size > 1 :
13291346 num_heads = self .num_heads * self .dcp_size
13301347 else :
@@ -1340,7 +1357,7 @@ def _compute_prefill_context(self, query: torch.Tensor,
13401357 dtype = torch .float32 ,
13411358 device = query .device )
13421359
1343- if not torch . all ( local_chunked_kv_lens_rank == 0 ). item () :
1360+ if total_toks > 0 :
13441361 prefix_chunk_output , prefix_chunk_lse = torch .ops .npu .npu_fused_infer_attention_score (
13451362 query ,
13461363 key ,
@@ -1358,6 +1375,14 @@ def _compute_prefill_context(self, query: torch.Tensor,
13581375 actual_seq_lengths_kv ,
13591376 actual_seq_lengths = attn_metadata .prefill .chunked_context .
13601377 actual_chunk_seq_lengths )
1378+ batch_chunk_seq_mask = attn_metadata .prefill .chunked_context .batch_chunk_seq_mask
1379+ out_mask = batch_chunk_seq_mask [:, None , None ].expand_as (
1380+ prefix_chunk_output )
1381+ prefix_chunk_output = torch .where (out_mask , 0 , prefix_chunk_output )
1382+ lse_mask = batch_chunk_seq_mask [:, None ,
1383+ None ].expand_as (prefix_chunk_lse )
1384+ prefix_chunk_lse = torch .where (lse_mask , - torch .inf ,
1385+ prefix_chunk_lse )
13611386
13621387 prefix_output , prefix_lse = self ._update_chunk_attn_out_lse (
13631388 prefix_chunk_output , prefix_chunk_lse )
@@ -1413,14 +1438,12 @@ def _update_chunk_attn_out_lse(self, prefix_chunk_output,
14131438 return prefix_output , prefix_lse
14141439
14151440 def _load_kv_for_chunk (self , attn_metadata , kv_cache ,
1416- local_chunked_kv_lens_rank , query ):
1441+ local_chunked_kv_lens_rank , query , total_toks ):
14171442 cache_key = kv_cache [0 ]
14181443 cache_value = kv_cache [1 ]
14191444 num_heads = cache_key .size (2 )
14201445 head_size = kv_cache [0 ].size (- 1 )
14211446
1422- total_toks = local_chunked_kv_lens_rank .sum ()
1423-
14241447 key = torch .empty (total_toks ,
14251448 num_heads ,
14261449 head_size ,
@@ -1579,7 +1602,7 @@ def forward(
15791602 query , attn_metadata , output )
15801603 else :
15811604 intermediate_output , num_tokens = self .full_graph_attention (
1582- query , key , value , attn_metadata , output )
1605+ query , key , value , kv_cache , attn_metadata , output )
15831606 output [:num_tokens ] = intermediate_output [:num_tokens ]
15841607
15851608 return output
0 commit comments