Skip to content

Commit 7440f2b

Browse files
committed
fix a bug with headDim 256 nvfp4-kv kernels
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
1 parent a972cf7 commit 7440f2b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,8 @@ struct KernelParams
685685
// The number of elements in 128B for Q.
686686
int32_t numEltsIn128BKv = (128 * 8) / get_size_in_bits(kernelMeta.mDataTypeKv);
687687
// The number of head elts (per token) in each block of shared memory (see above explanation).
688-
int32_t numEltsInClampedHeadDimKv = std::min(numEltsIn128BKv, maxHeadDimKv);
688+
// HeadDim will be split into multiple headDimStages (128) if maxHeadDimKv > 128.
689+
int32_t numEltsInClampedHeadDimKv = std::min({numEltsIn128BKv, maxHeadDimKv, 128});
689690

690691
// Do we have to transform K/V before MMA?
691692
bool const transformsKv{kernelMeta.mDataTypeKv != kernelMeta.mDataTypeQ};

0 commit comments

Comments
 (0)