From 8126fa1004b9538906554b2257558ac025581577 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Mon, 3 Nov 2025 11:28:56 -0500 Subject: [PATCH 01/10] remove prefill support from 3d kernel Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 87 ++++--------------- 1 file changed, 19 insertions(+), 68 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 565be1c39bec..dfdcf82745ce 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -36,14 +36,13 @@ def find_seq_idx( target_idx, num_seqs, BLOCK_Q: tl.constexpr, - use_q_block_mode: tl.constexpr, ): left: tl.int32 = 0 right = num_seqs while left < right: mid = (left + right) // 2 val = tl.load(query_start_len_ptr + mid) - mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + mid_val = val // BLOCK_Q + mid if mid_val <= target_idx: left = mid + 1 @@ -106,7 +105,7 @@ def kernel_unified_attention_2d( kv_head_idx = tl.program_id(1) seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q ) q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx @@ -393,32 +392,13 @@ def kernel_unified_attention_3d( stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): - q_block_global_idx = tl.program_id(0) + seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) - seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True - ) - - q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx - - q_block_local_idx = q_block_global_idx - q_block_start_idx - - cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - - cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index - - if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: - return - # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -432,9 +412,9 @@ def kernel_unified_attention_3d( offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_t = tl.arange(0, TILE_SIZE) - query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + query_pos = offs_m // num_queries_per_kv - query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_0 = seq_idx + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv query_offset = ( query_offset_0[:, None] * query_stride_0 @@ -443,7 +423,7 @@ def kernel_unified_attention_3d( ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) - query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < 1, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) # Q : (BLOCK_M, HEAD_SIZE_PADDED) @@ -471,7 +451,7 @@ def kernel_unified_attention_3d( acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) # context length for this particular sequences - context_len = seq_len - cur_batch_query_len + context_len = seq_len - 1 # alibi slope for this head if USE_ALIBI_SLOPES: @@ -485,23 +465,7 @@ def kernel_unified_attention_3d( qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 ) # shape: [BLOCK_M] - # compute the length of the longest sequence prefix spanned by any - # query token in the current q_block (q_block_local_idx) - max_seq_prefix_len = ( - context_len - + q_block_local_idx * BLOCK_Q - + (BLOCK_M - 1) // num_queries_per_kv - + 1 - ) - - # adjust for potential padding in the last q_block by considering the - # actual sequence length - max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - - # calculate the number of tiles that need to be processed to - # cover the longest sequence prefix (due to causal masking, tiles beyond - # this prefix can be skipped) - num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + num_tiles = cdiv_fn(seq_len, TILE_SIZE) # iterate through tiles within current segment for j in range( @@ -509,7 +473,7 @@ def kernel_unified_attention_3d( min((segm_idx + 1) * tiles_per_segment, num_tiles), ): seq_offset = j * TILE_SIZE + offs_t - tile_mask = seq_offset < max_seq_prefix_len + tile_mask = seq_offset < seq_len physical_block_idx = tl.load( block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE @@ -650,7 +614,6 @@ def reduce_segments( segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] seq_lens_ptr, # [num_seqs] - num_seqs, # int num_query_heads: tl.constexpr, # int out_scale_inv, # float32 output_stride_0: tl.int64, # int @@ -659,20 +622,14 @@ def reduce_segments( TILE_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): - query_token_idx = tl.program_id(0) + seq_idx = tl.program_id(0) query_head_idx = tl.program_id(1) - seq_idx = find_seq_idx( - query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False - ) - # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -689,7 +646,7 @@ def reduce_segments( # load segment maxima segm_offset = ( - query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + query_head_idx * NUM_SEGMENTS_PER_SEQ + tl.arange(0, NUM_SEGMENTS_PER_SEQ) ) @@ -703,7 +660,7 @@ def reduce_segments( # load, rescale, and add segment attention outputs segm_output_offset = ( - query_token_idx.to(tl.int64) + seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED @@ -725,7 +682,7 @@ def reduce_segments( # write result output_offset = ( - query_token_idx * output_stride_0 + seq_idx * output_stride_0 + query_head_idx * output_stride_1 + tl.arange(0, HEAD_SIZE_PADDED) ) @@ -793,7 +750,7 @@ def unified_attention( TILE_SIZE_PREFILL = 32 TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 - # if batch contains a prefill + # if batch contains a prefill or launch grid size is larger than threshold if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: kernel_unified_attention_2d[ ( @@ -852,7 +809,7 @@ def unified_attention( NUM_SEGMENTS = 16 segm_output = torch.empty( - q.shape[0], + num_seqs, num_query_heads, NUM_SEGMENTS, triton.next_power_of_2(head_size), @@ -860,21 +817,21 @@ def unified_attention( device=q.device, ) segm_max = torch.empty( - q.shape[0], + num_seqs, num_query_heads, NUM_SEGMENTS, dtype=torch.float32, device=q.device, ) segm_expsum = torch.empty( - q.shape[0], + num_seqs, num_query_heads, NUM_SEGMENTS, dtype=torch.float32, device=q.device, ) - kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + kernel_unified_attention_3d[(num_seqs, num_kv_heads, NUM_SEGMENTS)]( segm_output_ptr=segm_output, segm_max_ptr=segm_max, segm_expsum_ptr=segm_expsum, @@ -913,19 +870,15 @@ def unified_attention( stride_v_cache_1=v.stride(1), stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, BLOCK_M=BLOCK_M, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ) - reduce_segments[(q.shape[0], num_query_heads)]( + reduce_segments[(num_seqs, num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, segm_max_ptr=segm_max, segm_expsum_ptr=segm_expsum, seq_lens_ptr=seqused_k, - num_seqs=num_seqs, num_query_heads=num_query_heads, out_scale_inv=1 / output_scale if output_scale is not None else 1.0, output_stride_0=out.stride(0), @@ -934,8 +887,6 @@ def unified_attention( TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, USE_FP8=output_scale is not None, ) From 4a54c088e23d2738d14a30576a68bbf179b1466b Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Mon, 3 Nov 2025 12:45:59 -0500 Subject: [PATCH 02/10] formatting Signed-off-by: Jan van Lunteren --- vllm/attention/ops/triton_unified_attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index dfdcf82745ce..f4925ee91181 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -104,9 +104,7 @@ def kernel_unified_attention_2d( q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) - seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q - ) + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q) q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx From 84c5cd79d769af1ba365513d81ee2213f6a21cf3 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Fri, 7 Nov 2025 12:53:43 +0000 Subject: [PATCH 03/10] adapt 3D kernel for full CUDA Graph support Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 52 +++++------------ vllm/v1/attention/backends/triton_attn.py | 57 ++++++++++++++++++- 2 files changed, 71 insertions(+), 38 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index f4925ee91181..099bb35f9ecc 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -704,6 +704,11 @@ def unified_attention( q_descale, k_descale, v_descale, + seq_threshold_3D, + num_par_softmax_segments, + softmax_segm_output, + softmax_segm_max, + softmax_segm_expsum, alibi_slopes=None, output_scale=None, qq_bias=None, @@ -749,7 +754,7 @@ def unified_attention( TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 # if batch contains a prefill or launch grid size is larger than threshold - if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + if max_seqlen_q > 1 or num_seqs > seq_threshold_3D: kernel_unified_attention_2d[ ( total_num_q_blocks, @@ -802,37 +807,10 @@ def unified_attention( USE_FP8=output_scale is not None, ) else: - # for initial version, NUM_SEGMENTS = 16 is chosen as a default - # value that showed good performance in tests - NUM_SEGMENTS = 16 - - segm_output = torch.empty( - num_seqs, - num_query_heads, - NUM_SEGMENTS, - triton.next_power_of_2(head_size), - dtype=torch.float32, - device=q.device, - ) - segm_max = torch.empty( - num_seqs, - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - segm_expsum = torch.empty( - num_seqs, - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - - kernel_unified_attention_3d[(num_seqs, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, + kernel_unified_attention_3d[(num_seqs, num_kv_heads, num_par_softmax_segments)]( + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, @@ -869,13 +847,13 @@ def unified_attention( stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, ) reduce_segments[(num_seqs, num_query_heads)]( output_ptr=out, - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, seq_lens_ptr=seqused_k, num_query_heads=num_query_heads, out_scale_inv=1 / output_scale if output_scale is not None else 1.0, @@ -885,6 +863,6 @@ def unified_attention( TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, USE_FP8=output_scale is not None, ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b1d34dbfd172..869c03e510d4 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,7 +18,7 @@ triton_reshape_and_cache_flash, ) from vllm.attention.ops.triton_unified_attention import unified_attention -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -35,6 +35,12 @@ logger = init_logger(__name__) +# constants +MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel +NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments + + + @dataclass class TritonAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -53,6 +59,12 @@ class TritonAttentionMetadata: block_table: torch.Tensor slot_mapping: torch.Tensor + seq_threshold_3D: int + num_par_softmax_segments: int + softmax_segm_output: torch.Tensor + softmax_segm_max: torch.Tensor + softmax_segm_expsum: torch.Tensor + # For cascade attention. use_cascade: bool common_prefix_len: int @@ -86,6 +98,33 @@ def __init__( self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() + # Check if CUDA Graphs are enabled for decode + self.decode_cudagraph_enabled = self.vllm_config.compilation_config.cudagraph_mode in (CUDAGraphMode.FULL_AND_PIECEWISE, CUDAGraphMode.FULL_DECODE_ONLY, CUDAGraphMode.FULL) + + # Set initial value for the threshold for the number of sequences used + # to select between the 2D and 3D kernels for decode. + self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv + if self.decode_cudagraph_enabled: + # Select the CUDA Graph capture size closest to self.seq_threshold_3D + # as threshold. This ensures that each captured graph covers the + # correct execution path. + upd_seq_threshold_3D = min(self.vllm_config.compilation_config.cudagraph_capture_sizes, key=lambda x: abs(x - self.seq_threshold_3D)) + + # If the updated threshold becomes significantly larger than the + # initial value, it is reset to zero. This enforces the use of the + # 2D kernel only and ensures that the size of the allocated + # intermediate structures remains bounded. + if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D: + self.seq_threshold_3D = upd_seq_threshold_3D + else: + self.seq_threshold_3D = 0 + + self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS + self.softmax_segm_output = torch.empty((self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments, self.headdim), dtype=torch.float32, device=device) + self.softmax_segm_max = torch.empty((self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), dtype=torch.float32, device=device) + self.softmax_segm_expsum = torch.empty((self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), dtype=torch.float32, device=device) + + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: @@ -142,6 +181,11 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, + seq_threshold_3D=self.seq_threshold_3D, + num_par_softmax_segments=self.num_par_softmax_segments, + softmax_segm_output=self.softmax_segm_output, + softmax_segm_max=self.softmax_segm_max, + softmax_segm_expsum=self.softmax_segm_expsum, ) return attn_metadata @@ -346,6 +390,12 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + seq_threshold_3D = attn_metadata.seq_threshold_3D + num_par_softmax_segments = attn_metadata.num_par_softmax_segments + softmax_segm_output = attn_metadata.softmax_segm_output + softmax_segm_max = attn_metadata.softmax_segm_max + softmax_segm_expsum = attn_metadata.softmax_segm_expsum + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) unified_attention( @@ -366,6 +416,11 @@ def forward( q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + seq_threshold_3D = seq_threshold_3D, + num_par_softmax_segments = num_par_softmax_segments, + softmax_segm_output = softmax_segm_output, + softmax_segm_max = softmax_segm_max, + softmax_segm_expsum = softmax_segm_expsum, sinks=self.sinks, output_scale=output_scale, ) From 39f52b413722ade07b870ed2daf1f68c152155eb Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Fri, 7 Nov 2025 12:59:46 +0000 Subject: [PATCH 04/10] formatting Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 55 ++++++++++++++++------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 869c03e510d4..68d3e3a3c484 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -36,9 +36,8 @@ # constants -MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel -NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments - +MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel +NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments @dataclass @@ -99,7 +98,14 @@ def __init__( self.headdim = model_config.get_head_size() # Check if CUDA Graphs are enabled for decode - self.decode_cudagraph_enabled = self.vllm_config.compilation_config.cudagraph_mode in (CUDAGraphMode.FULL_AND_PIECEWISE, CUDAGraphMode.FULL_DECODE_ONLY, CUDAGraphMode.FULL) + self.decode_cudagraph_enabled = ( + self.vllm_config.compilation_config.cudagraph_mode + in ( + CUDAGraphMode.FULL_AND_PIECEWISE, + CUDAGraphMode.FULL_DECODE_ONLY, + CUDAGraphMode.FULL, + ) + ) # Set initial value for the threshold for the number of sequences used # to select between the 2D and 3D kernels for decode. @@ -107,8 +113,11 @@ def __init__( if self.decode_cudagraph_enabled: # Select the CUDA Graph capture size closest to self.seq_threshold_3D # as threshold. This ensures that each captured graph covers the - # correct execution path. - upd_seq_threshold_3D = min(self.vllm_config.compilation_config.cudagraph_capture_sizes, key=lambda x: abs(x - self.seq_threshold_3D)) + # correct execution path. + upd_seq_threshold_3D = min( + self.vllm_config.compilation_config.cudagraph_capture_sizes, + key=lambda x: abs(x - self.seq_threshold_3D), + ) # If the updated threshold becomes significantly larger than the # initial value, it is reset to zero. This enforces the use of the @@ -120,10 +129,26 @@ def __init__( self.seq_threshold_3D = 0 self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS - self.softmax_segm_output = torch.empty((self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments, self.headdim), dtype=torch.float32, device=device) - self.softmax_segm_max = torch.empty((self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), dtype=torch.float32, device=device) - self.softmax_segm_expsum = torch.empty((self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), dtype=torch.float32, device=device) - + self.softmax_segm_output = torch.empty( + ( + self.seq_threshold_3D, + self.num_heads_q, + self.num_par_softmax_segments, + self.headdim, + ), + dtype=torch.float32, + device=device, + ) + self.softmax_segm_max = torch.empty( + (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), + dtype=torch.float32, + device=device, + ) + self.softmax_segm_expsum = torch.empty( + (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), + dtype=torch.float32, + device=device, + ) def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata @@ -416,11 +441,11 @@ def forward( q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), - seq_threshold_3D = seq_threshold_3D, - num_par_softmax_segments = num_par_softmax_segments, - softmax_segm_output = softmax_segm_output, - softmax_segm_max = softmax_segm_max, - softmax_segm_expsum = softmax_segm_expsum, + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, sinks=self.sinks, output_scale=output_scale, ) From 3102959ddb338b8844a27e86b3b1cf1a360c1ebf Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Fri, 7 Nov 2025 10:57:59 -0500 Subject: [PATCH 05/10] update unit test Signed-off-by: Jan van Lunteren --- .../test_triton_unified_attention.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index bf4d2179af5f..a9575270bee8 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -22,6 +22,10 @@ # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] +# 0: use 2D kernel for decode +# 8: use 3D kernel for decode +SEQ_THRESHOLD_3D_VALUES = [0, 8] + def ref_paged_attn( query: torch.Tensor, @@ -92,6 +96,7 @@ def ref_paged_attn( @pytest.mark.parametrize("soft_cap", [None, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) +@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES) @torch.inference_mode() def test_triton_unified_attn( seq_lens: list[tuple[int, int]], @@ -103,6 +108,7 @@ def test_triton_unified_attn( soft_cap: float | None, num_blocks: int, q_dtype: torch.dtype | None, + seq_threshold_3D: int, ) -> None: torch.set_default_device("cuda") @@ -152,6 +158,20 @@ def test_triton_unified_attn( k_descale = torch.rand(scale_shape, dtype=torch.float32) v_descale = torch.rand(scale_shape, dtype=torch.float32) + num_par_softmax_segments = 16 + softmax_segm_output = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size), + dtype=torch.float32, + ) + softmax_segm_max = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments), + dtype=torch.float32, + ) + softmax_segm_expsum = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments), + dtype=torch.float32, + ) + unified_attention( q=maybe_quantized_query, k=maybe_quantized_key_cache, @@ -169,6 +189,11 @@ def test_triton_unified_attn( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, ) ref_output = ref_paged_attn( From 9bcc1fbcfee5f2f903a6fd0424c95337b110e623 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Fri, 7 Nov 2025 11:12:49 -0500 Subject: [PATCH 06/10] corrected comment Signed-off-by: Jan van Lunteren --- vllm/attention/ops/triton_unified_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 099bb35f9ecc..911ae8944f57 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -753,7 +753,7 @@ def unified_attention( TILE_SIZE_PREFILL = 32 TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 - # if batch contains a prefill or launch grid size is larger than threshold + # if batch contains a prefill or number of sequences is larger than threshold if max_seqlen_q > 1 or num_seqs > seq_threshold_3D: kernel_unified_attention_2d[ ( From 53d7b8bd9077853ba3db4142d674e22640213b01 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Mon, 10 Nov 2025 03:25:49 -0500 Subject: [PATCH 07/10] added check for empty cudagraph_capture_sizes Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 36 +++++++++++++---------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 9e5369cfbacd..f8717b763760 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -110,22 +110,28 @@ def __init__( # to select between the 2D and 3D kernels for decode. self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv if self.decode_cudagraph_enabled: - # Select the CUDA Graph capture size closest to self.seq_threshold_3D - # as threshold. This ensures that each captured graph covers the - # correct execution path. - upd_seq_threshold_3D = min( - self.vllm_config.compilation_config.cudagraph_capture_sizes, - key=lambda x: abs(x - self.seq_threshold_3D), - ) - - # If the updated threshold becomes significantly larger than the - # initial value, it is reset to zero. This enforces the use of the - # 2D kernel only and ensures that the size of the allocated - # intermediate structures remains bounded. - if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D: - self.seq_threshold_3D = upd_seq_threshold_3D - else: + capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + if not capture_sizes: + # If no CUDA Graph capture sizes are specified, the threshold + # is reset to zero, forcing the 2D kernel to be used. self.seq_threshold_3D = 0 + else: + # Select the CUDA Graph capture size closest to self.seq_threshold_3D + # as threshold. This ensures that each captured graph covers the + # correct execution path. + upd_seq_threshold_3D = min( + capture_sizes, + key=lambda x: abs(x - self.seq_threshold_3D), + ) + + # If the updated threshold becomes significantly larger than the + # initial value, it is reset to zero. This enforces the use of the + # 2D kernel only and ensures that the size of the allocated + # intermediate structures remains bounded. + if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D: + self.seq_threshold_3D = upd_seq_threshold_3D + else: + self.seq_threshold_3D = 0 self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS self.softmax_segm_output = torch.empty( From a70bf68be6018e96ecb8858c2d299e222e0602f6 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Mon, 10 Nov 2025 04:20:02 -0500 Subject: [PATCH 08/10] allocate softmax buffers with padded head dimension Signed-off-by: Jan van Lunteren --- tests/kernels/attention/test_triton_unified_attention.py | 3 ++- vllm/attention/ops/triton_unified_attention.py | 4 ++-- vllm/v1/attention/backends/triton_attn.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index a9575270bee8..a0f9952713fd 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -159,8 +159,9 @@ def test_triton_unified_attn( v_descale = torch.rand(scale_shape, dtype=torch.float32) num_par_softmax_segments = 16 + head_size_padded = 1 << (head_size - 1).bit_length() # next power of 2 value softmax_segm_output = torch.empty( - (seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size), + (seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded), dtype=torch.float32, ) softmax_segm_max = torch.empty( diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 911ae8944f57..8e1e92f118a8 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -352,7 +352,7 @@ def kernel_unified_attention_2d( @triton.jit def kernel_unified_attention_3d( segm_output_ptr, - # [num_tokens, num_query_heads, num_segments, head_size] + # [num_tokens, num_query_heads, num_segments, head_size_padded] segm_max_ptr, # [num_tokens, num_query_heads, num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] query_ptr, # [num_tokens, num_query_heads, head_size] @@ -608,7 +608,7 @@ def kernel_unified_attention_3d( def reduce_segments( output_ptr, # [num_tokens, num_query_heads, head_size] segm_output_ptr, - # [num_tokens, num_query_heads, max_num_segments, head_size] + # [num_tokens, num_query_heads, max_num_segments, head_size_padded] segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] seq_lens_ptr, # [num_seqs] diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index f8717b763760..b490f0e86c2f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -134,12 +134,13 @@ def __init__( self.seq_threshold_3D = 0 self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS + headdim_padded = 1 << (self.headdim - 1).bit_length() # next power of 2 value self.softmax_segm_output = torch.empty( ( self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments, - self.headdim, + headdim_padded, ), dtype=torch.float32, device=device, From a62aa118fb6ed96c7668e4c08515cebf5c42fda9 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Tue, 11 Nov 2025 09:39:40 -0500 Subject: [PATCH 09/10] fix failing ruff check Signed-off-by: Jan van Lunteren --- tests/compile/test_fusions_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index f67063cdf42e..e1560efb3f24 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -74,7 +74,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="Qwen/Qwen3-30B-A3B", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=0, allreduce_fusions=97, ), From 90e746ac726e5a72773a83836918cdb2abb90b07 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Thu, 13 Nov 2025 12:31:45 -0500 Subject: [PATCH 10/10] remove dependencies on other PRs Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 79 +++++++++++++++---- 1 file changed, 65 insertions(+), 14 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 8e1e92f118a8..8a28b2e4bba4 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -36,13 +36,14 @@ def find_seq_idx( target_idx, num_seqs, BLOCK_Q: tl.constexpr, + use_q_block_mode: tl.constexpr, ): left: tl.int32 = 0 right = num_seqs while left < right: mid = (left + right) // 2 val = tl.load(query_start_len_ptr + mid) - mid_val = val // BLOCK_Q + mid + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val if mid_val <= target_idx: left = mid + 1 @@ -104,7 +105,9 @@ def kernel_unified_attention_2d( q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) - seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q) + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx @@ -390,13 +393,32 @@ def kernel_unified_attention_3d( stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): - seq_idx = tl.program_id(0) + q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -410,9 +432,9 @@ def kernel_unified_attention_3d( offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_t = tl.arange(0, TILE_SIZE) - query_pos = offs_m // num_queries_per_kv + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv - query_offset_0 = seq_idx + query_pos + query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv query_offset = ( query_offset_0[:, None] * query_stride_0 @@ -421,7 +443,7 @@ def kernel_unified_attention_3d( ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) - query_mask_0 = tl.where(query_pos < 1, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) # Q : (BLOCK_M, HEAD_SIZE_PADDED) @@ -449,7 +471,7 @@ def kernel_unified_attention_3d( acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) # context length for this particular sequences - context_len = seq_len - 1 + context_len = seq_len - cur_batch_query_len # alibi slope for this head if USE_ALIBI_SLOPES: @@ -463,7 +485,23 @@ def kernel_unified_attention_3d( qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 ) # shape: [BLOCK_M] - num_tiles = cdiv_fn(seq_len, TILE_SIZE) + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) # iterate through tiles within current segment for j in range( @@ -471,7 +509,7 @@ def kernel_unified_attention_3d( min((segm_idx + 1) * tiles_per_segment, num_tiles), ): seq_offset = j * TILE_SIZE + offs_t - tile_mask = seq_offset < seq_len + tile_mask = seq_offset < max_seq_prefix_len physical_block_idx = tl.load( block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE @@ -608,10 +646,11 @@ def kernel_unified_attention_3d( def reduce_segments( output_ptr, # [num_tokens, num_query_heads, head_size] segm_output_ptr, - # [num_tokens, num_query_heads, max_num_segments, head_size_padded] + # [num_tokens, num_query_heads, max_num_segments, head_size] segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] seq_lens_ptr, # [num_seqs] + num_seqs, # int num_query_heads: tl.constexpr, # int out_scale_inv, # float32 output_stride_0: tl.int64, # int @@ -620,14 +659,20 @@ def reduce_segments( TILE_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): - seq_idx = tl.program_id(0) + query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) + seq_idx = find_seq_idx( + query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False + ) + # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -644,7 +689,7 @@ def reduce_segments( # load segment maxima segm_offset = ( - seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + query_head_idx * NUM_SEGMENTS_PER_SEQ + tl.arange(0, NUM_SEGMENTS_PER_SEQ) ) @@ -658,7 +703,7 @@ def reduce_segments( # load, rescale, and add segment attention outputs segm_output_offset = ( - seq_idx.to(tl.int64) + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED @@ -680,7 +725,7 @@ def reduce_segments( # write result output_offset = ( - seq_idx * output_stride_0 + query_token_idx * output_stride_0 + query_head_idx * output_stride_1 + tl.arange(0, HEAD_SIZE_PADDED) ) @@ -846,6 +891,9 @@ def unified_attention( stride_v_cache_1=v.stride(1), stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, BLOCK_M=BLOCK_M, NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, ) @@ -855,6 +903,7 @@ def unified_attention( segm_max_ptr=softmax_segm_max, segm_expsum_ptr=softmax_segm_expsum, seq_lens_ptr=seqused_k, + num_seqs=num_seqs, num_query_heads=num_query_heads, out_scale_inv=1 / output_scale if output_scale is not None else 1.0, output_stride_0=out.stride(0), @@ -863,6 +912,8 @@ def unified_attention( TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, USE_FP8=output_scale is not None, )