|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +TRTLLM FMHA utility functions for handling tensor conversion and kernel preparation. |
| 9 | +""" |
| 10 | + |
| 11 | +import torch |
| 12 | + |
| 13 | + |
| 14 | +def trtllm_paged_attention_decode_func(q, k_cache, v_cache, cache_seqlens): |
| 15 | + """ |
| 16 | + TRTLLM FMHA paged attention decode function that prepares inputs for the |
| 17 | + FlashInfer fmha_gen library's trtllm_paged_attention_decode kernel. |
| 18 | +
|
| 19 | + This function converts standard KV cache tensors to paged format and prepares |
| 20 | + all necessary parameters for the TRTLLM kernel. |
| 21 | +
|
| 22 | + Args: |
| 23 | + q: Query tensor [batch, seq_len_q, num_qo_heads, head_dim] |
| 24 | + k_cache: Key cache tensor [batch, max_seq_len_kv, num_kv_heads, head_dim] |
| 25 | + v_cache: Value cache tensor [batch, max_seq_len_kv, num_kv_heads, head_dim] |
| 26 | + cache_seqlens: Sequence lengths tensor [batch] |
| 27 | +
|
| 28 | + Returns: |
| 29 | + Tuple of arguments for torch.ops.fmha_gen.trtllm_paged_attention_decode: |
| 30 | + (out, out_scale_factor, query, key_cache, value_cache, workspace_buffer, |
| 31 | + block_tables, seq_lens, max_kv_len, bmm1_scale, bmm2_scale, o_sf_scale, |
| 32 | + o_sf_vec_size, o_sf_start_index, window_left, sm_count, enable_pdl, |
| 33 | + workspace_size, attention_sinks) |
| 34 | + """ |
| 35 | + |
| 36 | + device = q.device |
| 37 | + # Convert input tensors to paged format for TRTLLM FMHA |
| 38 | + batch_size, seq_len_q, num_qo_heads, head_dim = q.shape |
| 39 | + _, max_seq_len_kv, num_kv_heads, _ = k_cache.shape |
| 40 | + |
| 41 | + # Use page size of 16 for TRTLLM FMHA |
| 42 | + page_size = 16 |
| 43 | + max_num_blocks_per_seq = (max_seq_len_kv + page_size - 1) // page_size |
| 44 | + total_pages = batch_size * max_num_blocks_per_seq |
| 45 | + |
| 46 | + # Reshape k_cache and v_cache to paged format [total_pages, num_kv_heads, page_size, head_dim] |
| 47 | + k_cache_paged = k_cache.view( |
| 48 | + batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim |
| 49 | + ) |
| 50 | + k_cache_paged = k_cache_paged.permute(0, 1, 3, 2, 4).contiguous() |
| 51 | + k_cache_paged = k_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim) |
| 52 | + |
| 53 | + v_cache_paged = v_cache.view( |
| 54 | + batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim |
| 55 | + ) |
| 56 | + v_cache_paged = v_cache_paged.permute(0, 1, 3, 2, 4).contiguous() |
| 57 | + v_cache_paged = v_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim) |
| 58 | + |
| 59 | + # Create block tables |
| 60 | + block_tables = torch.zeros( |
| 61 | + (batch_size, max_num_blocks_per_seq), dtype=torch.int32, device=device |
| 62 | + ) |
| 63 | + for i in range(batch_size): |
| 64 | + for j in range(max_num_blocks_per_seq): |
| 65 | + block_tables[i, j] = i * max_num_blocks_per_seq + j |
| 66 | + |
| 67 | + # Create output tensor |
| 68 | + out = torch.zeros_like(q) |
| 69 | + |
| 70 | + # Create workspace buffer |
| 71 | + workspace_size = 128 * 1024 * 1024 # 128MB |
| 72 | + workspace_buffer = torch.zeros(workspace_size, dtype=torch.uint8, device=device) |
| 73 | + |
| 74 | + # Attention parameters |
| 75 | + max_seq_len = cache_seqlens.max().item() |
| 76 | + bmm1_scale = 1.0 / (head_dim**0.5) |
| 77 | + bmm2_scale = 1.0 |
| 78 | + |
| 79 | + # Output scale factor parameters (not used for non-FP8) |
| 80 | + out_scale_factor = None # Optional tensor for FP8 output scaling |
| 81 | + o_sf_scale = -1.0 # Output scale factor scale (disabled when -1) |
| 82 | + o_sf_vec_size = -1 # Output scale factor vector size (disabled when -1) |
| 83 | + o_sf_start_index = -1 # Output scale factor start index (disabled when -1) |
| 84 | + |
| 85 | + # Attention window settings |
| 86 | + window_left = -1 # No sliding window (disabled when -1) |
| 87 | + |
| 88 | + # Device settings |
| 89 | + sm_count = torch.cuda.get_device_properties(device).multi_processor_count |
| 90 | + |
| 91 | + # PDL (Programmatic Dependent Launch) settings |
| 92 | + enable_pdl = False |
| 93 | + |
| 94 | + # Attention sinks (optional) |
| 95 | + attention_sinks = None |
| 96 | + |
| 97 | + # Return tuple matching trtllm_paged_attention_decode signature: |
| 98 | + # void trtllm_paged_attention_decode( |
| 99 | + # at::Tensor out, |
| 100 | + # std::optional<at::Tensor> out_scale_factor, |
| 101 | + # at::Tensor query, |
| 102 | + # at::Tensor key_cache, |
| 103 | + # at::Tensor value_cache, |
| 104 | + # at::Tensor workspace_buffer, |
| 105 | + # at::Tensor block_tables, |
| 106 | + # at::Tensor seq_lens, |
| 107 | + # int64_t max_kv_len, |
| 108 | + # double bmm1_scale, |
| 109 | + # double bmm2_scale, |
| 110 | + # double o_sf_scale, |
| 111 | + # int64_t o_sf_vec_size, |
| 112 | + # int64_t o_sf_start_index, |
| 113 | + # int64_t window_left, |
| 114 | + # int64_t sm_count, |
| 115 | + # bool enable_pdl, |
| 116 | + # int64_t workspace_size, |
| 117 | + # std::optional<at::Tensor> attention_sinks |
| 118 | + # ) |
| 119 | + |
| 120 | + args = ( |
| 121 | + out, # out |
| 122 | + out_scale_factor, # out_scale_factor (optional) |
| 123 | + q, # query |
| 124 | + k_cache_paged, # key_cache |
| 125 | + v_cache_paged, # value_cache |
| 126 | + workspace_buffer, # workspace_buffer |
| 127 | + block_tables, # block_tables |
| 128 | + cache_seqlens, # seq_lens |
| 129 | + max_seq_len, # max_kv_len |
| 130 | + bmm1_scale, # bmm1_scale |
| 131 | + bmm2_scale, # bmm2_scale |
| 132 | + o_sf_scale, # o_sf_scale |
| 133 | + o_sf_vec_size, # o_sf_vec_size |
| 134 | + o_sf_start_index, # o_sf_start_index |
| 135 | + window_left, # window_left |
| 136 | + sm_count, # sm_count |
| 137 | + enable_pdl, # enable_pdl |
| 138 | + workspace_size, # workspace_size |
| 139 | + attention_sinks, # attention_sinks (optional) |
| 140 | + ) |
| 141 | + return args |
0 commit comments