@@ -778,8 +778,19 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
778778 if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher ->isSeparateQAndKvInput ())
779779 {
780780 fp8_q_buf_size = max_num_tokens * static_cast <size_t >(total_q_dim_all_heads);
781- fp8_k_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast <size_t >(total_k_dim_all_heads);
782- fp8_v_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast <size_t >(total_v_dim_all_heads);
781+
782+ if (useSparseMLA ())
783+ {
784+ // Sparse MLA (absorption mode): K and V are stored directly in KV cache during MLA RoPE kernel.
785+ // No separate FP8 buffers needed for K/V since they're read from paged KV cache (Q_PAGED_KV layout).
786+ fp8_k_buf_size = 0 ;
787+ fp8_v_buf_size = 0 ;
788+ }
789+ else
790+ {
791+ fp8_k_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast <size_t >(total_k_dim_all_heads);
792+ fp8_v_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast <size_t >(total_v_dim_all_heads);
793+ }
783794 }
784795
785796 size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof (int ) * max_num_tokens;
@@ -1436,8 +1447,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
14361447 if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher ->isSeparateQAndKvInput ())
14371448 {
14381449 fp8_q_buf_size = params.num_tokens * static_cast <size_t >(total_q_dim_all_heads);
1439- fp8_k_buf_size = params.total_kv_len * static_cast <size_t >(total_k_dim_all_heads);
1440- fp8_v_buf_size = params.total_kv_len * static_cast <size_t >(total_v_dim_all_heads);
1450+
1451+ if (useSparseMLA ())
1452+ {
1453+ // Sparse MLA (absorption mode): K and V are stored directly in KV cache during MLA RoPE kernel.
1454+ // No separate FP8 buffers needed for K/V since they're read from paged KV cache (Q_PAGED_KV layout).
1455+ fp8_k_buf_size = 0 ;
1456+ fp8_v_buf_size = 0 ;
1457+ }
1458+ else
1459+ {
1460+ fp8_k_buf_size = params.total_kv_len * static_cast <size_t >(total_k_dim_all_heads);
1461+ fp8_v_buf_size = params.total_kv_len * static_cast <size_t >(total_v_dim_all_heads);
1462+ }
14411463 }
14421464 size_t const padding_offset_size
14431465 = mEnableContextFMHA ? 0 : sizeof (int ) * params.batch_size * params.input_seq_length ;
@@ -1805,11 +1827,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
18051827 TLLM_CHECK_WITH_INFO (
18061828 mFmhaDispatcher ->isSeparateQAndKvInput (), " Separate QKV input is required for fp8 context MLA" );
18071829 TLLM_CHECK_WITH_INFO (fp8_q_buf != nullptr , " FP8 q buffer is required for fp8 context MLA" );
1808- TLLM_CHECK_WITH_INFO (fp8_k_buf != nullptr , " FP8 k buffer is required for fp8 context MLA" );
1809- TLLM_CHECK_WITH_INFO (fp8_v_buf != nullptr , " FP8 v buffer is required for fp8 context MLA" );
1830+ // In sparse MLA (absorption mode), K and V are stored in KV cache, not as separate FP8 buffers
1831+ TLLM_CHECK_WITH_INFO (useSparseMLA () || fp8_k_buf != nullptr ,
1832+ " FP8 k buffer is required for fp8 context MLA in non-sparse mode" );
1833+ TLLM_CHECK_WITH_INFO (useSparseMLA () || fp8_v_buf != nullptr ,
1834+ " FP8 v buffer is required for fp8 context MLA in non-sparse mode" );
1835+
18101836 fmhaParams.qPtr = reinterpret_cast <void const *>(fp8_q_buf);
1811- fmhaParams.kPtr = reinterpret_cast <void const *>(fp8_k_buf);
1812- fmhaParams.vPtr = reinterpret_cast <void const *>(fp8_v_buf);
1837+ fmhaParams.kPtr = useSparseMLA () ? nullptr : reinterpret_cast <void const *>(fp8_k_buf);
1838+ fmhaParams.vPtr = useSparseMLA () ? nullptr : reinterpret_cast <void const *>(fp8_v_buf);
18131839 }
18141840 else
18151841 {
0 commit comments