Skip to content

Commit 3198532

Browse files
[Disagg] Fixes for vllm model impl disagg support (#1066)
1 parent a14b215 commit 3198532

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

tpu_inference/core/core_tpu.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,13 +446,15 @@ def __init__(
446446
executor_fail_callback: Optional[Callable] = None,
447447
):
448448
self.vllm_config = vllm_config
449-
self.vllm_config.cache_config.gpu_memory_utilization = (
450-
self.vllm_config.cache_config.gpu_memory_utilization - 0.1)
451449

452450
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
453451
bytes]]()
454452

455453
self.devices = jax.devices()
454+
device_kind = self.devices[0].device_kind
455+
if device_kind != 'TPU7x':
456+
self.vllm_config.cache_config.gpu_memory_utilization = (
457+
self.vllm_config.cache_config.gpu_memory_utilization - 0.1)
456458
prefill_slice_sizes, decode_slice_sizes, slice_sizes = _get_slice_sizes(
457459
self.devices)
458460

@@ -597,7 +599,6 @@ def __init__(
597599
# engine core to be executed, instead we create other instance of
598600
# engine cores and let them do the work.
599601
self.vllm_config = vllm_config
600-
self.vllm_config.cache_config.gpu_memory_utilization = self.vllm_config.cache_config.gpu_memory_utilization - 0.1
601602

602603
# We should be taking the input from the client, the code below is forked from
603604
# vllm.v1.engine.core.EngineCoreProc.
@@ -610,6 +611,10 @@ def __init__(
610611
self.engines_running = False
611612

612613
self.devices = jax.devices()
614+
device_kind = self.devices[0].device_kind
615+
if device_kind != 'TPU7x':
616+
self.vllm_config.cache_config.gpu_memory_utilization = (
617+
self.vllm_config.cache_config.gpu_memory_utilization - 0.1)
613618
prefill_slice_sizes, decode_slice_sizes, slice_sizes = _get_slice_sizes(
614619
self.devices)
615620

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,21 @@ def __init__(self, vllm_config: VllmConfig, rng: PRNGKey, mesh: Mesh):
8181

8282
def load_weights(self):
8383
# Set up to load the model into CPU first.
84+
# Cache device slice config since device config cannot be deepcopied
85+
modified_slice_config = False
86+
if hasattr(
87+
self.vllm_config.device_config,
88+
'slice') and self.vllm_config.device_config.slice is not None:
89+
slice_config = self.vllm_config.device_config.slice
90+
modified_slice_config = True
91+
self.vllm_config.device_config.slice = None
8492
vllm_config_for_load = copy.deepcopy(self.vllm_config)
93+
if modified_slice_config:
94+
self.vllm_config.device_config.slice = slice_config
8595
assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
8696
vllm_config_for_load.device_config.device = "cpu"
97+
# Clearing the cached compilation config, otherwise vllm model init will fail
98+
vllm_config_for_load.compilation_config.static_forward_context.clear()
8799

88100
# When expert parallelism is enabled, vLLM loads weight in sharding
89101
# aware manner. Since tpu-inference has its own sharding logic, this

0 commit comments

Comments
 (0)