Skip to content
2 changes: 1 addition & 1 deletion tests/compile/test_fusions_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class ModelBackendTestCase(NamedTuple):
ModelBackendTestCase(
model_name="Qwen/Qwen3-30B-A3B",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=97,
),
Expand Down
26 changes: 26 additions & 0 deletions tests/kernels/attention/test_triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]

# 0: use 2D kernel for decode
# 8: use 3D kernel for decode
SEQ_THRESHOLD_3D_VALUES = [0, 8]


def ref_paged_attn(
query: torch.Tensor,
Expand Down Expand Up @@ -92,6 +96,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("q_dtype", QDTYPES)
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
@torch.inference_mode()
def test_triton_unified_attn(
seq_lens: list[tuple[int, int]],
Expand All @@ -103,6 +108,7 @@ def test_triton_unified_attn(
soft_cap: float | None,
num_blocks: int,
q_dtype: torch.dtype | None,
seq_threshold_3D: int,
) -> None:
torch.set_default_device("cuda")

Expand Down Expand Up @@ -152,6 +158,21 @@ def test_triton_unified_attn(
k_descale = torch.rand(scale_shape, dtype=torch.float32)
v_descale = torch.rand(scale_shape, dtype=torch.float32)

num_par_softmax_segments = 16
head_size_padded = 1 << (head_size - 1).bit_length() # next power of 2 value
softmax_segm_output = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
dtype=torch.float32,
)
softmax_segm_max = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)
softmax_segm_expsum = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)

