Skip to content

Commit 50d9af6

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
Add trtlllm to triton bench (meta-pytorch#379)
Summary: Run C++ FLASHINFER_CUBIN_DIR=/data/users/$USER/fbsource/fbcode/deeplearning/flashinfer/fb/cubins/ buck2 run mode/opt mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //deeplearning/flashinfer/trtllm_kernel_interfaces:run_example``` ------- Run Triton bench buck2 run mode/opt mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //pytorch/tritonbench:run -- --op decoding_attention --only trtllm_decode_fmha --seq-len-q 1 --metrics gbps Todo: Support non-paged case Reviewed By: YJYJLee Differential Revision: D81021980
1 parent ef75d59 commit 50d9af6

File tree

2 files changed

+184
-1
lines changed

2 files changed

+184
-1
lines changed

tritonbench/operators/decoding_attention/operator.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,42 @@
5757
torch.ops.load_library(
5858
"//deeplearning/fbgemm/fbgemm_gpu/experimental:gen_ai_attention_ops"
5959
)
60-
6160
HAS_FB_IMPORT = True
6261
except ImportError:
6362
HAS_FB_IMPORT = False
6463

64+
# Load FlashInfer FMHA Gen library (includes TRTLLM kernels)
65+
torch.ops.load_library("//deeplearning/flashinfer:fmha_gen")
66+
67+
# Initialize FlashInfer cubin loader
68+
try:
69+
from flashinfer.jit.cubin_loader import setup_cubin_loader
70+
71+
# Find the loaded library from the dlopen handle
72+
# The torch.ops.load_library should have loaded it already
73+
lib_name = "libdeeplearning_flashinfer_fmha_gen.so"
74+
75+
# Try to find it in /proc/self/maps
76+
found = False
77+
with open('/proc/self/maps', 'r') as f:
78+
for line in f:
79+
if lib_name in line:
80+
# Extract the path from the line
81+
parts = line.strip().split()
82+
if len(parts) >= 6:
83+
lib_path = ' '.join(parts[5:])
84+
setup_cubin_loader(lib_path)
85+
found = True
86+
break
87+
88+
if not found:
89+
print(f"Warning: Could not find {lib_name} in loaded libraries")
90+
except Exception as e:
91+
print(f"Warning: Could not initialize FlashInfer cubin loader: {e}")
92+
import traceback
93+
traceback.print_exc()
94+
95+
from .trtllm_utils import trtllm_paged_attention_decode_func
6596

6697
from tritonbench.utils.triton_op import (
6798
BenchmarkOperator,
@@ -663,3 +694,14 @@ def aiter_paged_fp8kv(
663694
k_scale_asm,
664695
v_scale_asm,
665696
)
697+
698+
@register_benchmark()
699+
def trtllm_decode_fmha(
700+
self,
701+
q: torch.Tensor,
702+
k_cache: torch.Tensor,
703+
v_cache: torch.Tensor,
704+
cache_seqlens: torch.Tensor,
705+
) -> Callable:
706+
args = trtllm_paged_attention_decode_func(q, k_cache, v_cache, cache_seqlens)
707+
return lambda: torch.ops.fmha_gen.trtllm_paged_attention_decode(*args)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
TRTLLM FMHA utility functions for handling tensor conversion and kernel preparation.
9+
"""
10+
11+
import torch
12+
13+
14+
def trtllm_paged_attention_decode_func(q, k_cache, v_cache, cache_seqlens):
15+
"""
16+
TRTLLM FMHA paged attention decode function that prepares inputs for the
17+
FlashInfer fmha_gen library's trtllm_paged_attention_decode kernel.
18+
19+
This function converts standard KV cache tensors to paged format and prepares
20+
all necessary parameters for the TRTLLM kernel.
21+
22+
Args:
23+
q: Query tensor [batch, seq_len_q, num_qo_heads, head_dim]
24+
k_cache: Key cache tensor [batch, max_seq_len_kv, num_kv_heads, head_dim]
25+
v_cache: Value cache tensor [batch, max_seq_len_kv, num_kv_heads, head_dim]
26+
cache_seqlens: Sequence lengths tensor [batch]
27+
28+
Returns:
29+
Tuple of arguments for torch.ops.fmha_gen.trtllm_paged_attention_decode:
30+
(out, out_scale_factor, query, key_cache, value_cache, workspace_buffer,
31+
block_tables, seq_lens, max_kv_len, bmm1_scale, bmm2_scale, o_sf_scale,
32+
o_sf_vec_size, o_sf_start_index, window_left, sm_count, enable_pdl,
33+
workspace_size, attention_sinks)
34+
"""
35+
36+
device = q.device
37+
# Convert input tensors to paged format for TRTLLM FMHA
38+
batch_size, seq_len_q, num_qo_heads, head_dim = q.shape
39+
_, max_seq_len_kv, num_kv_heads, _ = k_cache.shape
40+
41+
# Use page size of 16 for TRTLLM FMHA
42+
page_size = 16
43+
max_num_blocks_per_seq = (max_seq_len_kv + page_size - 1) // page_size
44+
total_pages = batch_size * max_num_blocks_per_seq
45+
46+
# Reshape k_cache and v_cache to paged format [total_pages, num_kv_heads, page_size, head_dim]
47+
k_cache_paged = k_cache.view(
48+
batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim
49+
)
50+
k_cache_paged = k_cache_paged.permute(0, 1, 3, 2, 4).contiguous()
51+
k_cache_paged = k_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim)
52+
53+
v_cache_paged = v_cache.view(
54+
batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim
55+
)
56+
v_cache_paged = v_cache_paged.permute(0, 1, 3, 2, 4).contiguous()
57+
v_cache_paged = v_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim)
58+
59+
# Create block tables
60+
block_tables = torch.zeros(
61+
(batch_size, max_num_blocks_per_seq), dtype=torch.int32, device=device
62+
)
63+
for i in range(batch_size):
64+
for j in range(max_num_blocks_per_seq):
65+
block_tables[i, j] = i * max_num_blocks_per_seq + j
66+
67+
# Create output tensor
68+
out = torch.zeros_like(q)
69+
70+
# Create workspace buffer
71+
workspace_size = 128 * 1024 * 1024 # 128MB
72+
workspace_buffer = torch.zeros(workspace_size, dtype=torch.uint8, device=device)
73+
74+
# Attention parameters
75+
max_seq_len = cache_seqlens.max().item()
76+
bmm1_scale = 1.0 / (head_dim**0.5)
77+
bmm2_scale = 1.0
78+
79+
# Output scale factor parameters (not used for non-FP8)
80+
out_scale_factor = None # Optional tensor for FP8 output scaling
81+
o_sf_scale = -1.0 # Output scale factor scale (disabled when -1)
82+
o_sf_vec_size = -1 # Output scale factor vector size (disabled when -1)
83+
o_sf_start_index = -1 # Output scale factor start index (disabled when -1)
84+
85+
# Attention window settings
86+
window_left = -1 # No sliding window (disabled when -1)
87+
88+
# Device settings
89+
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
90+
91+
# PDL (Programmatic Dependent Launch) settings
92+
enable_pdl = False
93+
94+
# Attention sinks (optional)
95+
attention_sinks = None
96+
97+
# Return tuple matching trtllm_paged_attention_decode signature:
98+
# void trtllm_paged_attention_decode(
99+
# at::Tensor out,
100+
# std::optional<at::Tensor> out_scale_factor,
101+
# at::Tensor query,
102+
# at::Tensor key_cache,
103+
# at::Tensor value_cache,
104+
# at::Tensor workspace_buffer,
105+
# at::Tensor block_tables,
106+
# at::Tensor seq_lens,
107+
# int64_t max_kv_len,
108+
# double bmm1_scale,
109+
# double bmm2_scale,
110+
# double o_sf_scale,
111+
# int64_t o_sf_vec_size,
112+
# int64_t o_sf_start_index,
113+
# int64_t window_left,
114+
# int64_t sm_count,
115+
# bool enable_pdl,
116+
# int64_t workspace_size,
117+
# std::optional<at::Tensor> attention_sinks
118+
# )
119+
120+
args = (
121+
out, # out
122+
out_scale_factor, # out_scale_factor (optional)
123+
q, # query
124+
k_cache_paged, # key_cache
125+
v_cache_paged, # value_cache
126+
workspace_buffer, # workspace_buffer
127+
block_tables, # block_tables
128+
cache_seqlens, # seq_lens
129+
max_seq_len, # max_kv_len
130+
bmm1_scale, # bmm1_scale
131+
bmm2_scale, # bmm2_scale
132+
o_sf_scale, # o_sf_scale
133+
o_sf_vec_size, # o_sf_vec_size
134+
o_sf_start_index, # o_sf_start_index
135+
window_left, # window_left
136+
sm_count, # sm_count
137+
enable_pdl, # enable_pdl
138+
workspace_size, # workspace_size
139+
attention_sinks, # attention_sinks (optional)
140+
)
141+
return args

0 commit comments

Comments
 (0)