File tree Expand file tree Collapse file tree 1 file changed +2
-1
lines changed
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Original file line number Diff line number Diff line change @@ -685,7 +685,8 @@ struct KernelParams
685685 // The number of elements in 128B for Q.
686686 int32_t numEltsIn128BKv = (128 * 8 ) / get_size_in_bits (kernelMeta.mDataTypeKv );
687687 // The number of head elts (per token) in each block of shared memory (see above explanation).
688- int32_t numEltsInClampedHeadDimKv = std::min (numEltsIn128BKv, maxHeadDimKv);
688+ // HeadDim will be split into multiple headDimStages (128) if maxHeadDimKv > 128.
689+ int32_t numEltsInClampedHeadDimKv = std::min ({numEltsIn128BKv, maxHeadDimKv, 128 });
689690
690691 // Do we have to transform K/V before MMA?
691692 bool const transformsKv{kernelMeta.mDataTypeKv != kernelMeta.mDataTypeQ };
You can’t perform that action at this time.
0 commit comments