diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index f67063cdf42e..e1560efb3f24 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -74,7 +74,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="Qwen/Qwen3-30B-A3B", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=0, allreduce_fusions=97, ), diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index bf4d2179af5f..a0f9952713fd 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -22,6 +22,10 @@ # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] +# 0: use 2D kernel for decode +# 8: use 3D kernel for decode +SEQ_THRESHOLD_3D_VALUES = [0, 8] + def ref_paged_attn( query: torch.Tensor, @@ -92,6 +96,7 @@ def ref_paged_attn( @pytest.mark.parametrize("soft_cap", [None, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) +@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES) @torch.inference_mode() def test_triton_unified_attn( seq_lens: list[tuple[int, int]], @@ -103,6 +108,7 @@ def test_triton_unified_attn( soft_cap: float | None, num_blocks: int, q_dtype: torch.dtype | None, + seq_threshold_3D: int, ) -> None: torch.set_default_device("cuda") @@ -152,6 +158,21 @@ def test_triton_unified_attn( k_descale = torch.rand(scale_shape, dtype=torch.float32) v_descale = torch.rand(scale_shape, dtype=torch.float32) + num_par_softmax_segments = 16 + head_size_padded = 1 << (head_size - 1).bit_length() # next power of 2 value + softmax_segm_output = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded), + dtype=torch.float32, + ) + softmax_segm_max = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments), + dtype=torch.float32, + ) + softmax_segm_expsum = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments), + dtype=torch.float32, + ) + unified_attention( q=maybe_quantized_query, k=maybe_quantized_key_cache, @@ -169,6 +190,11 @@ def test_triton_unified_attn( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, ) ref_output = ref_paged_attn( diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 565be1c39bec..8e1e92f118a8 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -36,14 +36,13 @@ def find_seq_idx( target_idx, num_seqs, BLOCK_Q: tl.constexpr, - use_q_block_mode: tl.constexpr, ): left: tl.int32 = 0 right = num_seqs while left < right: mid = (left + right) // 2 val = tl.load(query_start_len_ptr + mid) - mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + mid_val = val // BLOCK_Q + mid if mid_val <= target_idx: left = mid + 1 @@ -105,9 +104,7 @@ def kernel_unified_attention_2d( q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) - seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True - ) + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q) q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx @@ -355,7 +352,7 @@ def kernel_unified_attention_2d( @triton.jit def kernel_unified_attention_3d( segm_output_ptr, - # [num_tokens, num_query_heads, num_segments, head_size] + # [num_tokens, num_query_heads, num_segments, head_size_padded] segm_max_ptr, # [num_tokens, num_query_heads, num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] query_ptr, # [num_tokens, num_query_heads, head_size] @@ -393,32 +390,13 @@ def kernel_unified_attention_3d( stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): - q_block_global_idx = tl.program_id(0) + seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) - seq_idx = find_seq_idx( - query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True - ) - - q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx - - q_block_local_idx = q_block_global_idx - q_block_start_idx - - cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - - cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index - - if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: - return - # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -432,9 +410,9 @@ def kernel_unified_attention_3d( offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_t = tl.arange(0, TILE_SIZE) - query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + query_pos = offs_m // num_queries_per_kv - query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_0 = seq_idx + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv query_offset = ( query_offset_0[:, None] * query_stride_0 @@ -443,7 +421,7 @@ def kernel_unified_attention_3d( ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) - query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < 1, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) # Q : (BLOCK_M, HEAD_SIZE_PADDED) @@ -471,7 +449,7 @@ def kernel_unified_attention_3d( acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) # context length for this particular sequences - context_len = seq_len - cur_batch_query_len + context_len = seq_len - 1 # alibi slope for this head if USE_ALIBI_SLOPES: @@ -485,23 +463,7 @@ def kernel_unified_attention_3d( qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 ) # shape: [BLOCK_M] - # compute the length of the longest sequence prefix spanned by any - # query token in the current q_block (q_block_local_idx) - max_seq_prefix_len = ( - context_len - + q_block_local_idx * BLOCK_Q - + (BLOCK_M - 1) // num_queries_per_kv - + 1 - ) - - # adjust for potential padding in the last q_block by considering the - # actual sequence length - max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - - # calculate the number of tiles that need to be processed to - # cover the longest sequence prefix (due to causal masking, tiles beyond - # this prefix can be skipped) - num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + num_tiles = cdiv_fn(seq_len, TILE_SIZE) # iterate through tiles within current segment for j in range( @@ -509,7 +471,7 @@ def kernel_unified_attention_3d( min((segm_idx + 1) * tiles_per_segment, num_tiles), ): seq_offset = j * TILE_SIZE + offs_t - tile_mask = seq_offset < max_seq_prefix_len + tile_mask = seq_offset < seq_len physical_block_idx = tl.load( block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE @@ -646,11 +608,10 @@ def kernel_unified_attention_3d( def reduce_segments( output_ptr, # [num_tokens, num_query_heads, head_size] segm_output_ptr, - # [num_tokens, num_query_heads, max_num_segments, head_size] + # [num_tokens, num_query_heads, max_num_segments, head_size_padded] segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] seq_lens_ptr, # [num_seqs] - num_seqs, # int num_query_heads: tl.constexpr, # int out_scale_inv, # float32 output_stride_0: tl.int64, # int @@ -659,20 +620,14 @@ def reduce_segments( TILE_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int USE_FP8: tl.constexpr, # bool FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): - query_token_idx = tl.program_id(0) + seq_idx = tl.program_id(0) query_head_idx = tl.program_id(1) - seq_idx = find_seq_idx( - query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False - ) - # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -689,7 +644,7 @@ def reduce_segments( # load segment maxima segm_offset = ( - query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + query_head_idx * NUM_SEGMENTS_PER_SEQ + tl.arange(0, NUM_SEGMENTS_PER_SEQ) ) @@ -703,7 +658,7 @@ def reduce_segments( # load, rescale, and add segment attention outputs segm_output_offset = ( - query_token_idx.to(tl.int64) + seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED @@ -725,7 +680,7 @@ def reduce_segments( # write result output_offset = ( - query_token_idx * output_stride_0 + seq_idx * output_stride_0 + query_head_idx * output_stride_1 + tl.arange(0, HEAD_SIZE_PADDED) ) @@ -749,6 +704,11 @@ def unified_attention( q_descale, k_descale, v_descale, + seq_threshold_3D, + num_par_softmax_segments, + softmax_segm_output, + softmax_segm_max, + softmax_segm_expsum, alibi_slopes=None, output_scale=None, qq_bias=None, @@ -793,8 +753,8 @@ def unified_attention( TILE_SIZE_PREFILL = 32 TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 - # if batch contains a prefill - if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + # if batch contains a prefill or number of sequences is larger than threshold + if max_seqlen_q > 1 or num_seqs > seq_threshold_3D: kernel_unified_attention_2d[ ( total_num_q_blocks, @@ -847,37 +807,10 @@ def unified_attention( USE_FP8=output_scale is not None, ) else: - # for initial version, NUM_SEGMENTS = 16 is chosen as a default - # value that showed good performance in tests - NUM_SEGMENTS = 16 - - segm_output = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - triton.next_power_of_2(head_size), - dtype=torch.float32, - device=q.device, - ) - segm_max = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - segm_expsum = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - - kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, + kernel_unified_attention_3d[(num_seqs, num_kv_heads, num_par_softmax_segments)]( + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, @@ -913,19 +846,15 @@ def unified_attention( stride_v_cache_1=v.stride(1), stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, ) - reduce_segments[(q.shape[0], num_query_heads)]( + reduce_segments[(num_seqs, num_query_heads)]( output_ptr=out, - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, seq_lens_ptr=seqused_k, - num_seqs=num_seqs, num_query_heads=num_query_heads, out_scale_inv=1 / output_scale if output_scale is not None else 1.0, output_stride_0=out.stride(0), @@ -934,8 +863,6 @@ def unified_attention( TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, USE_FP8=output_scale is not None, ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 37c0ae61e65d..77a0d455479a 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -17,7 +17,7 @@ triton_reshape_and_cache_flash, ) from vllm.attention.ops.triton_unified_attention import unified_attention -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -36,6 +36,11 @@ logger = init_logger(__name__) +# constants +MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel +NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments + + @dataclass class TritonAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -54,6 +59,12 @@ class TritonAttentionMetadata: block_table: torch.Tensor slot_mapping: torch.Tensor + seq_threshold_3D: int + num_par_softmax_segments: int + softmax_segm_output: torch.Tensor + softmax_segm_max: torch.Tensor + softmax_segm_expsum: torch.Tensor + # For cascade attention. use_cascade: bool common_prefix_len: int @@ -87,6 +98,66 @@ def __init__( self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() + # Check if CUDA Graphs are enabled for decode + self.decode_cudagraph_enabled = ( + self.vllm_config.compilation_config.cudagraph_mode + in ( + CUDAGraphMode.FULL_AND_PIECEWISE, + CUDAGraphMode.FULL_DECODE_ONLY, + CUDAGraphMode.FULL, + ) + ) + + # Set initial value for the threshold for the number of sequences used + # to select between the 2D and 3D kernels for decode. + self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv + if self.decode_cudagraph_enabled: + capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + if not capture_sizes: + # If no CUDA Graph capture sizes are specified, the threshold + # is reset to zero, forcing the 2D kernel to be used. + self.seq_threshold_3D = 0 + else: + # Select the CUDA Graph capture size closest to self.seq_threshold_3D + # as threshold. This ensures that each captured graph covers the + # correct execution path. + upd_seq_threshold_3D = min( + capture_sizes, + key=lambda x: abs(x - self.seq_threshold_3D), + ) + + # If the updated threshold becomes significantly larger than the + # initial value, it is reset to zero. This enforces the use of the + # 2D kernel only and ensures that the size of the allocated + # intermediate structures remains bounded. + if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D: + self.seq_threshold_3D = upd_seq_threshold_3D + else: + self.seq_threshold_3D = 0 + + self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS + headdim_padded = 1 << (self.headdim - 1).bit_length() # next power of 2 value + self.softmax_segm_output = torch.empty( + ( + self.seq_threshold_3D, + self.num_heads_q, + self.num_par_softmax_segments, + headdim_padded, + ), + dtype=torch.float32, + device=device, + ) + self.softmax_segm_max = torch.empty( + (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), + dtype=torch.float32, + device=device, + ) + self.softmax_segm_expsum = torch.empty( + (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), + dtype=torch.float32, + device=device, + ) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: @@ -143,6 +214,11 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, + seq_threshold_3D=self.seq_threshold_3D, + num_par_softmax_segments=self.num_par_softmax_segments, + softmax_segm_output=self.softmax_segm_output, + softmax_segm_max=self.softmax_segm_max, + softmax_segm_expsum=self.softmax_segm_expsum, ) return attn_metadata @@ -346,6 +422,12 @@ def forward( max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + seq_threshold_3D = attn_metadata.seq_threshold_3D + num_par_softmax_segments = attn_metadata.num_par_softmax_segments + softmax_segm_output = attn_metadata.softmax_segm_output + softmax_segm_max = attn_metadata.softmax_segm_max + softmax_segm_expsum = attn_metadata.softmax_segm_expsum + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) unified_attention( @@ -366,6 +448,11 @@ def forward( q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, sinks=self.sinks, output_scale=output_scale, )