Skip to content

Commit 7376e6f

Browse files
Fix the wrong KVAggregator finished count cause dead loop, adopt vllm changes (#1009)
1 parent f454815 commit 7376e6f

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

tpu_inference/distributed/tpu_connector.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ def request_finished(
190190
assert self.connector_scheduler is not None
191191
return self.connector_scheduler.request_finished(request, block_ids)
192192

193+
def get_finished_count(self) -> int:
194+
assert self.connector_scheduler is not None
195+
return self.connector_scheduler.get_finished_count()
196+
193197
############################################################
194198
# Worker Side Methods
195199
############################################################
@@ -280,7 +284,7 @@ def get_num_new_matched_tokens(
280284
because TPU pulls KV cache in a blocking way.
281285
282286
"""
283-
if self.is_producer:
287+
if self.is_producer or not request.kv_transfer_params:
284288
return 0, False
285289

286290
assert num_computed_tokens % self.block_size == 0
@@ -345,7 +349,9 @@ def update_state_after_alloc(self, request: "Request",
345349
remote_host=params["remote_host"],
346350
remote_port=params["remote_port"],
347351
)
348-
logger.info(f"Scheduler --> reqs_to_load={self.reqs_to_load}")
352+
logger.info(
353+
f"TPUConnector Scheduler update_state_after_alloc --> reqs_to_load={self.reqs_to_load}"
354+
)
349355

350356
def build_connector_meta(self) -> TPUConnectorMetadata:
351357
"""
@@ -365,6 +371,12 @@ def build_connector_meta(self) -> TPUConnectorMetadata:
365371

366372
return meta
367373

374+
def get_finished_count(self) -> int:
375+
"""
376+
Return how many workers need pull the kv cache and report back.
377+
"""
378+
return len(self.kv_ip) if isinstance(self.kv_ip, list) else 1
379+
368380
def request_finished(
369381
self,
370382
request: "Request",

tpu_inference/executors/ray_distributed_executor.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
import vllm.envs as envs
77
from ray.util.placement_group import PlacementGroup
88
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
9-
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
109
from vllm.multimodal.inputs import MultiModalKwargs
1110
from vllm.platforms import current_platform
1211
from vllm.ray.ray_env import get_env_vars_to_copy
1312
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
1413
from vllm.utils.network_utils import (get_distributed_init_method, get_ip,
1514
get_open_port)
15+
from vllm.v1.core.sched.output import SchedulerOutput
1616
from vllm.v1.executor.ray_distributed_executor import \
1717
RayDistributedExecutor as RayDistributedExecutorV1
1818
from vllm.v1.executor.ray_executor import RayWorkerMetaData
@@ -101,10 +101,10 @@ def _init_executor(self) -> None:
101101

102102
self.pp_locks: Optional[List[asyncio.Lock]] = None
103103

104+
self.scheduler_output: SchedulerOutput | None = None
105+
104106
# KV connector setup
105107
self.has_connector = self.vllm_config.kv_transfer_config is not None
106-
self.kv_output_aggregator = KVOutputAggregator(
107-
self.parallel_config.world_size)
108108
if self.has_connector:
109109
ip_port = self.collective_rpc("get_node_kv_ip_port")
110110
for item in ip_port:
@@ -229,7 +229,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
229229
for each, ip in zip(worker_metadata, worker_ips):
230230
each.ip = ip
231231

232-
logger.debug("workers: %s", worker_metadata)
232+
logger.debug(f"Initialized worker_metadata: {worker_metadata}")
233233

234234
ip_counts: Dict[str, int] = {}
235235
for ip in worker_ips:
@@ -256,6 +256,9 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
256256
start_rank = 0
257257
for i, item in enumerate(sorted_worker_metadata):
258258
item.adjusted_rank = i + start_rank
259+
logger.info(
260+
f"Initialized sorted worker_metadata: {sorted_worker_metadata}")
261+
259262
self.workers = [item.worker for item in sorted_worker_metadata]
260263
rerank_mapping = {
261264
item.created_rank: item.adjusted_rank
@@ -353,3 +356,8 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
353356
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
354357
assert pp_rank < len(self.pp_tp_workers)
355358
self.pp_tp_workers[pp_rank].append(self.workers[rank])
359+
360+
# Ray executor do not need handshake metadata
361+
# as we pass the kv_parameters through proxy server
362+
def get_kv_connector_handshake_metadata(self) -> None:
363+
pass

0 commit comments

Comments
 (0)