@@ -1113,35 +1113,12 @@ def get_kv_cache_config_from_groups(
11131113 KVCacheTensor (size = page_size * num_blocks , shared_by = shared_by )
11141114 )
11151115
1116- kv_cache_config = KVCacheConfig (
1116+ return KVCacheConfig (
11171117 num_blocks = num_blocks ,
11181118 kv_cache_tensors = kv_cache_tensors ,
11191119 kv_cache_groups = kv_cache_groups ,
11201120 )
11211121
1122- min_block_size = min ([group .kv_cache_spec .block_size for group in kv_cache_groups ])
1123-
1124- # Print the KV cache size and maximum concurrency.
1125- num_tokens = num_blocks // len (kv_cache_groups ) * min_block_size
1126- if vllm_config .parallel_config .decode_context_parallel_size > 1 :
1127- num_tokens *= vllm_config .parallel_config .decode_context_parallel_size
1128- logger .info (
1129- "Multiplying the GPU KV cache size by the dcp_world_size %d." ,
1130- vllm_config .parallel_config .decode_context_parallel_size ,
1131- )
1132- num_tokens_str = f"{ num_tokens :,} "
1133- logger .info ("GPU KV cache size: %s tokens" , num_tokens_str )
1134- max_model_len_str = f"{ vllm_config .model_config .max_model_len :,} "
1135- max_concurrency = get_max_concurrency_for_kv_cache_config (
1136- vllm_config , kv_cache_config
1137- )
1138- logger .info (
1139- "Maximum concurrency for %s tokens per request: %.2fx" ,
1140- max_model_len_str ,
1141- max_concurrency ,
1142- )
1143- return kv_cache_config
1144-
11451122
11461123def unify_hybrid_kv_cache_specs (kv_cache_spec : dict [str , KVCacheSpec ]):
11471124 """
@@ -1265,6 +1242,45 @@ def generate_scheduler_kv_cache_config(
12651242 return cfg
12661243
12671244
1245+ def _report_kv_cache_config (
1246+ vllm_config : VllmConfig , kv_cache_config : KVCacheConfig
1247+ ) -> None :
1248+ """
1249+ Log resolved KV cache configuration.
1250+
1251+ Args:
1252+ vllm_config: The global VllmConfig
1253+ kv_cache_config: The resolved KV cache configuration
1254+ """
1255+ min_block_size = min (
1256+ [group .kv_cache_spec .block_size for group in kv_cache_config .kv_cache_groups ]
1257+ )
1258+
1259+ # Log the KV cache size and maximum concurrency.
1260+ num_tokens = (
1261+ kv_cache_config .num_blocks
1262+ // len (kv_cache_config .kv_cache_groups )
1263+ * min_block_size
1264+ )
1265+ if vllm_config .parallel_config .decode_context_parallel_size > 1 :
1266+ num_tokens *= vllm_config .parallel_config .decode_context_parallel_size
1267+ logger .info (
1268+ "Multiplying the GPU KV cache size by the dcp_world_size %d." ,
1269+ vllm_config .parallel_config .decode_context_parallel_size ,
1270+ )
1271+ num_tokens_str = f"{ num_tokens :,} "
1272+ logger .info ("GPU KV cache size: %s tokens" , num_tokens_str )
1273+ max_model_len_str = f"{ vllm_config .model_config .max_model_len :,} "
1274+ max_concurrency = get_max_concurrency_for_kv_cache_config (
1275+ vllm_config , kv_cache_config
1276+ )
1277+ logger .info (
1278+ "Maximum concurrency for %s tokens per request: %.2fx" ,
1279+ max_model_len_str ,
1280+ max_concurrency ,
1281+ )
1282+
1283+
12681284def get_kv_cache_configs (
12691285 vllm_config : VllmConfig ,
12701286 kv_cache_specs : list [dict [str , KVCacheSpec ]],
@@ -1284,7 +1300,8 @@ def get_kv_cache_configs(
12841300 3. Generate the KV cache configs for each worker based on the KV cache
12851301 grouping strategy. (This is reasonable because the layer ratio of
12861302 different PP stages are similar.)
1287- 4. Change the num_blocks of each worker to the smallest among all workers.
1303+ 4. Change the num_blocks of each worker to the smallest among all workers
1304+ and shrink tensor sizes proportionally to avoid allocating unused memory.
12881305
12891306 Args:
12901307 vllm_config: The global VllmConfig
@@ -1345,13 +1362,22 @@ def get_kv_cache_configs(
13451362 )
13461363 )
13471364
1348- # Change the num_blocks of each rank to the smallest among all ranks. We
1349- # do not need to shrink the tensor size because it is valid to only use the
1350- # first `num_blocks` blocks of the tensor .
1365+ # Change the num_blocks of each rank to the smallest among all ranks.
1366+ # We also need to shrink the tensor size proportionally to avoid
1367+ # allocating unused memory .
13511368 min_num_blocks = min (
13521369 kv_cache_config .num_blocks for kv_cache_config in kv_cache_configs
13531370 )
13541371 for kv_cache_config in kv_cache_configs :
1372+ num_blocks_old = kv_cache_config .num_blocks
13551373 kv_cache_config .num_blocks = min_num_blocks
13561374
1375+ # Shrink tensor size proportionally
1376+ for tensor in kv_cache_config .kv_cache_tensors :
1377+ assert tensor .size % num_blocks_old == 0
1378+ tensor .size = tensor .size // num_blocks_old * min_num_blocks
1379+
1380+ if len (kv_cache_config .kv_cache_groups ) > 0 :
1381+ _report_kv_cache_config (vllm_config , kv_cache_config )
1382+
13571383 return kv_cache_configs
0 commit comments