diff --git a/benchmarks/bench_hopper_fp8_attention.py b/benchmarks/bench_hopper_fp8_attention.py index 89224af622..75b02024d6 100644 --- a/benchmarks/bench_hopper_fp8_attention.py +++ b/benchmarks/bench_hopper_fp8_attention.py @@ -1,3 +1,19 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import numpy as np import torch @@ -8,39 +24,65 @@ ) -def bench_single_prefill(seq_len, num_heads, causal, head_dim): - num_qo_heads = num_kv_heads = num_heads - q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") - k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - - sm80_ms, sm90_ms = ( - np.median( - bench_gpu_time( - lambda: flashinfer.single_prefill_with_kv_cache_return_lse( - q, k, v, causal=causal, backend=backend - ), - dry_run_time_ms=100, - repeat_time_ms=1000, - ) - ) - for backend in ["fa2", "fa3"] +def per_head_symmetric_quant(x, quant_dtype): + """Per-head symmetric quantization to FP8.""" + o_min_val, o_max_val = ( + (-448.0, 448.0) if quant_dtype == torch.float8_e4m3fn else (-57344, 57344) + ) + x_max_val = x.abs().amax(dim=(0, 2)).to(dtype=torch.float32) + s_out = torch.clamp(x_max_val / o_max_val, min=1e-6) + s_out_broadcast = s_out.view(1, -1, 1) + q_x_out = torch.clamp(x / s_out_broadcast, min=o_min_val, max=o_max_val).to( + dtype=quant_dtype ) + return q_x_out, s_out - q = torch.randn( + +def bench_fp8_single_prefill( + seq_len, num_heads, causal, head_dim, dtype=torch.float8_e4m3fn +): + """Benchmark FP8 single prefill attention.""" + num_qo_heads = num_kv_heads = num_heads + + # Create FP16 tensors first, then quantize + q_fp16 = torch.randn( seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" - ).to(dtype=torch.float8_e4m3fn) - k = torch.randn( + ) + k_fp16 = torch.randn( seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" - ).to(dtype=torch.float8_e4m3fn) - v = torch.randn( + ) + v_fp16 = torch.randn( seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" - ).to(dtype=torch.float8_e4m3fn) + ) + + # Quantize to FP8 + q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype) + k_fp8, s_k = per_head_symmetric_quant(k_fp16, dtype) + v_fp8, s_v = per_head_symmetric_quant(v_fp16, dtype) - fp8_sm90_ms = np.median( + # FP16 baseline (fa3) + fp16_ms = np.median( bench_gpu_time( lambda: flashinfer.single_prefill_with_kv_cache_return_lse( - q, k, v, causal=causal, backend="fa3", o_dtype=torch.half + q_fp16, k_fp16, v_fp16, causal=causal, backend="fa3" + ), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ) + + # FP8 (fa3) + fp8_ms = np.median( + bench_gpu_time( + lambda: flashinfer.single_prefill_with_kv_cache_return_lse( + q_fp8, + k_fp8, + v_fp8, + causal=causal, + backend="fa3", + scale_q=s_q, + scale_k=s_k, + scale_v=s_v, ), dry_run_time_ms=100, repeat_time_ms=1000, @@ -59,7 +101,222 @@ def flops(ms): ) print( - f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s" + f"bench_fp8_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), " + f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), " + f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), " + f"speedup: {fp16_ms / fp8_ms:.2f}x" + ) + + +def bench_fp8_batch_ragged_prefill( + batch_size, num_heads, seq_len, causal, head_dim, dtype=torch.float8_e4m3fn +): + """Benchmark FP8 batch ragged prefill attention.""" + num_qo_heads = num_kv_heads = num_heads + total_len = batch_size * seq_len + + # Create FP16 tensors first + q_fp16 = torch.randn( + total_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + k_fp16 = torch.randn( + total_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + v_fp16 = torch.randn( + total_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + + # Quantize to FP8 + q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype) + k_fp8, s_k = per_head_symmetric_quant(k_fp16, dtype) + v_fp8, s_v = per_head_symmetric_quant(v_fp16, dtype) + + qo_indptr = torch.arange( + 0, total_len + 1, seq_len, dtype=torch.int32, device="cuda" + ) + kv_indptr = torch.arange( + 0, total_len + 1, seq_len, dtype=torch.int32, device="cuda" + ) + + # FP16 wrapper + fp16_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + kv_layout="NHD", + backend="fa3", + ) + fp16_wrapper.plan( + qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, causal=causal + ) + + # FP8 wrapper + fp8_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + kv_layout="NHD", + backend="fa3", + ) + fp8_wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + q_data_type=dtype, + kv_data_type=dtype, + o_data_type=torch.half, + causal=causal, + ) + + fp16_ms = np.median( + bench_gpu_time( + lambda: fp16_wrapper.run(q_fp16, k_fp16, v_fp16), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ) + + fp8_ms = np.median( + bench_gpu_time( + lambda: fp8_wrapper.run(q_fp8, k_fp8, v_fp8, s_q, s_k, s_v), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ) + + def flops(ms): + return attention_tflops_per_sec_with_actual_seq_lens( + torch.full((batch_size,), seq_len), + torch.full((batch_size,), seq_len), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) + + print( + f"bench_fp8_batch_ragged_prefill (batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), " + f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), " + f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), " + f"speedup: {fp16_ms / fp8_ms:.2f}x" + ) + + +def bench_fp8_batch_paged_prefill( + page_size, + batch_size, + num_heads, + seq_len, + causal, + head_dim, + dtype=torch.float8_e4m3fn, +): + """Benchmark FP8 batch paged prefill attention.""" + num_qo_heads = num_kv_heads = num_heads + total_qo_len = batch_size * seq_len + num_pages = batch_size * seq_len // page_size + + # Create FP16 tensors first + q_fp16 = torch.randn( + total_qo_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + # Paged KV cache: (num_pages, page_size, num_heads, head_dim) + k_fp16 = torch.randn( + num_pages, page_size, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + v_fp16 = torch.randn( + num_pages, page_size, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + + # Quantize to FP8 + q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype) + # For paged KV, reshape to (total_tokens, num_heads, head_dim) for quantization + k_flat = k_fp16.view(-1, num_kv_heads, head_dim) + v_flat = v_fp16.view(-1, num_kv_heads, head_dim) + k_fp8_flat, s_k = per_head_symmetric_quant(k_flat, dtype) + v_fp8_flat, s_v = per_head_symmetric_quant(v_flat, dtype) + k_fp8 = k_fp8_flat.view(num_pages, page_size, num_kv_heads, head_dim) + v_fp8 = v_fp8_flat.view(num_pages, page_size, num_kv_heads, head_dim) + + qo_indptr = torch.arange( + 0, total_qo_len + 1, seq_len, dtype=torch.int32, device="cuda" + ) + kv_indptr = torch.arange( + 0, num_pages + 1, seq_len // page_size, dtype=torch.int32, device="cuda" + ) + kv_indices = torch.arange(0, num_pages, dtype=torch.int32, device="cuda") + last_page_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") * page_size + + # FP16 wrapper + fp16_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + kv_layout="NHD", + backend="fa3", + ) + fp16_wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + ) + + # FP8 wrapper + fp8_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + kv_layout="NHD", + backend="fa3", + ) + fp8_wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + q_data_type=dtype, + kv_data_type=dtype, + o_data_type=torch.half, + causal=causal, + ) + + fp16_ms = np.median( + bench_gpu_time( + lambda: fp16_wrapper.run(q_fp16, (k_fp16, v_fp16)), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ) + + fp8_ms = np.median( + bench_gpu_time( + lambda: fp8_wrapper.run(q_fp8, (k_fp8, v_fp8), s_q, s_k, s_v), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ) + + def flops(ms): + return attention_tflops_per_sec_with_actual_seq_lens( + torch.full((batch_size,), seq_len), + torch.full((batch_size,), seq_len), + head_dim, + head_dim, + num_qo_heads, + causal, + ms, + ) + + print( + f"bench_fp8_batch_paged_prefill (page_size={page_size}, batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), " + f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), " + f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), " + f"speedup: {fp16_ms / fp8_ms:.2f}x" ) @@ -70,8 +327,30 @@ def flops(ms): print("Current benchmark targets capability (9, 0). Returning...") exit() - for seq_len in [4096, 8192, 16384]: - for num_heads in [24, 32]: - for causal in [True, False]: - for head_dim in [64, 128, 256]: - bench_single_prefill(seq_len, num_heads, causal, head_dim) + # Skip single prefill for now due to compilation issues + # print("=" * 80) + # print("FP8 Single Prefill Benchmarks") + # print("=" * 80) + # for head_dim in [128, 256]: + # for seq_len in [1024, 4096, 8192]: + # bench_fp8_single_prefill(seq_len, 32, True, head_dim) + + print() + print("=" * 80) + print("FP8 Batch Ragged Prefill Benchmarks") + print("=" * 80) + for head_dim in [128, 256]: + bench_fp8_batch_ragged_prefill(128, 32, 1024, True, head_dim) + bench_fp8_batch_ragged_prefill(64, 32, 2048, True, head_dim) + bench_fp8_batch_ragged_prefill(32, 32, 4096, True, head_dim) + bench_fp8_batch_ragged_prefill(16, 32, 8192, True, head_dim) + + print() + print("=" * 80) + print("FP8 Batch Paged Prefill Benchmarks") + print("=" * 80) + for head_dim in [128, 256]: + bench_fp8_batch_paged_prefill(16, 128, 32, 1024, True, head_dim) + bench_fp8_batch_paged_prefill(16, 64, 32, 2048, True, head_dim) + bench_fp8_batch_paged_prefill(16, 32, 32, 4096, True, head_dim) + bench_fp8_batch_paged_prefill(16, 16, 32, 8192, True, head_dim) diff --git a/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja b/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja index 8225edbb00..317f498fd1 100644 --- a/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja +++ b/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja @@ -1 +1,15 @@ -// TODO: Not implemented yet +#include +#include "batch_prefill_sm90_config.inc" + +namespace flashinfer { + +{% for same_scheduler_for_all_heads in ["true", "false"] %} +template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched + <{{ head_dim_qk }}, + {{ mask_mode }}, + /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, + {{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream); +{% endfor %} + +}; // namespace flashinfer diff --git a/csrc/batch_prefill_fp8_sm90.cu b/csrc/batch_prefill_fp8_sm90.cu index 7c8680dc0b..a2ef83cd3e 100644 --- a/csrc/batch_prefill_fp8_sm90.cu +++ b/csrc/batch_prefill_fp8_sm90.cu @@ -29,6 +29,11 @@ template +cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched(Params& params, bool enable_pdl, + cudaStream_t stream); + } // namespace flashinfer using namespace flashinfer; @@ -78,7 +83,94 @@ void BatchPrefillWithRaggedKVCacheSM90Run(ffi::TensorView float_workspace_buffer int64_t window_left, bool enable_pdl // placeholder ADDITIONAL_FUNC_PARAMS) { - return; // TODO: Implement this function + PrefillPlanSM90Info plan_info; + plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); + + if (maybe_lse.has_value()) { + const auto& lse = maybe_lse.value(); + TVM_FFI_ICHECK_EQ(lse.size(0), q.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q.size(1)); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + int64_t head_dim_qk = q.size(2); + int64_t head_dim_vo = v.size(2); + + QKVLayout kv_layout = static_cast(layout); + + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); + const MaskMode mask_mode = static_cast(mask_mode_code); + bool use_swa = window_left != -1; + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, + USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { + RaggedParams params; + + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse.value().data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + } else { + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + } + params.nnz_qo = q.size(0); + params.nnz_kv = k.size(0); + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.window_left = window_left; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.batch_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.batch_indices_offset); + + ADDITIONAL_PARAMS_SETTER + + // Not support various head_dim for now + static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same"); + // Currently only support same quantization precision + static_assert(std::is_same_v); + + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { + cudaError_t status = + BatchFP8PrefillWithRaggedKVCacheDispatched(params, enable_pdl, + stream); + + TVM_FFI_ICHECK(status == cudaSuccess) + << "BatchPrefillWithRaggedKVCacheSM90Run failed with error: " + << cudaGetErrorString(status); + return true; + }); + }); } void BatchPrefillWithPagedKVCacheSM90Run( @@ -136,12 +228,18 @@ void BatchPrefillWithPagedKVCacheSM90Run( params.k_stride_h = paged_k_cache.stride(2); params.v_stride_n = paged_v_cache.stride(1); params.v_stride_h = paged_v_cache.stride(2); + // For sparse paged KV cache, store the stride between pages + params.k_page_stride = paged_k_cache.stride(0); + params.v_page_stride = paged_v_cache.stride(0); } else { // (num_pages, num_heads, page_size, head_dim) params.k_stride_h = paged_k_cache.stride(1); params.k_stride_n = paged_k_cache.stride(2); params.v_stride_h = paged_v_cache.stride(1); params.v_stride_n = paged_v_cache.stride(2); + // For sparse paged KV cache, store the stride between pages + params.k_page_stride = paged_k_cache.stride(0); + params.v_page_stride = paged_v_cache.stride(0); } params.nnz_qo = q.size(0); params.num_qo_heads = q.size(1); diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu index 1cf78bab59..56a97b565b 100644 --- a/csrc/batch_prefill_sm90.cu +++ b/csrc/batch_prefill_sm90.cu @@ -218,13 +218,24 @@ void BatchPrefillWithPagedKVCacheSM90Run( params.k_stride_h = paged_k_cache.stride(2); params.v_stride_n = paged_v_cache.stride(1); params.v_stride_h = paged_v_cache.stride(2); + // For sparse paged KV cache, store the stride between pages + params.k_page_stride = paged_k_cache.stride(0); + params.v_page_stride = paged_v_cache.stride(0); } else { // (num_pages, num_heads, page_size, head_dim) params.k_stride_h = paged_k_cache.stride(1); params.k_stride_n = paged_k_cache.stride(2); params.v_stride_h = paged_v_cache.stride(1); params.v_stride_n = paged_v_cache.stride(2); + // For sparse paged KV cache, store the stride between pages + params.k_page_stride = paged_k_cache.stride(0); + params.v_page_stride = paged_v_cache.stride(0); } + // Sparse mainloop assumes K and V have same strides for efficiency + TVM_FFI_ICHECK_EQ(params.k_page_stride, params.v_page_stride) + << "K and V must have same page stride for sparse attention"; + TVM_FFI_ICHECK_EQ(params.k_stride_n, params.v_stride_n) + << "K and V must have same stride_n for sparse attention"; params.nnz_qo = q.size(0); params.num_qo_heads = q.size(1); params.num_kv_heads = num_kv_heads; diff --git a/csrc/batch_prefill_sm90_customize_config.jinja b/csrc/batch_prefill_sm90_customize_config.jinja index b37ecac60d..640637c7df 100644 --- a/csrc/batch_prefill_sm90_customize_config.jinja +++ b/csrc/batch_prefill_sm90_customize_config.jinja @@ -104,6 +104,11 @@ struct PagedParams { int64_t o_stride_h; int64_t nnz_qo; + // NOTE: For sparse paged KV cache, we need the stride between pages + // This is paged_k_cache.stride(0), not the layout stride + int64_t k_page_stride; // Stride between pages for K + int64_t v_page_stride; // Stride between pages for V + int head_dim; int num_qo_heads; int num_kv_heads; diff --git a/csrc/flashinfer_page_binding.cu b/csrc/flashinfer_page_binding.cu index dbab4f5cb8..97105712f7 100644 --- a/csrc/flashinfer_page_binding.cu +++ b/csrc/flashinfer_page_binding.cu @@ -27,12 +27,5 @@ void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe, TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr, TensorView kv_last_page_len); -void block_sparse_indices_to_vector_sparse_offsets( - TensorView block_sparse_indices, TensorView block_sparse_indptr, - TensorView vector_sparse_offsets, TensorView vector_sparse_indptr, TensorView kv_len_arr, - int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size); - TVM_FFI_DLL_EXPORT_TYPED_FUNC(append_paged_kv_cache, append_paged_kv_cache); TVM_FFI_DLL_EXPORT_TYPED_FUNC(append_paged_mla_kv_cache, append_paged_mla_kv_cache); -TVM_FFI_DLL_EXPORT_TYPED_FUNC(block_sparse_indices_to_vector_sparse_offsets, - block_sparse_indices_to_vector_sparse_offsets); diff --git a/csrc/page.cu b/csrc/page.cu index 614fc96640..65bee2bec6 100644 --- a/csrc/page.cu +++ b/csrc/page.cu @@ -112,31 +112,6 @@ void append_paged_kv_cache(TensorView append_key, TensorView append_value, Tenso << paged_k_cache.dtype(); } -void block_sparse_indices_to_vector_sparse_offsets( - TensorView block_sparse_indices, TensorView block_sparse_indptr, - TensorView vector_sparse_offsets, TensorView vector_sparse_indptr, TensorView kv_len_arr, - int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size) { - CHECK_INPUT(block_sparse_indices); - CHECK_INPUT(block_sparse_indptr); - CHECK_INPUT(vector_sparse_offsets); - CHECK_INPUT(vector_sparse_indptr); - CHECK_INPUT(kv_len_arr); - - cudaSetDevice(block_sparse_indices.device().device_id); - const cudaStream_t stream = get_stream(block_sparse_indices.device()); - - cudaError_t status = BlockSparseIndicesToVectorSparseOffset( - static_cast(block_sparse_indices.data_ptr()), - static_cast(block_sparse_indptr.data_ptr()), - static_cast(vector_sparse_offsets.data_ptr()), - static_cast(vector_sparse_indptr.data_ptr()), - static_cast(kv_len_arr.data_ptr()), stride_block, stride_n, batch_size, block_size, - stream); - - TVM_FFI_ICHECK(status == cudaSuccess) - << "BlockSparseIndicesToVectorSparseOffset failed with error: " << cudaGetErrorString(status); -} - void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe, TensorView batch_indices, TensorView positions, TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr, diff --git a/flashinfer/page.py b/flashinfer/page.py index 069303e501..b8b82792b9 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -34,42 +34,6 @@ def get_page_module(): return gen_page_module().build_and_load() -def block_sparse_indices_to_vector_sparse_offsets( - block_sparse_indices: torch.Tensor, - block_sparse_indptr: torch.Tensor, - vector_sparse_offsets: torch.Tensor, - vector_sparse_indptr: torch.Tensor, - kv_lens: torch.Tensor, - stride_block: int, - stride_n: int, - block_size: int, -) -> torch.Tensor: - if block_size == 1: - if stride_block == 1: - return block_sparse_indices - else: - return block_sparse_indices * stride_block - - assert block_sparse_indices.dtype == torch.int32 - assert block_sparse_indptr.dtype == torch.int32 - assert vector_sparse_offsets.dtype == torch.int32 - assert vector_sparse_indptr.dtype == torch.int32 - assert kv_lens.dtype == torch.int32 - batch_size = block_sparse_indptr.size(0) - 1 - get_page_module().block_sparse_indices_to_vector_sparse_offsets( - block_sparse_indices, - block_sparse_indptr, - vector_sparse_offsets, - vector_sparse_indptr, - kv_lens, - stride_block, - stride_n, - batch_size, - block_size, - ) - return vector_sparse_offsets - - @register_custom_op( "flashinfer::append_paged_mla_kv_cache", mutates_args=("ckv_cache", "kpe_cache"), diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 6b4353011f..a2a1cb73e8 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -33,7 +33,7 @@ gen_trtllm_gen_fmha_module, ) from .cudnn import cudnn_batch_prefill_with_kv_cache -from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens +from .page import get_seq_lens from .quantization import packbits, segment_packbits from .utils import ( FP4Tensor, @@ -413,7 +413,13 @@ def ragged_run( rope_scale: float, rope_theta: float, token_pos_in_items_len: int, + scale_q: Optional[torch.Tensor] = None, + scale_k: Optional[torch.Tensor] = None, + scale_v: Optional[torch.Tensor] = None, ) -> None: + # Check if FP8 by presence of scale tensors + is_fp8 = scale_q is not None + if backend == "fa2": ragged_run_func( float_workspace_buffer, @@ -439,10 +445,34 @@ def ragged_run( logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale - 1.0 / rope_theta, # rope_rcp_theta + 1.0 / rope_theta, # rope_rcp_theta, token_pos_in_items_len, ) + elif is_fp8: + # FA3 FP8: scale_q, scale_k, scale_v, sm_scale + ragged_run_func( + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + k, + v, + qo_indptr, + kv_indptr, + o, + maybe_lse, + mask_mode, + layout, + window_left, + enable_pdl, + scale_q, + scale_k, + scale_v, + sm_scale, + ) else: + # FA3 FP16: maybe_prefix_len_ptr, maybe_token_pos_in_items_ptr, + # maybe_max_item_len_ptr, logits_soft_cap, sm_scale, token_pos_in_items_len ragged_run_func( float_workspace_buffer, int_workspace_buffer, @@ -1424,16 +1454,6 @@ def __init__( * self._float_workspace_buffer.element_size() ) self.device = float_workspace_buffer.device - self._vector_sparse_indptr_buffer: Optional[torch.Tensor] = None - if backend in ["fa3", "auto", "trtllm-gen"]: - # NOTE(Zihao): assume maximum accumulate kv length is 16M - self._vector_sparse_indices_buffer = torch.empty( - (16 * 1024 * 1024,), dtype=torch.int32, device=self.device - ) - # NOTE(Zihao): assume maximum batch size is 32768 - self._vector_sparse_indptr_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device - ) self._kv_lens_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device @@ -1543,6 +1563,7 @@ def plan( rope_theta: Optional[float] = None, q_data_type: Union[str, torch.dtype] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, + o_data_type: Optional[Union[str, torch.dtype]] = None, non_blocking: bool = True, prefix_len_ptr: Optional[torch.Tensor] = None, token_pos_in_items_ptr: Optional[torch.Tensor] = None, @@ -1627,6 +1648,9 @@ def plan( The data type of the query tensor, defaults torch.float16. kv_data_type : Optional[Union[str, torch.dtype]] The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. + o_data_type : Optional[Union[str, torch.dtype]] + The data type of the output tensor. If None, will be set to :attr:`q_data_type`. + For FP8 inputs, this should typically be set to torch.float16. non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. prefix_len_ptr :Optional[torch.Tensor] @@ -1678,6 +1702,9 @@ def plan( if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) + if o_data_type is None: + o_data_type = q_data_type + o_data_type = canonicalize_torch_dtype(o_data_type) if logits_soft_cap is None: logits_soft_cap = 0.0 @@ -1808,6 +1835,7 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type + self._cached_o_data_type = o_data_type if self._jit_module is not None: self._cached_module = self._jit_module @@ -1825,7 +1853,7 @@ def plan( get_module_args = ( q_data_type, kv_data_type, - q_data_type, + o_data_type, paged_kv_indptr.dtype, head_dim_qk, head_dim_vo, @@ -1839,22 +1867,6 @@ def plan( self._backend, *get_module_args ) - if self._backend == "fa3" or self._backend == "trtllm-gen": - if page_size != 1: - vector_sparse_indptr_host = torch.cat( - [ - torch.tensor( - [0], dtype=torch.int32, device=kv_lens_arr_host.device - ), - torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), - ], - dim=0, - ) - self._vector_sparse_indptr_buffer[ - : len(vector_sparse_indptr_host) - ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) - paged_kv_indptr_host = vector_sparse_indptr_host - self._block_tables = block_tables if self._backend == "trtllm-gen": assert logits_soft_cap == 0.0 @@ -2042,13 +2054,10 @@ def run( q, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) - stride_block = k_cache.stride(0) if self._kv_layout == "NHD": page_size = k_cache.shape[1] - stride_n = k_cache.stride(1) else: page_size = k_cache.shape[2] - stride_n = k_cache.stride(2) window_left = self._window_left if window_left is None else window_left if self._backend != "trtllm-gen": # NOTE(Siyuan): since window_left is appeared in the plan function, we need to make sure it is the same as the one in the plan function. @@ -2081,12 +2090,15 @@ def run( ) if out is None: + # Use cached output data type if available (for FP8 attention with FP16 output) + out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype out = torch.empty( - q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device + q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device ) else: + out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype check_shape_dtype_device( - out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out" + out, q.shape[:-1] + v_cache.shape[-1:], out_dtype, q.device, "out" ) # Convert NHD layout to HND for trtllm-gen backend @@ -2106,24 +2118,6 @@ def run( if self._prefix_len_ptr is not None: mask_mode = MaskMode.MULTIITEMSCORING.value - if self._backend == "fa3": - # NOTE(Zihao): we divide both stride_block and stride_n by stride_n - # because we will multiply stride_n back in the kernel - sparse_indices = block_sparse_indices_to_vector_sparse_offsets( - self._paged_kv_indices_buf, - self._paged_kv_indptr_buf, - self._vector_sparse_indices_buffer, # output - self._vector_sparse_indptr_buffer, - self._kv_lens_buffer, - stride_block // stride_n, - 1, # stride_n // stride_n - page_size, - ) - sparse_indptr = self._vector_sparse_indptr_buffer - else: - sparse_indices = self._paged_kv_indices_buf - sparse_indptr = self._paged_kv_indptr_buf - if self._backend == "cudnn": if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1: self._seq_lens_q = self._seq_lens_q.reshape(self._batch_size, 1, 1, 1) @@ -2160,8 +2154,8 @@ def run( k_cache, v_cache, self._qo_indptr_buf, - sparse_indptr, - sparse_indices, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, out, lse, @@ -2173,6 +2167,14 @@ def run( if self._jit_module is not None: run_args.extend(list(args)) else: + # Extract FP8 scale tensors from *args if q is FP8 + fp8_scale_q = None + fp8_scale_k = None + fp8_scale_v = None + if is_float8(q) and len(args) >= 3: + fp8_scale_q = args[0] + fp8_scale_k = args[1] + fp8_scale_v = args[2] run_args += [ self._custom_mask_buf, self._mask_indptr_buf, @@ -2182,9 +2184,9 @@ def run( self._max_item_len_ptr, logits_soft_cap, sm_scale, - None, # scale_q, not supported yet - None, # scale_k - None, # scale_v + fp8_scale_q, + fp8_scale_k, + fp8_scale_v, rope_scale, rope_theta, self._token_pos_in_items_len, @@ -2198,7 +2200,7 @@ def run( self._max_kv_len, self._batch_size, self._qo_indptr_buf, - self._vector_sparse_indptr_buffer, + self._paged_kv_indptr_buf, sinks, ] @@ -2513,6 +2515,7 @@ def plan( rope_theta: Optional[float] = None, q_data_type: Union[str, torch.dtype] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, + o_data_type: Optional[Union[str, torch.dtype]] = None, non_blocking: bool = True, prefix_len_ptr: Optional[torch.Tensor] = None, token_pos_in_items_ptr: Optional[torch.Tensor] = None, @@ -2587,6 +2590,9 @@ def plan( The data type of the query tensor, defaults to torch.float16. kv_data_type : Optional[Union[str, torch.dtype]] The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. + o_data_type : Optional[Union[str, torch.dtype]] + The data type of the output tensor. If None, will be set to :attr:`q_data_type`. + For FP8 inputs, this should typically be set to torch.float16. non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. prefix_len_ptr :Optional[torch.Tensor] @@ -2627,6 +2633,9 @@ def plan( if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) + if o_data_type is None: + o_data_type = q_data_type + o_data_type = canonicalize_torch_dtype(o_data_type) if head_dim_vo is None: head_dim_vo = head_dim_qk if fixed_split_size is None: @@ -2699,6 +2708,7 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type + self._cached_o_data_type = o_data_type kv_len_arr = kv_indptr_host[1:] - kv_indptr_host[:-1] self._prefix_len_ptr = prefix_len_ptr @@ -2722,7 +2732,7 @@ def plan( get_module_args = ( q_data_type, kv_data_type, - q_data_type, + o_data_type, kv_indptr.dtype, head_dim_qk, head_dim_vo, @@ -2909,11 +2919,17 @@ def run( ) if out is None: out = torch.empty( - q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device + q.shape[:-1] + v.shape[-1:], + dtype=self._cached_o_data_type, + device=q.device, ) else: check_shape_dtype_device( - out, q.shape[:-1] + v.shape[-1:], q.dtype, q.device, "out" + out, + q.shape[:-1] + v.shape[-1:], + self._cached_o_data_type, + q.device, + "out", ) if self._backend == "cutlass": out, lse = fmha_varlen( @@ -2931,7 +2947,9 @@ def run( ) return (out, lse) if return_lse else out - if is_float8(q): + # Skip FP8->FP16 conversion for FA3 backend with FP8 support + # The JIT module will handle FP8 natively + if is_float8(q) and self._backend != "fa3": logging.warning( "Our current prefill kernel implementation needs f16 input, the f8 inputs " " are casted to f16, which could result in performance degradation." @@ -2980,6 +2998,9 @@ def run( rope_theta, self._token_pos_in_items_len, ] + # For FP8, append scale tensors + if is_float8(q): + run_args.extend(list(args)) # scale_q, scale_k, scale_v assert self._cached_module is not None, "cached module is not initialized" self._cached_module.ragged_run(*run_args) diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 37a3d444b7..e22efdb518 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -20,7 +20,6 @@ import torch from .decode import get_batch_decode_module -from .page import block_sparse_indices_to_vector_sparse_offsets from .prefill import _compute_page_mask_indptr, get_batch_prefill_module from .quantization import segment_packbits from .utils import ( @@ -133,16 +132,6 @@ def __init__( self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) - if backend in ["fa3", "auto"]: - # NOTE(Zihao): assume maximum accumulate kv length is 128M - # NOTE(Yilong): 128M is required by video DiT models - self._vector_sparse_indices_buffer = torch.empty( - (128 * 1024 * 1024,), dtype=torch.int32, device=self.device - ) - # NOTE(Zihao): assume maximum batch size is 32768 - self._vector_sparse_indptr_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device - ) self._kv_lens_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device @@ -171,8 +160,6 @@ def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, - vector_sparse_indices_buffer: Optional[torch.Tensor] = None, - vector_sparse_indptr_buffer: Optional[torch.Tensor] = None, ) -> None: r"""Reset the workspace buffer. @@ -197,12 +184,6 @@ def reset_workspace_buffer( pin_memory=True, ) - # Enable user-defined size - if vector_sparse_indices_buffer is not None: - self._vector_sparse_indices_buffer = vector_sparse_indices_buffer - if vector_sparse_indptr_buffer is not None: - self._vector_sparse_indptr_buffer = vector_sparse_indptr_buffer - def plan( self, indptr: torch.Tensor, @@ -438,20 +419,6 @@ def plan( kv_lens_arr_host, ) - if self._backend == "fa3": - if self.C != 1: - vector_sparse_indptr_host = torch.cat( - [ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), - ], - dim=0, - ) - self._vector_sparse_indptr_buffer[ - : len(vector_sparse_indptr_host) - ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) - kv_indptr_host = vector_sparse_indptr_host - args = [ self._float_workspace_buffer, self._int_workspace_buffer, @@ -582,9 +549,6 @@ def run( k = k.reshape(-1, self.C, *k.shape[-2:]) v = v.reshape(-1, self.C, *v.shape[-2:]) - stride_block = k.stride(0) - stride_n = k.stride(1) - if return_lse: if lse is None: lse = torch.empty( @@ -613,30 +577,6 @@ def run( scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device) if self._use_tensor_cores: - if self._backend == "fa3": - if ( - self._vector_sparse_indices_buffer.numel() - <= self._paged_kv_indices_buf.numel() * self.C - ): - raise ValueError( - "_vector_sparse_indices_buffer is not large enough. Please increase the size." - ) - - sparse_indices = block_sparse_indices_to_vector_sparse_offsets( - self._paged_kv_indices_buf, - self._paged_kv_indptr_buf, - self._vector_sparse_indices_buffer, # output - self._vector_sparse_indptr_buffer, - self._kv_lens_buffer, - stride_block // stride_n, - 1, # stride_n // stride_n - self.C, # block_size - ) - sparse_indptr = self._vector_sparse_indptr_buffer - else: - sparse_indices = self._paged_kv_indices_buf - sparse_indptr = self._paged_kv_indptr_buf - self._cached_module.paged_run( self._float_workspace_buffer, self._int_workspace_buffer, @@ -645,8 +585,8 @@ def run( k, v, self._qo_indptr, - sparse_indptr, - sparse_indices, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, self._paged_kv_last_page_len, out, lse, @@ -761,13 +701,6 @@ def __init__( self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) - if backend in ["fa3", "auto"]: - self._vector_sparse_indices_buffer = torch.empty( - (128 * 1024 * 1024,), dtype=torch.int32, device=self.device - ) - self._vector_sparse_indptr_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device - ) self._kv_lens_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device @@ -790,8 +723,6 @@ def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, - vector_sparse_indices_buffer: Optional[torch.Tensor] = None, - vector_sparse_indptr_buffer: Optional[torch.Tensor] = None, ) -> None: r"""Reset the workspace buffer. @@ -816,12 +747,6 @@ def reset_workspace_buffer( pin_memory=True, ) - # Enable user-defined size - if vector_sparse_indices_buffer is not None: - self._vector_sparse_indices_buffer = vector_sparse_indices_buffer - if vector_sparse_indptr_buffer is not None: - self._vector_sparse_indptr_buffer = vector_sparse_indptr_buffer - def plan( self, block_mask_map: torch.Tensor, @@ -1034,14 +959,6 @@ def _block_mask_map_to_expanded_indices( kv_lens_arr_host, ) - if self._backend == "fa3": - if self._vector_sparse_indptr_buffer.numel() <= kv_indptr.numel(): - raise ValueError( - "_vector_sparse_indptr_buffer is not large enough. Please increase the buffer size." - ) - self._vector_sparse_indptr_buffer[: len(kv_indptr)].copy_( - kv_indptr, non_blocking=non_blocking - ) args = [ self._float_workspace_buffer, self._int_workspace_buffer, @@ -1176,9 +1093,6 @@ def run( "num_kv_heads kv_len head_dim -> (num_kv_heads kv_len) 1 1 head_dim", ).contiguous() - stride_block = k.stride(0) - stride_n = k.stride(1) - if return_lse: if lse is None: lse = torch.empty( @@ -1194,30 +1108,6 @@ def run( else: check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out") - if self._backend == "fa3": - if ( - self._vector_sparse_indices_buffer.numel() - <= self._paged_kv_indices_buf.numel() - ): - raise ValueError( - "_vector_sparse_indices_buffer is not large enough. Please increase the buffer size." - ) - - sparse_indices = block_sparse_indices_to_vector_sparse_offsets( - self._paged_kv_indices_buf, - self._paged_kv_indptr_buf, - self._vector_sparse_indices_buffer, # output - self._vector_sparse_indptr_buffer, - self._kv_lens_buffer, - stride_block // stride_n, - 1, # stride_n // stride_n - 1, # block_size - ) - sparse_indptr = self._vector_sparse_indptr_buffer - else: - sparse_indices = self._paged_kv_indices_buf - sparse_indptr = self._paged_kv_indptr_buf - self._cached_module.paged_run( self._float_workspace_buffer, self._int_workspace_buffer, @@ -1226,8 +1116,8 @@ def run( k, v, self._qo_indptr, - sparse_indptr, - sparse_indices, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, self._paged_kv_last_page_len, out, lse, diff --git a/flashinfer/triton/kernels/cascade.py b/flashinfer/triton/kernels/cascade.py index 0439dc0440..88a9450010 100644 --- a/flashinfer/triton/kernels/cascade.py +++ b/flashinfer/triton/kernels/cascade.py @@ -148,8 +148,9 @@ def variable_length_merge_states_kernel( for head_idx in tl.range(bdy): o, m, d = 0.0, -5e4, 1.0 for iter in tl.range(tl.load(indptr + pos), tl.load(indptr + pos + 1)): - s = tl.load(s_ptr + iter * num_heads + head_idx) - v = tl.load(v_ptr + (iter * num_heads + head_idx) * head_dim + tx) + iter_i64 = iter.to(tl.int64) + s = tl.load(s_ptr + iter_i64 * num_heads + head_idx) + v = tl.load(v_ptr + (iter_i64 * num_heads + head_idx) * head_dim + tx) o, m, d = state_merge(o, m, d, v, s, 1) o, m, d = state_normalize(o, m, d) tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o) diff --git a/include/flashinfer/attention/hopper/default_params.cuh b/include/flashinfer/attention/hopper/default_params.cuh index f2b9d2e33e..bb2a33ac2c 100644 --- a/include/flashinfer/attention/hopper/default_params.cuh +++ b/include/flashinfer/attention/hopper/default_params.cuh @@ -154,6 +154,11 @@ struct BatchPrefillPagedParams { int64_t o_stride_h; int64_t nnz_qo; + // NOTE: For sparse paged KV cache, we need the stride between pages + // This is paged_k_cache.stride(0), not the layout stride + int64_t k_page_stride; // Stride between pages for K + int64_t v_page_stride; // Stride between pages for V + int num_qo_heads; int num_kv_heads; int group_size; diff --git a/include/flashinfer/attention/hopper/epilogue.cuh b/include/flashinfer/attention/hopper/epilogue.cuh index 81e43bd9a7..4950aa8be0 100644 --- a/include/flashinfer/attention/hopper/epilogue.cuh +++ b/include/flashinfer/attention/hopper/epilogue.cuh @@ -168,7 +168,7 @@ struct CollectiveEpilogue { /*id=*/static_cast(NamedBarriers::kValueEmpty)); cute::copy(smem_tiled_copy_O, tOrO_retile, tOsO); cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA - cutlass::arch::NamedBarrier::arrive(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, + cutlass::arch::NamedBarrier::arrive(NUM_MMA_THREADS, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.lse_ptr), epilogue_params.layout_LSE); @@ -194,11 +194,10 @@ struct CollectiveEpilogue { } } + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + int write_warp_idx = NUM_WARPS - 1; - if (cutlass::canonical_warp_idx_sync() == write_warp_idx) { - cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - } TiledCopyO gmem_tiled_copy_O; write_O(epilogue_params.O_ptr, gmem_tiled_copy_O, epilogue_params.layout_O, select<0, 1>(TileShape_PDV{}), sO, thread_idx, qo_tile_idx, diff --git a/include/flashinfer/attention/hopper/prefill_sm90.cuh b/include/flashinfer/attention/hopper/prefill_sm90.cuh index c9bee9466c..f1e441a53b 100644 --- a/include/flashinfer/attention/hopper/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/prefill_sm90.cuh @@ -379,7 +379,11 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, params.v_ptr, get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM_VO, params.v_stride_n, params.v_stride_h), // layout_V - params.kv_indices, params.window_left, params.additional_params}); + params.kv_indices, params.window_left, + params.k_page_stride, // Stride between pages for K + params.v_page_stride, // Stride between pages for V + static_cast(params.page_size), // Page size + params.additional_params}); typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments({ params.o_ptr, diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh index 1b17842443..4ab9ed4032 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh @@ -99,10 +99,13 @@ struct FP8SparseCollectiveMainloop { DTypeQ const* Q_ptr; LayoutT layout_Q; DTypeKV const* K_ptr; - LayoutT layout_K; + int64_t k_stride_n; // Stride between consecutive KV tokens + int64_t k_page_stride; // Stride between pages DTypeKV const* V_ptr; - LayoutT layout_V; + int64_t v_stride_n; // Stride between consecutive KV tokens + int64_t v_page_stride; // Stride between pages IdType const* kv_indices; + uint32_t page_size; // Size of each page int window_left; AdditionalParams additional_params; }; @@ -110,12 +113,15 @@ struct FP8SparseCollectiveMainloop { // Device side kernel params struct Params { LayoutT layout_Q; - LayoutT layout_K; - LayoutT layout_V; TMA_Q tma_load_Q; DTypeKV* K_ptr; + int64_t k_stride_n; + int64_t k_page_stride; DTypeKV* V_ptr; + int64_t v_stride_n; + int64_t v_page_stride; IdType* kv_indices; + uint_fastdiv page_size; // Size of each page (as fastdiv for efficient divmod) int window_left; AdditionalParams additional_params; using DTypeKV = typename Ktraits::DTypeKV; @@ -125,15 +131,10 @@ struct FP8SparseCollectiveMainloop { Tensor mQ = make_tensor(make_gmem_ptr(args.Q_ptr), args.layout_Q); TMA_Q tma_load_Q = make_tma_copy(GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{}); - return {args.layout_Q, - args.layout_K, - args.layout_V, - tma_load_Q, - const_cast(args.K_ptr), - const_cast(args.V_ptr), - const_cast(args.kv_indices), - args.window_left, - args.additional_params}; + return {args.layout_Q, tma_load_Q, const_cast(args.K_ptr), + args.k_stride_n, args.k_page_stride, const_cast(args.V_ptr), + args.v_stride_n, args.v_page_stride, const_cast(args.kv_indices), + args.page_size, args.window_left, args.additional_params}; } CUTLASS_DEVICE @@ -208,43 +209,118 @@ struct FP8SparseCollectiveMainloop { constexpr int HEAD_DIM = get<2>(TileShape_QKD{}); constexpr int CTA_KV = get<1>(TileShape_QKD{}); - auto indexed_gather = BlockSparseIndexedGather(mainloop_params.kv_indices + kv_indptr); + IdType const* kv_indices_ptr = mainloop_params.kv_indices + kv_indptr; - Tensor mK = make_block_sparse_tensor( // (kv_len, D) - make_gmem_ptr(mainloop_params.K_ptr + kv_head_idx * stride<2>(mainloop_params.layout_K)), - make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_K), indexed_gather); - Tensor mV = make_block_sparse_tensor( // (kv_len, D) - make_gmem_ptr(mainloop_params.V_ptr + kv_head_idx * stride<2>(mainloop_params.layout_V)), - make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_V), indexed_gather); - - Tensor gK = local_tile(mK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv) - Tensor gV = local_tile(mV, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv) - Tensor cKV = cute::make_identity_tensor(gK.shape()); + // Setup for manual K/V loading with page table + DTypeKV* k_base_ptr = mainloop_params.K_ptr; + DTypeKV* v_base_ptr = mainloop_params.V_ptr; + int64_t k_stride_n = mainloop_params.k_stride_n; + int64_t k_page_stride = mainloop_params.k_page_stride; + int64_t v_stride_n = mainloop_params.v_stride_n; + int64_t v_page_stride = mainloop_params.v_page_stride; GmemTiledCopyKV gmem_tiled_copy_kv; auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx); - Tensor tKgK = gmem_thr_copy_kv.partition_S(gK); // (CPY, CPY_KV, CPY_D, kv) - Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE) - Tensor tVgV = gmem_thr_copy_kv.partition_S(gV); // (CPY, CPY_KV, CPY_D, kv) - Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE) + // Create coordinate tensors for partitioning + Tensor cKV = cute::make_identity_tensor(make_shape(CTA_KV, HEAD_DIM)); Tensor tKVcKV = gmem_thr_copy_kv.partition_D(cKV); // (CPY, CPY_KV, CPY_D) Tensor tKVcKVGroup = flatten_1(tKVcKV); // (CPY, (CPY_KV, CPY_D)) + Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE) - int valid_last_kv_tile_size = std::min(kv_len - kv_tile_idx * CTA_KV, CTA_KV); - auto predicate_fn = [&](auto coords) { - auto s_coords = tKVcKVGroup(_0{}, coords); - return elem_less(get<0>(s_coords), valid_last_kv_tile_size); + // FA3-style prefetch offset optimization: pre-compute page offsets and share via shuffle + // This reduces redundant page table lookups and address calculations + int64_t my_kv_offset[2]; // Rolling buffer: page_idx * page_stride + entry_idx * stride_n + + // Group organization based on partition strategy (same as FP16 sparse_mainloop) + // For FP8 with cp.async: AlignmentKV=16 (128bits/8bits), NUM_PRODUCER_THREADS=128 + // The simt gmem tiled copy partitions threads as: (thread_stride_M, thread_stride_K) + // where thread_stride_M = threads / (CTA_KV / AlignmentKV) for column-major + // NUM_KV_PER_ITER = number of KV elements each thread handles per iteration + // + // The tiled copy arrangement: + // - Each thread loads AlignmentKV (16) elements contiguously in the D dimension + // - Threads are spread across the (KV, D) tile + // For column-major: threads stride by (D/AlignmentKV) in the KV dimension + // D_stride = HEAD_DIM / AlignmentKV (e.g., 128/16=8 or 256/16=16) + // Thread arrangement: threads = KV_stride * D_stride + // So KV_stride = NUM_COPY_THREADS / D_stride = NUM_COPY_THREADS * AlignmentKV / HEAD_DIM + // NUM_KV_PER_ITER = CTA_KV / KV_stride = CTA_KV * HEAD_DIM / (NUM_COPY_THREADS * AlignmentKV) + static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS; + constexpr int NUM_KV_PER_ITER = CTA_KV * HEAD_DIM / (NUM_COPY_THREADS * AlignmentKV); + constexpr int KV_STRIDE = CTA_KV / NUM_KV_PER_ITER; + constexpr int NUM_GROUPS = KV_STRIDE; + constexpr int THREADS_PER_GROUP = NUM_COPY_THREADS / NUM_GROUPS; + constexpr int NUM_ITERS_PER_GROUP = NUM_KV_PER_ITER; + + int group_id = thread_idx / THREADS_PER_GROUP; + int thread_in_group = thread_idx % THREADS_PER_GROUP; + + // Prefetch: compute page_idx * page_stride + entry_idx * stride_n + auto prefetch_kv_offset = [&](int kv_tile_idx, int64_t stride_n, int64_t page_stride, + bool use_predicate) { + int kv_base_idx = kv_tile_idx * CTA_KV; + int buf_idx = kv_tile_idx % 2; + + int kv_idx_read = kv_base_idx + group_id + thread_in_group * KV_STRIDE; + bool valid_read = + thread_in_group < NUM_ITERS_PER_GROUP && (!use_predicate || kv_idx_read < kv_len); + + if (valid_read) { + // Use divmod to find page and offset within page + uint32_t page_iter, entry_idx; + mainloop_params.page_size.divmod(kv_idx_read, page_iter, entry_idx); + IdType page_idx = kv_indices_ptr[page_iter]; + // Pre-compute: page_idx * page_stride + entry_idx * stride_n + my_kv_offset[buf_idx] = page_idx * page_stride + entry_idx * stride_n; + } else { + my_kv_offset[buf_idx] = 0; + } }; - // load last k-tile + // Load K/V with pre-computed offsets using shuffle + auto load_kv_with_prefetch = [&](DTypeKV* base_ptr, auto& tXsX, int tile_idx, int pipe_idx, + bool use_predicate) { + using Vec = AlignmentTypeKV; + constexpr int VecSize = AlignmentKV; + + int kv_base_idx = tile_idx * CTA_KV; + int buf_idx = tile_idx % 2; + + auto dst = recast(flatten(tXsX(_, _, _, pipe_idx))); + auto c = flatten(tKVcKV); + + constexpr unsigned FULL_MASK = 0xffffffff; + + // Load using FA3-style shuffle with pre-computed offsets + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst); ++i) { + auto coord = c(VecSize * i); + int kv_offset = get<0>(coord); + int d_idx = get<1>(coord); + int kv_idx = kv_base_idx + kv_offset; + bool guard = !use_predicate || kv_idx < kv_len; + + // Shuffle the pre-computed offset (page_idx * page_stride + entry_idx * stride_n) + int src_thread = group_id * THREADS_PER_GROUP + kv_offset / KV_STRIDE; + int64_t base_offset = __shfl_sync(FULL_MASK, my_kv_offset[buf_idx], src_thread); + + // Final address: base_ptr + base_offset + d_idx + Vec const* src_ptr = reinterpret_cast(base_ptr + base_offset + d_idx); + cutlass::arch::cp_async_zfill( + &dst(i), src_ptr, guard); + } + }; + + int valid_last_kv_tile_size = std::min(kv_len - kv_tile_idx * CTA_KV, CTA_KV); + + // load last k-tile with prefetch optimization // all threads are issuing as TMA is disabled { + prefetch_kv_offset(kv_tile_idx, k_stride_n, k_page_stride, true); pipeline_k.producer_acquire(smem_pipe_write); - Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) - Tensor tKsKiGroup = - flatten_1(tKsK(_, _, _, smem_pipe_write.index())); // (CPY, (CPY_KV, CPY_D)) - copy_if(gmem_tiled_copy_kv, predicate_fn, tKgKiGroup, tKsKiGroup); + load_kv_with_prefetch(k_base_ptr, tKsK, kv_tile_idx, smem_pipe_write.index(), true); pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); } @@ -264,12 +340,9 @@ struct FP8SparseCollectiveMainloop { shared_storage.barrier_O.wait((work_idx + 1) % 2); if (kv_tile_idx == swa_begin_kv_tile_idx) { - // first tile is the last tile + // first tile is the last tile, reuse kv_tile_idx prefetch for V pipeline_v.producer_acquire(smem_pipe_write); - Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) - Tensor tVsViGroup = - flatten_1(tVsV(_, _, _, smem_pipe_write.index())); // (CPY, (CPY_KV, CPY_D)) - copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), true); pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); // Transpose V @@ -282,11 +355,12 @@ struct FP8SparseCollectiveMainloop { ++smem_pipe_write; // update state, as K is loaded 1 step faster } else { // load second last k-tile and last v-tile + // Prefetch for next K tile (kv_tile_idx - 1) + prefetch_kv_offset(kv_tile_idx - 1, k_stride_n, k_page_stride, false); + + // Load V using prefetch from last K load (kv_tile_idx) pipeline_v.producer_acquire(smem_pipe_write); - Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) - Tensor tVsViGroup = - flatten_1(tVsV(_, _, _, smem_pipe_write.index())); // (CPY, (CPY_KV, CPY_D)) - copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), true); pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); // Transpose V @@ -298,10 +372,9 @@ struct FP8SparseCollectiveMainloop { ++smem_pipe_read; ++smem_pipe_write; // update state, as K is loaded 1 step faster + // Load K (kv_tile_idx - 1) using prefetched offset pipeline_k.producer_acquire(smem_pipe_write); - Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) - Tensor tKsKi = tKsK(_, _, _, smem_pipe_write.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + load_kv_with_prefetch(k_base_ptr, tKsK, kv_tile_idx - 1, smem_pipe_write.index(), false); pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); --kv_tile_idx; @@ -309,10 +382,12 @@ struct FP8SparseCollectiveMainloop { // load remaining k/v tiles #pragma unroll 2 for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { + // Prefetch for next K tile + prefetch_kv_offset(kv_tile_idx - 1, k_stride_n, k_page_stride, false); + + // Load V using prefetch from previous K prefetch pipeline_v.producer_acquire(smem_pipe_write); - Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D) - Tensor tVsVi = tVsV(_, _, _, smem_pipe_write.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_kv, tVgVi, tVsVi); + load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), false); pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); // Transpose V @@ -324,20 +399,18 @@ struct FP8SparseCollectiveMainloop { ++smem_pipe_read; ++smem_pipe_write; // update state, as K is loaded 1 step faster + // Load K using prefetched offset pipeline_k.producer_acquire(smem_pipe_write); - Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) - Tensor tKsKi = tKsK(_, _, _, smem_pipe_write.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + load_kv_with_prefetch(k_base_ptr, tKsK, kv_tile_idx - 1, smem_pipe_write.index(), false); pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); } scheduler.prefetch_next_work(scheduler_params, work_tile_info); // load first v tile { + prefetch_kv_offset(0, v_stride_n, v_page_stride, false); pipeline_v.producer_acquire(smem_pipe_write); - Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D)) - Tensor tVsVi = tVsV(_, _, _, smem_pipe_write.index()); // (CPY, (CPY_KV, CPY_D)) - copy(gmem_tiled_copy_kv, tVgVi, tVsVi); + load_kv_with_prefetch(v_base_ptr, tVsV, 0, smem_pipe_write.index(), false); pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); // Transpose V diff --git a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh index 24e416b61b..b6bde27e3c 100644 --- a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh @@ -334,13 +334,14 @@ cudaError_t BatchFP8PrefillWithPagedKVCacheKernelTraitsDispatched(Params& params params.q_stride_n, params.q_stride_h), // layout_Q params.k_ptr, - // NOTE(Zihao): nnz was useless here, we can just pass 0 - get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM, params.k_stride_n, - params.k_stride_h), // layout_K + params.k_stride_n, // k_stride_n + params.k_page_stride, // k_page_stride params.v_ptr, - get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM, params.v_stride_n, - params.v_stride_h), // layout_V - params.kv_indices, params.window_left, params.additional_params}); + params.v_stride_n, // v_stride_n + params.v_page_stride, // v_page_stride + params.kv_indices, + static_cast(params.page_size), // page_size + params.window_left, params.additional_params}); typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments({ params.o_ptr, @@ -458,9 +459,127 @@ cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched(Params& params, bool enabl LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } else { // HEAD_DIM == 256; - // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 256, need to optimize later + // NOTE: Use smaller CTA_KV=64 for sparse paged loading to reduce page table lookup overhead + // (FP8 transpose requires minimum 64x64 blocks, so CTA_KV cannot be smaller than 64) BatchFP8PrefillWithPagedKVCacheKernelTraitsDispatched< FP8AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + } + cudaError_t status = cudaGetLastError(); + return status; +}; + +template +cudaError_t BatchFP8PrefillWithRaggedKVCacheKernelTraitsDispatched(Params& params, + cudaStream_t stream) { + using DTypeQ = typename KernelTraits::DTypeQ; + using DTypeKV = typename KernelTraits::DTypeKV; + using DTypeO = typename KernelTraits::DTypeO; + using IdType = typename KernelTraits::IdType; + + using CollectiveMainloop = + FP8CollectiveMainloop; + using CollectiveEpilogue = FP8CollectiveEpilogue; + using Scheduler = + std::conditional_t, + BatchPrefillPersistentTileScheduler>; + typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( + {params.q_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM, + params.q_stride_n, + params.q_stride_h), // layout_Q + params.k_ptr, + // NOTE(Zihao): nnz was useless here, we can just pass 0 + get_gmem_layout(params.nnz_kv, params.num_kv_heads, KernelTraits::HEAD_DIM, + params.k_stride_n, + params.k_stride_h), // layout_K + params.v_ptr, + get_gmem_layout(params.nnz_kv, params.num_kv_heads, KernelTraits::HEAD_DIM, + params.v_stride_n, + params.v_stride_h), // layout_V + params.window_left, params.additional_params}); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments({ + params.o_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM, + params.o_stride_n, + params.o_stride_h), // layout_O + params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE + }); + + // NOTE(Zihao): add support for kv head-major later + typename Scheduler::Arguments scheduler_args = { + params.work_indptr, + params.head_indices, + params.qo_tile_indices, + params.qo_indptr, + params.kv_indptr, + params.qo_lens, + params.kv_lens, + params.batch_indices, + cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads), + params.num_qo_heads}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + + // Get the ptr to kernel function. + auto kernel = + (void*)FP8PrefillWithKVCacheKernel; + int smem_size = sizeof(typename KernelTraits::SharedStorage); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int device; + cudaGetDevice(&device); + int multiprocessor_count; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); + static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; + dim3 block_dims(ctaSize); + void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); + + return cudaSuccess; +} + +template +cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched(Params& params, bool enable_pdl, + cudaStream_t stream) { + static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + if (MASK_MODE == MaskMode::kCustom) { + return cudaErrorNotSupported; // Not supported yet. + } + constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; + if constexpr (HEAD_DIM == 64) { + BatchFP8PrefillWithRaggedKVCacheKernelTraitsDispatched< + FP8AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + } else if constexpr (HEAD_DIM == 128) { + BatchFP8PrefillWithRaggedKVCacheKernelTraitsDispatched< + FP8AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + } else { + // HEAD_DIM == 256; + BatchFP8PrefillWithRaggedKVCacheKernelTraitsDispatched< + FP8AttentionKernelTraits #include +#include "../../fastdiv.cuh" #include "../../math.cuh" -#include "block_sparse_gather.cuh" #include "cute/tensor.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/pipeline/pipeline.hpp" @@ -107,6 +107,9 @@ struct SparseCollectiveMainloop { LayoutT layout_V; IdType const* kv_indices; int window_left; + int64_t k_page_stride; // Stride between pages for K (paged_k.stride(0)) + int64_t v_page_stride; // Stride between pages for V (paged_v.stride(0)) + uint32_t page_size; // Size of each page AdditionalParams additional_params; }; @@ -120,6 +123,9 @@ struct SparseCollectiveMainloop { DTypeKV* V_ptr; IdType* kv_indices; int window_left; + int64_t k_page_stride; // Stride between pages for K + int64_t v_page_stride; // Stride between pages for V + uint_fastdiv page_size; // Size of each page (as fastdiv for efficient divmod) AdditionalParams additional_params; }; @@ -135,6 +141,9 @@ struct SparseCollectiveMainloop { const_cast(args.V_ptr), const_cast(args.kv_indices), args.window_left, + args.k_page_stride, // Use stride from arguments + args.v_page_stride, // Use stride from arguments + uint_fastdiv(args.page_size), // Convert page_size to fastdiv args.additional_params}; } @@ -203,45 +212,47 @@ struct SparseCollectiveMainloop { constexpr int HEAD_DIM_QK = get<2>(TileShape_QKD{}); constexpr int HEAD_DIM_VO = get<1>(TileShape_PDV{}); constexpr int CTA_KV = get<1>(TileShape_QKD{}); - auto indexed_gather = BlockSparseIndexedGather(mainloop_params.kv_indices + kv_indptr); - - Tensor mK = make_block_sparse_tensor( // (kv_len, D_K) - make_gmem_ptr(mainloop_params.K_ptr + kv_head_idx * stride<2>(mainloop_params.layout_K)), - make_shape(kv_len, HEAD_DIM_QK), stride<0>(mainloop_params.layout_K), indexed_gather); - Tensor mV = make_block_sparse_tensor( // (kv_len, D_V) - make_gmem_ptr(mainloop_params.V_ptr + kv_head_idx * stride<2>(mainloop_params.layout_V)), - make_shape(kv_len, HEAD_DIM_VO), stride<0>(mainloop_params.layout_V), indexed_gather); + // Store base pointers and indices for manual page table lookup + DTypeKV* K_ptr_base = mainloop_params.K_ptr + kv_head_idx * stride<2>(mainloop_params.layout_K); + DTypeKV* V_ptr_base = mainloop_params.V_ptr + kv_head_idx * stride<2>(mainloop_params.layout_V); + IdType const* kv_indices_ptr = mainloop_params.kv_indices + kv_indptr; + // Use the page stride (stride between pages) and stride within page + int64_t k_page_stride = mainloop_params.k_page_stride; + int64_t v_page_stride = mainloop_params.v_page_stride; + int64_t k_stride_n = + stride<0>(mainloop_params.layout_K); // Stride within page (between tokens) + int64_t v_stride_n = stride<0>(mainloop_params.layout_V); + + // Create dummy tensors for partitioning with contiguous column-major layout + // NOTE: We use a virtual contiguous layout for correct partitioning, + // actual addressing uses page table lookup Tensor gK = - local_tile(mK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D_K, kv) + make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(CTA_KV, HEAD_DIM_QK), + make_stride(HEAD_DIM_QK, _1{})); // Column-major: (KV, D) + Tensor gK_tiled = + local_tile(gK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D_K, kv) Tensor gV = - local_tile(mV, select<2, 1>(TileShape_PDV{}), make_coord(_, _0{})); // (KV, D_V, kv) - Tensor cK = cute::make_identity_tensor(gK.shape()); - Tensor cV = cute::make_identity_tensor(gV.shape()); + make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(CTA_KV, HEAD_DIM_VO), + make_stride(HEAD_DIM_VO, _1{})); // Column-major: (KV, D) + Tensor gV_tiled = + local_tile(gV, select<2, 1>(TileShape_PDV{}), make_coord(_, _0{})); // (KV, D_V, kv) + Tensor cK = cute::make_identity_tensor(gK_tiled.shape()); + Tensor cV = cute::make_identity_tensor(gV_tiled.shape()); GmemTiledCopyK gmem_tiled_copy_k; GmemTiledCopyV gmem_tiled_copy_v; auto gmem_thr_copy_k = gmem_tiled_copy_k.get_slice(thread_idx); auto gmem_thr_copy_v = gmem_tiled_copy_v.get_slice(thread_idx); - Tensor tKgK = gmem_thr_copy_k.partition_S(gK); // (CPY, CPY_KV, CPY_D, kv) - Tensor tKsK = gmem_thr_copy_k.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE) - Tensor tVgV = gmem_thr_copy_v.partition_S(gV); // (CPY, CPY_KV, CPY_D, kv) - Tensor tVsV = gmem_thr_copy_v.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE) - Tensor tKcK = gmem_thr_copy_k.partition_D(cK); // (CPY, CPY_KV, CPY_D) - Tensor tKcKGroup = flatten_1(tKcK); // (CPY, (CPY_KV, CPY_D)) - Tensor tVcV = gmem_thr_copy_v.partition_D(cV); // (CPY, CPY_KV, CPY_D) - Tensor tVcVGroup = flatten_1(tVcV); // (CPY, (CPY_KV, CPY_D)) + Tensor tKgK = gmem_thr_copy_k.partition_S(gK_tiled); // (CPY, CPY_KV, CPY_D, kv) + Tensor tKsK = gmem_thr_copy_k.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tVgV = gmem_thr_copy_v.partition_S(gV_tiled); // (CPY, CPY_KV, CPY_D, kv) + Tensor tVsV = gmem_thr_copy_v.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tKcK = gmem_thr_copy_k.partition_D(cK); // (CPY, CPY_KV, CPY_D, kv) + Tensor tVcV = gmem_thr_copy_v.partition_D(cV); // (CPY, CPY_KV, CPY_D, kv) int valid_last_kv_tile_size = std::min(kv_len - kv_tile_idx * CTA_KV, CTA_KV); - auto k_predicate_fn = [&](auto coords) { - auto s_coords = tKcKGroup(_0{}, coords); - return elem_less(get<0>(s_coords), valid_last_kv_tile_size); - }; - auto v_predicate_fn = [&](auto coords) { - auto s_coords = tVcVGroup(_0{}, coords); - return elem_less(get<0>(s_coords), valid_last_kv_tile_size); - }; auto kv_tile_idx_decrement = [&](int kv_tile_idx) { int result = kv_tile_idx - 1; if constexpr (MULTIITEMSCORING) { @@ -253,14 +264,81 @@ struct SparseCollectiveMainloop { return result; }; + // FA3-style cooperative loading: store pre-computed base offset for each KV position + int64_t my_kv_offset[2]; // Rolling buffer: page_idx * page_stride + entry_idx * stride_n + + // Group organization based on partition strategy + constexpr int NUM_KV_PER_ITER = decltype(size<1>(tKcK))::value; // e.g., 12 + constexpr int KV_STRIDE = CTA_KV / NUM_KV_PER_ITER; // 96/12 = 8 + constexpr int NUM_GROUPS = KV_STRIDE; // 8 groups (one per lane) + constexpr int THREADS_PER_GROUP = NUM_COPY_THREADS / NUM_GROUPS; // 128/8 = 16 + constexpr int NUM_ITERS_PER_GROUP = NUM_KV_PER_ITER; // 12 iterations per group + + int group_id = thread_idx / THREADS_PER_GROUP; // 0-7 + int thread_in_group = thread_idx % THREADS_PER_GROUP; // 0-15 + + // Prefetch: compute page_idx * page_stride + entry_idx * stride_n + // NOTE: Assumes K and V have same strides (asserted on host side) + auto prefetch_kv_offset = [&](int kv_tile_idx, bool use_predicate) { + int kv_base_idx = kv_tile_idx * CTA_KV; + int buf_idx = kv_tile_idx % 2; + + int kv_idx_read = kv_base_idx + group_id + thread_in_group * KV_STRIDE; + bool valid_read = + thread_in_group < NUM_ITERS_PER_GROUP && (!use_predicate || kv_idx_read < kv_len); + + if (valid_read) { + // Use divmod to find page and offset within page + uint32_t page_iter, entry_idx; + mainloop_params.page_size.divmod(kv_idx_read, page_iter, entry_idx); + IdType page_idx = kv_indices_ptr[page_iter]; + // Pre-compute: page_idx * page_stride + entry_idx * stride_n + my_kv_offset[buf_idx] = page_idx * k_page_stride + entry_idx * k_stride_n; + } else { + my_kv_offset[buf_idx] = 0; + } + }; + + // Unified helper lambda to load K or V with pre-computed offsets + auto load_kv_with_gather = [&](auto&& tXsX, auto&& tXcX, DTypeKV* base_ptr, int kv_tile_idx, + int stage_idx, bool use_predicate) { + using Vec = AlignmentTypeKV; + constexpr int VecSize = sizeof(Vec) / sizeof(DTypeKV); + + int kv_base_idx = kv_tile_idx * CTA_KV; + int buf_idx = kv_tile_idx % 2; + + auto dst = recast(flatten(tXsX(_, _, _, stage_idx))); + auto c = flatten(tXcX(_, _, _, kv_tile_idx)); + + constexpr unsigned FULL_MASK = 0xffffffff; + + // Load using FA3-style shuffle with pre-computed offsets + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst); ++i) { + auto coord = c(VecSize * i); + int kv_offset = get<0>(coord); + int d_idx = get<1>(coord); + int kv_idx = kv_base_idx + kv_offset; + bool guard = !use_predicate || kv_idx < kv_len; + + // Shuffle the pre-computed offset (page_idx * page_stride + entry_idx * stride_n) + int src_thread = group_id * THREADS_PER_GROUP + kv_offset / KV_STRIDE; + int64_t base_offset = __shfl_sync(FULL_MASK, my_kv_offset[buf_idx], src_thread); + + // Final address: base_ptr + base_offset + d_idx + // where base_offset = page_idx * page_stride + entry_idx * stride_n + Vec const* src_ptr = reinterpret_cast(base_ptr + base_offset + d_idx); + cutlass::arch::cp_async_zfill( + &dst(i), src_ptr, guard); + } + }; + // load last k-tile { + prefetch_kv_offset(kv_tile_idx, true); pipeline_k.producer_acquire(smem_pipe_write_k); - Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) - Tensor tKsKiGroup = - flatten_1(tKsK(_, _, _, smem_pipe_write_k.index())); // (CPY, (CPY_KV, CPY_D)) - copy_if(gmem_tiled_copy_k, k_predicate_fn, tKgKiGroup, tKsKiGroup); - + load_kv_with_gather(tKsK, tKcK, K_ptr_base, kv_tile_idx, smem_pipe_write_k.index(), true); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_k; } @@ -284,30 +362,23 @@ struct SparseCollectiveMainloop { shared_storage.barrier_O.wait((work_idx + 1) % 2); if (kv_tile_idx == swa_begin_kv_tile_idx) { + // kv_tile_idx already prefetched above, reuse it for V pipeline_v.producer_acquire(smem_pipe_write_v); - Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) - Tensor tVsViGroup = - flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) - copy_if(gmem_tiled_copy_v, v_predicate_fn, tVgViGroup, tVsViGroup); - + load_kv_with_gather(tVsV, tVcV, V_ptr_base, kv_tile_idx, smem_pipe_write_v.index(), true); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_v; } else { // load second last k-tile and last v-tile + int kv_tile_k = kv_tile_idx_decrement(kv_tile_idx); + prefetch_kv_offset(kv_tile_k, false); pipeline_k.producer_acquire(smem_pipe_write_k); - Tensor tKgKi = tKgK(_, _, _, kv_tile_idx_decrement(kv_tile_idx)); // (CPY, CPY_KV, CPY_D) - Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_k, tKgKi, tKsKi); - + load_kv_with_gather(tKsK, tKcK, K_ptr_base, kv_tile_k, smem_pipe_write_k.index(), false); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_k; + // kv_tile_idx already prefetched above, reuse it for V pipeline_v.producer_acquire(smem_pipe_write_v); - Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) - Tensor tVsViGroup = - flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) - copy_if(gmem_tiled_copy_v, v_predicate_fn, tVgViGroup, tVsViGroup); - + load_kv_with_gather(tVsV, tVcV, V_ptr_base, kv_tile_idx, smem_pipe_write_v.index(), true); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx); ++smem_pipe_write_v; @@ -316,20 +387,16 @@ struct SparseCollectiveMainloop { #pragma unroll 2 for (; kv_tile_idx > swa_begin_kv_tile_idx; kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { + int kv_tile_k = kv_tile_idx_decrement(kv_tile_idx); + prefetch_kv_offset(kv_tile_k, false); pipeline_k.producer_acquire(smem_pipe_write_k); - - Tensor tKgKi = tKgK(_, _, _, kv_tile_idx_decrement(kv_tile_idx)); // (CPY, CPY_KV, CPY_D) - Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_k, tKgKi, tKsKi); - + load_kv_with_gather(tKsK, tKcK, K_ptr_base, kv_tile_k, smem_pipe_write_k.index(), false); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_k; + // kv_tile_idx already prefetched in previous iteration, reuse pipeline_v.producer_acquire(smem_pipe_write_v); - Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D) - Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_v, tVgVi, tVsVi); - + load_kv_with_gather(tVsV, tVcV, V_ptr_base, kv_tile_idx, smem_pipe_write_v.index(), false); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_v; } @@ -337,10 +404,9 @@ struct SparseCollectiveMainloop { // load first v tile { + prefetch_kv_offset(0, false); pipeline_v.producer_acquire(smem_pipe_write_v); - Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D)) - Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, (CPY_KV, CPY_D)) - copy(gmem_tiled_copy_v, tVgVi, tVsVi); + load_kv_with_gather(tVsV, tVcV, V_ptr_base, 0, smem_pipe_write_v.index(), false); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_v; } diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index 1f5d328da8..efc224b4e1 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -283,55 +283,6 @@ __global__ void AppendPagedKVCacheKernel(paged_kv_t paged_kv, } } -template -__global__ void BlockSparseIndicesToVectorSparseOffsetsKernel( - IdType* __restrict__ block_sparse_indices, IdType* __restrict__ block_sparse_indptr, - IdType* __restrict__ vector_sparse_offsets, IdType* __restrict__ vector_sparse_indptr, - IdType* __restrict__ kv_lens, const uint32_t stride_block, const uint32_t stride_n, - const uint32_t batch_size, const uint_fastdiv block_size) { -#pragma unroll 1 - for (int b = blockIdx.x; b < batch_size; ++b) { -#pragma unroll 2 - for (int pos = threadIdx.x; pos < kv_lens[b]; pos += blockDim.x) { - uint32_t q, r; - block_size.divmod(pos, q, r); - vector_sparse_offsets[vector_sparse_indptr[b] + pos] = - block_sparse_indices[block_sparse_indptr[b] + q] * stride_block + r * stride_n; - } - } -} - -template -cudaError_t BlockSparseIndicesToVectorSparseOffset( - IdType* block_sparse_indices, IdType* block_sparse_indptr, IdType* vector_sparse_offsets, - IdType* vector_sparse_indptr, IdType* kv_lens, const int64_t stride_block, - const int64_t stride_n, const int64_t batch_size, const uint32_t block_size, - cudaStream_t stream = nullptr) { - int dev_id = 0; - int num_sms = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - - uint32_t num_threads = 512; - - uint_fastdiv block_size_fastdiv(block_size); - - auto kernel = BlockSparseIndicesToVectorSparseOffsetsKernel; - void* args[] = {(void*)&block_sparse_indices, - (void*)&block_sparse_indptr, - (void*)&vector_sparse_offsets, - (void*)&vector_sparse_indptr, - (void*)&kv_lens, - (void*)&stride_block, - (void*)&stride_n, - (void*)&batch_size, - (void*)&block_size_fastdiv}; - - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, num_sms, num_threads, args, 0, stream)); - - return cudaSuccess; -} - /*! * \brief Append new keys/values to the paged key-value cache in the decode phase * \tparam DType The data type of the key-value cache diff --git a/scripts/task_jit_run_tests_part2.sh b/scripts/task_jit_run_tests_part2.sh index b4bb6bf17c..c85d9d7e28 100755 --- a/scripts/task_jit_run_tests_part2.sh +++ b/scripts/task_jit_run_tests_part2.sh @@ -11,10 +11,10 @@ if [ "$SKIP_INSTALL" = "0" ]; then fi # Run each test file separately to isolate CUDA memory issues -pytest -s tests/utils/test_block_sparse.py pytest -s tests/utils/test_jit_example.py pytest -s tests/utils/test_jit_warmup.py pytest -s tests/utils/test_norm.py +pytest -s tests/attention/test_block_sparse.py pytest -s tests/attention/test_rope.py pytest -s tests/attention/test_mla_page.py pytest -s tests/utils/test_quantization.py diff --git a/tests/attention/test_batch_prefill_kernels.py b/tests/attention/test_batch_prefill_kernels.py index f067a70c62..0242c28ea9 100644 --- a/tests/attention/test_batch_prefill_kernels.py +++ b/tests/attention/test_batch_prefill_kernels.py @@ -144,13 +144,17 @@ def test_batch_prefill_with_paged_kv_cache( logits_soft_cap=logits_soft_cap, ) if return_lse: - o, _ = wrapper.run(q, kv_data, return_lse=True) + o, lse = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) # test with pre-allocated output o_buffer = torch.empty_like(o) - wrapper.run(q, kv_data, out=o_buffer) + if return_lse: + lse_buffer = torch.empty_like(lse) + wrapper.run(q, kv_data, out=o_buffer, lse=lse_buffer, return_lse=True) + else: + wrapper.run(q, kv_data, out=o_buffer) torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) else: q_indptr_buffer = torch.empty( diff --git a/tests/utils/test_block_sparse.py b/tests/attention/test_block_sparse.py similarity index 100% rename from tests/utils/test_block_sparse.py rename to tests/attention/test_block_sparse.py diff --git a/tests/attention/test_hopper.py b/tests/attention/test_hopper.py index 0a1b6fe8a7..22be2f8ea6 100644 --- a/tests/attention/test_hopper.py +++ b/tests/attention/test_hopper.py @@ -207,7 +207,7 @@ def test_deepseek_prefill( @pytest.mark.parametrize("batch_size", [1, 4, 8, 16]) @pytest.mark.parametrize("seq_len", [11, 12, 99, 1763, 9999, 32767]) -@pytest.mark.parametrize("page_size", [1]) # [1, 16]) +@pytest.mark.parametrize("page_size", [1, 16]) @pytest.mark.parametrize("num_qo_heads", [1, 4, 8]) @pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) @pytest.mark.parametrize("causal", [False, True]) @@ -267,8 +267,7 @@ def test_batch_paged_prefill( kv_indptr = torch.arange( 0, batch_size * num_pages_per_request + 1, num_pages_per_request ).int() - # NOTE(Zihao): pad 256 elements to avoid out-of-bound because we didn't check the boundary in the kernel - kv_indices = torch.arange(0, batch_size * num_pages_per_request + 256).int() + kv_indices = torch.arange(0, batch_size * num_pages_per_request).int() last_page_len = torch.full((batch_size,), last_page_len, dtype=torch.int32) wrapper_sm80.plan( diff --git a/tests/attention/test_hopper_fp8_attention.py b/tests/attention/test_hopper_fp8_attention.py index 35f102b21c..565af428d8 100644 --- a/tests/attention/test_hopper_fp8_attention.py +++ b/tests/attention/test_hopper_fp8_attention.py @@ -183,7 +183,248 @@ def test_block_sparse_attention( assert mse < 1.0, f"Block sparse MSE too high: {mse.item()}" +# Test batch prefill with ragged KV cache: MSE should be below threshold +@pytest.mark.parametrize("batch_size", [2, 4]) +@pytest.mark.parametrize("num_heads", [8, 32]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +def test_batch_prefill_ragged(batch_size, num_heads, head_dim, causal, dtype): + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("SM90A is not supported") + + print( + f"Testing FP8 batch prefill ragged with batch_size={batch_size}, num_heads={num_heads}, " + f"head_dim={head_dim}, causal={causal}, dtype={dtype}" + ) + + # Setup + o_dtype = torch.half + num_qo_heads = num_kv_heads = num_heads + + # Create variable length sequences + torch.manual_seed(0) + qo_lens = [128 * (i + 1) for i in range(batch_size)] + kv_lens = [128 * (i + 1) for i in range(batch_size)] + + # Build ragged tensors + qo_indptr = torch.tensor( + [0] + [sum(qo_lens[: i + 1]) for i in range(batch_size)], + dtype=torch.int32, + device="cuda", + ) + kv_indptr = torch.tensor( + [0] + [sum(kv_lens[: i + 1]) for i in range(batch_size)], + dtype=torch.int32, + device="cuda", + ) + + total_qo_len = sum(qo_lens) + total_kv_len = sum(kv_lens) + + # Create input tensors (fp16) + q_fp16 = torch.randn( + total_qo_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + k_fp16 = torch.randn( + total_kv_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + v_fp16 = torch.randn( + total_kv_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + + # Get reference output using fp16 + wrapper_fp16 = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + "NHD", + backend="fa3", + ) + wrapper_fp16.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + causal=causal, + ) + o_ref = wrapper_fp16.run(q_fp16, k_fp16, v_fp16) + + # Quantize to FP8 + q_fp8, s_q = per_head_symmetric_quant(q_fp16, quant_dtype=dtype) + k_fp8, s_k = per_head_symmetric_quant(k_fp16, quant_dtype=dtype) + v_fp8, s_v = per_head_symmetric_quant(v_fp16, quant_dtype=dtype) + + # Run FP8 batch prefill with ragged KV cache + wrapper_fp8 = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + "NHD", + backend="fa3", + ) + wrapper_fp8.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + q_data_type=dtype, + kv_data_type=dtype, + o_data_type=o_dtype, + causal=causal, + ) + o_fp8 = wrapper_fp8.run(q_fp8, k_fp8, v_fp8, s_q, s_k, s_v) + + # Compute MSE + mse = torch.mean((o_ref.float() - o_fp8.float()) ** 2) + assert mse < 1.0, f"MSE too high: {mse.item()}" + + +# Test batch prefill with paged KV cache: MSE should be below threshold +@pytest.mark.parametrize("batch_size", [2, 4]) +@pytest.mark.parametrize("num_heads", [8, 32]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +def test_batch_prefill_paged(batch_size, num_heads, head_dim, causal, dtype): + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("SM90A is not supported") + + print( + f"Testing FP8 batch prefill paged with batch_size={batch_size}, num_heads={num_heads}, " + f"head_dim={head_dim}, causal={causal}, dtype={dtype}" + ) + + # Setup + o_dtype = torch.half + num_qo_heads = num_kv_heads = num_heads + page_size = 16 + + # Create variable length sequences + torch.manual_seed(0) + qo_lens = [128 * (i + 1) for i in range(batch_size)] + kv_lens = [128 * (i + 1) for i in range(batch_size)] + + # Build indptr for Q + qo_indptr = torch.tensor( + [0] + [sum(qo_lens[: i + 1]) for i in range(batch_size)], + dtype=torch.int32, + device="cuda", + ) + + total_qo_len = sum(qo_lens) + + # Compute number of pages needed for each sequence + kv_page_counts = [(kv_len + page_size - 1) // page_size for kv_len in kv_lens] + total_pages = sum(kv_page_counts) + + # Build paged KV indptr and indices + kv_indptr = torch.tensor( + [0] + [sum(kv_page_counts[: i + 1]) for i in range(batch_size)], + dtype=torch.int32, + device="cuda", + ) + # Simple page indices: sequential allocation + kv_indices = torch.arange(total_pages, dtype=torch.int32, device="cuda") + kv_last_page_len = torch.tensor( + [ + kv_len % page_size if kv_len % page_size != 0 else page_size + for kv_len in kv_lens + ], + dtype=torch.int32, + device="cuda", + ) + + # Create input tensors (fp16) + q_fp16 = torch.randn( + total_qo_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ) + # Paged KV cache: (num_pages, page_size, num_heads, head_dim) + paged_k_fp16 = torch.randn( + total_pages, page_size, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + paged_v_fp16 = torch.randn( + total_pages, page_size, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ) + + # Get reference output using fp16 + wrapper_fp16 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + "NHD", + backend="fa3", + ) + wrapper_fp16.plan( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + ) + o_ref = wrapper_fp16.run(q_fp16, (paged_k_fp16, paged_v_fp16)) + + # Quantize to FP8 + q_fp8, s_q = per_head_symmetric_quant(q_fp16, quant_dtype=dtype) + # For paged KV, reshape to (total_tokens, num_heads, head_dim) for quantization + k_flat = paged_k_fp16.view(-1, num_kv_heads, head_dim) + v_flat = paged_v_fp16.view(-1, num_kv_heads, head_dim) + k_fp8_flat, s_k = per_head_symmetric_quant(k_flat, quant_dtype=dtype) + v_fp8_flat, s_v = per_head_symmetric_quant(v_flat, quant_dtype=dtype) + paged_k_fp8 = k_fp8_flat.view(total_pages, page_size, num_kv_heads, head_dim) + paged_v_fp8 = v_fp8_flat.view(total_pages, page_size, num_kv_heads, head_dim) + + # Run FP8 batch prefill with paged KV cache + wrapper_fp8 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda"), + "NHD", + backend="fa3", + ) + wrapper_fp8.plan( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + q_data_type=dtype, + kv_data_type=dtype, + o_data_type=o_dtype, + causal=causal, + ) + o_fp8 = wrapper_fp8.run(q_fp8, (paged_k_fp8, paged_v_fp8), s_q, s_k, s_v) + + # Compute MSE + mse = torch.mean((o_ref.float() - o_fp8.float()) ** 2) + assert mse < 1.0, f"MSE too high: {mse.item()}" + + if __name__ == "__main__": + # Test batch prefill paged + for batch_size in [2]: + for num_heads in [8]: + for head_dim in [128, 256]: + for causal in [True, False]: + for dtype in [torch.float8_e4m3fn]: + test_batch_prefill_paged( + batch_size, num_heads, head_dim, causal, dtype + ) + + # Test batch prefill ragged + for batch_size in [2]: + for num_heads in [8]: + for head_dim in [128]: + for causal in [True, False]: + for dtype in [torch.float8_e4m3fn]: + test_batch_prefill_ragged( + batch_size, num_heads, head_dim, causal, dtype + ) + + # Test block sparse attention for R in [4]: for C in [1]: for M in [1024]: diff --git a/tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py b/tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py deleted file mode 100644 index cf2ef003cc..0000000000 --- a/tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -Copyright (c) 2023 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import pytest -import torch - -import flashinfer.page - - -@pytest.mark.parametrize("batch_size", [1, 7, 19, 128, 517]) -@pytest.mark.parametrize("kv_len", [97, 199, 2049, 31791]) -@pytest.mark.parametrize("block_size", [1, 3, 7, 16, 64, 79, 128]) -@pytest.mark.parametrize("stride_block", [128]) -@pytest.mark.parametrize("stride_n", [1]) -def test_block_sparse_indices_to_vector_sparse_offsets( - batch_size, kv_len, block_size, stride_block, stride_n -): - if batch_size * kv_len > 1048576: - pytest.skip("skip large test") - num_blocks_per_row = (kv_len + block_size - 1) // block_size - - block_sparse_indices = torch.arange( - batch_size * num_blocks_per_row, device="cuda", dtype=torch.int32 - ) - block_sparse_indptr = torch.arange( - 0, - batch_size * num_blocks_per_row + 1, - num_blocks_per_row, - device="cuda", - dtype=torch.int32, - ) - vector_sparse_offsets_buf = torch.zeros( - batch_size * kv_len, device="cuda", dtype=torch.int32 - ) - vector_sparse_indptr = torch.arange( - 0, batch_size * kv_len + 1, kv_len, device="cuda", dtype=torch.int32 - ) - kv_lens = torch.full((batch_size,), kv_len, device="cuda", dtype=torch.int32) - - vector_sparse_offsets = ( - flashinfer.page.block_sparse_indices_to_vector_sparse_offsets( - block_sparse_indices, - block_sparse_indptr, - vector_sparse_offsets_buf, - vector_sparse_indptr, - kv_lens, - stride_block, - stride_n, - block_size, - ) - ) - - # Check that the output is correct - for i in range(batch_size): - indices_i = block_sparse_indices[ - i * num_blocks_per_row : (i + 1) * num_blocks_per_row - ].cpu() - output_i = vector_sparse_offsets[ - vector_sparse_indptr[i] : vector_sparse_indptr[i + 1] - ].cpu() - - output_ref_i = ( - indices_i[torch.arange(0, kv_len, dtype=torch.int32) // block_size] - * stride_block - + (torch.arange(0, kv_len, dtype=torch.int32) % block_size) * stride_n - ) - torch.testing.assert_close(output_i, output_ref_i) - - -if __name__ == "__main__": - pass