@@ -446,13 +446,15 @@ def __init__(
446446 executor_fail_callback : Optional [Callable ] = None ,
447447 ):
448448 self .vllm_config = vllm_config
449- self .vllm_config .cache_config .gpu_memory_utilization = (
450- self .vllm_config .cache_config .gpu_memory_utilization - 0.1 )
451449
452450 self .output_queue = queue .Queue [Union [tuple [int , EngineCoreOutputs ],
453451 bytes ]]()
454452
455453 self .devices = jax .devices ()
454+ device_kind = self .devices [0 ].device_kind
455+ if device_kind != 'TPU7x' :
456+ self .vllm_config .cache_config .gpu_memory_utilization = (
457+ self .vllm_config .cache_config .gpu_memory_utilization - 0.1 )
456458 prefill_slice_sizes , decode_slice_sizes , slice_sizes = _get_slice_sizes (
457459 self .devices )
458460
@@ -597,7 +599,6 @@ def __init__(
597599 # engine core to be executed, instead we create other instance of
598600 # engine cores and let them do the work.
599601 self .vllm_config = vllm_config
600- self .vllm_config .cache_config .gpu_memory_utilization = self .vllm_config .cache_config .gpu_memory_utilization - 0.1
601602
602603 # We should be taking the input from the client, the code below is forked from
603604 # vllm.v1.engine.core.EngineCoreProc.
@@ -610,6 +611,10 @@ def __init__(
610611 self .engines_running = False
611612
612613 self .devices = jax .devices ()
614+ device_kind = self .devices [0 ].device_kind
615+ if device_kind != 'TPU7x' :
616+ self .vllm_config .cache_config .gpu_memory_utilization = (
617+ self .vllm_config .cache_config .gpu_memory_utilization - 0.1 )
613618 prefill_slice_sizes , decode_slice_sizes , slice_sizes = _get_slice_sizes (
614619 self .devices )
615620
0 commit comments