@@ -667,12 +667,6 @@ def calculate_scaling_factor_size_bytes(
667667 @staticmethod
668668 def get_cache_size_per_token (model_config : ModelConfigPython ,
669669 mapping : Mapping , ** kwargs ):
670- # get kv cache dtype bytes
671- mem_per_token = 2
672- quant_config = model_config .quant_config
673- if quant_config is not None and quant_config .quant_mode .has_fp8_kv_cache (
674- ):
675- mem_per_token = 1
676670
677671 # get num key value heads
678672 config = model_config .pretrained_config
@@ -698,10 +692,20 @@ def get_cache_size_per_token(model_config: ModelConfigPython,
698692 # provide at least 1 layer to prevent division by zero cache size
699693 num_attention_layers = max (
700694 len (mapping .pp_layers (model_config .get_num_attention_layers ())), 1 )
701- mem_per_token *= num_attention_layers * head_dim
702-
703695 # K and V
704- mem_per_token *= kv_factor
696+ mem_per_token = kv_factor * num_attention_layers * head_dim
697+ # The data type bytes.
698+ quant_config = model_config .quant_config
699+ if quant_config is not None and quant_config .quant_mode .has_fp8_kv_cache (
700+ ):
701+ mem_per_token *= 1
702+ elif quant_config is not None and quant_config .quant_mode .has_fp4_kv_cache (
703+ ):
704+ # 1 bytes for 2 elements, and SFs (fp8) per 16 elements.
705+ mem_per_token = math .ceil (mem_per_token / 2 ) + math .ceil (
706+ mem_per_token / 16 )
707+ else :
708+ mem_per_token *= 2
705709 return mem_per_token
706710
707711 def get_cache_bytes_per_token (self ):
0 commit comments