Skip to content

Commit 389b73c

Browse files
authored
[None][fix] Remove FP8 K/V buffer from TRTLLM sparse MLA attention kernel (#9529)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
1 parent bf84d9c commit 389b73c

File tree

1 file changed

+34
-8
lines changed

1 file changed

+34
-8
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -778,8 +778,19 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
778778
if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher->isSeparateQAndKvInput())
779779
{
780780
fp8_q_buf_size = max_num_tokens * static_cast<size_t>(total_q_dim_all_heads);
781-
fp8_k_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast<size_t>(total_k_dim_all_heads);
782-
fp8_v_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast<size_t>(total_v_dim_all_heads);
781+
782+
if (useSparseMLA())
783+
{
784+
// Sparse MLA (absorption mode): K and V are stored directly in KV cache during MLA RoPE kernel.
785+
// No separate FP8 buffers needed for K/V since they're read from paged KV cache (Q_PAGED_KV layout).
786+
fp8_k_buf_size = 0;
787+
fp8_v_buf_size = 0;
788+
}
789+
else
790+
{
791+
fp8_k_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast<size_t>(total_k_dim_all_heads);
792+
fp8_v_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast<size_t>(total_v_dim_all_heads);
793+
}
783794
}
784795

785796
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
@@ -1436,8 +1447,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14361447
if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher->isSeparateQAndKvInput())
14371448
{
14381449
fp8_q_buf_size = params.num_tokens * static_cast<size_t>(total_q_dim_all_heads);
1439-
fp8_k_buf_size = params.total_kv_len * static_cast<size_t>(total_k_dim_all_heads);
1440-
fp8_v_buf_size = params.total_kv_len * static_cast<size_t>(total_v_dim_all_heads);
1450+
1451+
if (useSparseMLA())
1452+
{
1453+
// Sparse MLA (absorption mode): K and V are stored directly in KV cache during MLA RoPE kernel.
1454+
// No separate FP8 buffers needed for K/V since they're read from paged KV cache (Q_PAGED_KV layout).
1455+
fp8_k_buf_size = 0;
1456+
fp8_v_buf_size = 0;
1457+
}
1458+
else
1459+
{
1460+
fp8_k_buf_size = params.total_kv_len * static_cast<size_t>(total_k_dim_all_heads);
1461+
fp8_v_buf_size = params.total_kv_len * static_cast<size_t>(total_v_dim_all_heads);
1462+
}
14411463
}
14421464
size_t const padding_offset_size
14431465
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length;
@@ -1805,11 +1827,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
18051827
TLLM_CHECK_WITH_INFO(
18061828
mFmhaDispatcher->isSeparateQAndKvInput(), "Separate QKV input is required for fp8 context MLA");
18071829
TLLM_CHECK_WITH_INFO(fp8_q_buf != nullptr, "FP8 q buffer is required for fp8 context MLA");
1808-
TLLM_CHECK_WITH_INFO(fp8_k_buf != nullptr, "FP8 k buffer is required for fp8 context MLA");
1809-
TLLM_CHECK_WITH_INFO(fp8_v_buf != nullptr, "FP8 v buffer is required for fp8 context MLA");
1830+
// In sparse MLA (absorption mode), K and V are stored in KV cache, not as separate FP8 buffers
1831+
TLLM_CHECK_WITH_INFO(useSparseMLA() || fp8_k_buf != nullptr,
1832+
"FP8 k buffer is required for fp8 context MLA in non-sparse mode");
1833+
TLLM_CHECK_WITH_INFO(useSparseMLA() || fp8_v_buf != nullptr,
1834+
"FP8 v buffer is required for fp8 context MLA in non-sparse mode");
1835+
18101836
fmhaParams.qPtr = reinterpret_cast<void const*>(fp8_q_buf);
1811-
fmhaParams.kPtr = reinterpret_cast<void const*>(fp8_k_buf);
1812-
fmhaParams.vPtr = reinterpret_cast<void const*>(fp8_v_buf);
1837+
fmhaParams.kPtr = useSparseMLA() ? nullptr : reinterpret_cast<void const*>(fp8_k_buf);
1838+
fmhaParams.vPtr = useSparseMLA() ? nullptr : reinterpret_cast<void const*>(fp8_v_buf);
18131839
}
18141840
else
18151841
{

0 commit comments

Comments
 (0)