Skip to content

Commit 9e6bfa4

Browse files
committed
fix get_cache_size_per_token for nvfp4 kv
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
1 parent 3113146 commit 9e6bfa4

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -667,12 +667,6 @@ def calculate_scaling_factor_size_bytes(
667667
@staticmethod
668668
def get_cache_size_per_token(model_config: ModelConfigPython,
669669
mapping: Mapping, **kwargs):
670-
# get kv cache dtype bytes
671-
mem_per_token = 2
672-
quant_config = model_config.quant_config
673-
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
674-
):
675-
mem_per_token = 1
676670

677671
# get num key value heads
678672
config = model_config.pretrained_config
@@ -698,10 +692,20 @@ def get_cache_size_per_token(model_config: ModelConfigPython,
698692
# provide at least 1 layer to prevent division by zero cache size
699693
num_attention_layers = max(
700694
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
701-
mem_per_token *= num_attention_layers * head_dim
702-
703695
# K and V
704-
mem_per_token *= kv_factor
696+
mem_per_token = kv_factor * num_attention_layers * head_dim
697+
# The data type bytes.
698+
quant_config = model_config.quant_config
699+
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
700+
):
701+
mem_per_token *= 1
702+
elif quant_config is not None and quant_config.quant_mode.has_fp4_kv_cache(
703+
):
704+
# 1 bytes for 2 elements, and SFs (fp8) per 16 elements.
705+
mem_per_token = math.ceil(mem_per_token / 2) + math.ceil(
706+
mem_per_token / 16)
707+
else:
708+
mem_per_token *= 2
705709
return mem_per_token
706710

707711
def get_cache_bytes_per_token(self):

0 commit comments

Comments
 (0)