unified_attention(
q=maybe_quantized_query,
k=maybe_quantized_key_cache,
Expand All @@ -169,6 +190,11 @@ def test_triton_unified_attn(
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
)

ref_output = ref_paged_attn(
Expand Down
137 changes: 32 additions & 105 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@ def find_seq_idx(
target_idx,
num_seqs,
BLOCK_Q: tl.constexpr,
use_q_block_mode: tl.constexpr,
):
left: tl.int32 = 0
right = num_seqs
while left < right:
mid = (left + right) // 2
val = tl.load(query_start_len_ptr + mid)
mid_val = val // BLOCK_Q + mid if use_q_block_mode else val
mid_val = val // BLOCK_Q + mid

if mid_val <= target_idx:
left = mid + 1
Expand Down Expand Up @@ -105,9 +104,7 @@ def kernel_unified_attention_2d(
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)

seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q)

q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx

Expand Down Expand Up @@ -355,7 +352,7 @@ def kernel_unified_attention_2d(
@triton.jit
def kernel_unified_attention_3d(
segm_output_ptr,
# [num_tokens, num_query_heads, num_segments, head_size]
# [num_tokens, num_query_heads, num_segments, head_size_padded]
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
query_ptr, # [num_tokens, num_query_heads, head_size]
Expand Down Expand Up @@ -393,32 +390,13 @@ def kernel_unified_attention_3d(
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
):
q_block_global_idx = tl.program_id(0)
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
segm_idx = tl.program_id(2)

seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)

q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx

q_block_local_idx = q_block_global_idx - q_block_start_idx

cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)

cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index

if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return

# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)

Expand All @@ -432,9 +410,9 @@ def kernel_unified_attention_3d(
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_pos = offs_m // num_queries_per_kv

query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_0 = seq_idx + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
query_offset = (
query_offset_0[:, None] * query_stride_0
Expand All @@ -443,7 +421,7 @@ def kernel_unified_attention_3d(
)

dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < 1, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)

# Q : (BLOCK_M, HEAD_SIZE_PADDED)
Expand Down Expand Up @@ -471,7 +449,7 @@ def kernel_unified_attention_3d(
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)

# context length for this particular sequences
context_len = seq_len - cur_batch_query_len
context_len = seq_len - 1

# alibi slope for this head
if USE_ALIBI_SLOPES:
Expand All @@ -485,31 +463,15 @@ def kernel_unified_attention_3d(
qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
) # shape: [BLOCK_M]

# compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = (
context_len
+ q_block_local_idx * BLOCK_Q
+ (BLOCK_M - 1) // num_queries_per_kv
+ 1
)

# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)

# calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, tiles beyond
# this prefix can be skipped)
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
num_tiles = cdiv_fn(seq_len, TILE_SIZE)

# iterate through tiles within current segment
for j in range(
segm_idx * tiles_per_segment,
min((segm_idx + 1) * tiles_per_segment, num_tiles),
):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len
tile_mask = seq_offset < seq_len

physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
Expand Down Expand Up @@ -646,11 +608,10 @@ def kernel_unified_attention_3d(
def reduce_segments(
output_ptr, # [num_tokens, num_query_heads, head_size]
segm_output_ptr,
# [num_tokens, num_query_heads, max_num_segments, head_size]
# [num_tokens, num_query_heads, max_num_segments, head_size_padded]
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
seq_lens_ptr, # [num_seqs]
num_seqs, # int
num_query_heads: tl.constexpr, # int
out_scale_inv, # float32
output_stride_0: tl.int64, # int
Expand All @@ -659,20 +620,14 @@ def reduce_segments(
TILE_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
query_token_idx = tl.program_id(0)
seq_idx = tl.program_id(0)
query_head_idx = tl.program_id(1)

seq_idx = find_seq_idx(
query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False
)

# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)

Expand All @@ -689,7 +644,7 @@ def reduce_segments(

# load segment maxima
segm_offset = (
query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
+ query_head_idx * NUM_SEGMENTS_PER_SEQ
+ tl.arange(0, NUM_SEGMENTS_PER_SEQ)
)
Expand All @@ -703,7 +658,7 @@ def reduce_segments(

# load, rescale, and add segment attention outputs
segm_output_offset = (
query_token_idx.to(tl.int64)
seq_idx.to(tl.int64)
* (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED)
+ tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED
Expand All @@ -725,7 +680,7 @@ def reduce_segments(

# write result
output_offset = (
query_token_idx * output_stride_0
seq_idx * output_stride_0
+ query_head_idx * output_stride_1
+ tl.arange(0, HEAD_SIZE_PADDED)
)
Expand All @@ -749,6 +704,11 @@ def unified_attention(
q_descale,
k_descale,
v_descale,
seq_threshold_3D,
num_par_softmax_segments,
softmax_segm_output,
softmax_segm_max,
softmax_segm_expsum,
alibi_slopes=None,
output_scale=None,
qq_bias=None,
Expand Down Expand Up @@ -793,8 +753,8 @@ def unified_attention(
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32

# if batch contains a prefill
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
# if batch contains a prefill or number of sequences is larger than threshold
if max_seqlen_q > 1 or num_seqs > seq_threshold_3D:
kernel_unified_attention_2d[
(
total_num_q_blocks,
Expand Down Expand Up @@ -847,37 +807,10 @@ def unified_attention(
USE_FP8=output_scale is not None,
)
else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
# value that showed good performance in tests
NUM_SEGMENTS = 16

segm_output = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
triton.next_power_of_2(head_size),
dtype=torch.float32,
device=q.device,
)
segm_max = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
segm_expsum = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)

kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
kernel_unified_attention_3d[(num_seqs, num_kv_heads, num_par_softmax_segments)](
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
Expand Down Expand Up @@ -913,19 +846,15 @@ def unified_attention(
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
)
reduce_segments[(q.shape[0], num_query_heads)](
reduce_segments[(num_seqs, num_query_heads)](
output_ptr=out,
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
seq_lens_ptr=seqused_k,
num_seqs=num_seqs,
num_query_heads=num_query_heads,
out_scale_inv=1 / output_scale if output_scale is not None else 1.0,
output_stride_0=out.stride(0),
Expand All @@ -934,8 +863,6 @@ def unified_attention(
TILE_SIZE=TILE_SIZE_DECODE,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
USE_FP8=output_scale is not None,
)
Loading