Skip to content

Commit f2241f4

Browse files
committed
Send kv events from worker side to scheduler side
This is required for when worker side operations like CPU offloading generate KV cache events. This commit enables theses events to be passed to the scheduler side so that they can be published by the engine. Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
1 parent d4acf51 commit f2241f4

File tree

4 files changed

+81
-2
lines changed

4 files changed

+81
-2
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
if TYPE_CHECKING:
5050
from vllm.attention.backends.abstract import AttentionMetadata
5151
from vllm.config import VllmConfig
52-
from vllm.distributed.kv_events import KVCacheEvent
52+
from vllm.distributed.kv_events import KVCacheEvent, KVEventBatch
5353
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
5454
KVConnectorPromMetrics,
5555
KVConnectorStats,
@@ -350,6 +350,12 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
350350
"""
351351
return None
352352

353+
def get_kv_connector_kv_cache_events(self) -> Optional["KVEventBatch"]:
354+
"""
355+
Get the KV connector kv cache events collected during the last interval.
356+
"""
357+
return None
358+
353359
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
354360
"""
355361
Get the KVConnector handshake metadata for this connector.

vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import TYPE_CHECKING, Any
3+
import time
4+
from collections.abc import Iterable
5+
from typing import TYPE_CHECKING, Any, Optional
46

57
import torch
68
from lmcache.integration.vllm.vllm_v1_adapter import (
79
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
810
)
911

1012
from vllm.config import VllmConfig
13+
from vllm.distributed.kv_events import BlockStored, KVCacheEvent, KVEventBatch
1114
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1215
KVConnectorBase_V1,
1316
KVConnectorMetadata,
1417
KVConnectorRole,
1518
)
1619
from vllm.logger import init_logger
1720
from vllm.v1.core.sched.output import SchedulerOutput
21+
from vllm.v1.outputs import KVConnectorOutput
1822

1923
if TYPE_CHECKING:
2024
from vllm.attention.backends.abstract import AttentionMetadata
@@ -54,6 +58,8 @@ def __init__(
5458

5559
self._lmcache_engine = cls(vllm_config, role, self)
5660

61+
self._kv_events: list[KVCacheEvent] = []
62+
5763
# ==============================
5864
# Worker-side methods
5965
# ==============================
@@ -151,6 +157,30 @@ def get_block_ids_with_load_errors(self) -> set[int]:
151157
# Fallback for older versions that don't support this method
152158
return set()
153159

160+
def get_kv_connector_kv_cache_events(self) -> Optional["KVEventBatch"]:
161+
"""
162+
Get the KV connector kv cache events collected during the last interval.
163+
"""
164+
events = self._lmcache_engine.get_kv_events()
165+
if not events:
166+
return None
167+
168+
lmcache_kv_events: KVEventBatch | None = None
169+
for event in events:
170+
if lmcache_kv_events is None:
171+
lmcache_kv_events = KVEventBatch(ts=time.time(), events=[])
172+
block = BlockStored(
173+
block_hashes=event.block_hashes,
174+
parent_block_hash=event.parent_block_hash,
175+
token_ids=event.token_ids,
176+
lora_id=event.lora_id,
177+
block_size=event.block_size,
178+
medium=event.medium,
179+
)
180+
lmcache_kv_events.events.append(block)
181+
182+
return lmcache_kv_events
183+
154184
# ==============================
155185
# Scheduler-side methods
156186
# ==============================
@@ -198,6 +228,25 @@ def build_connector_meta(
198228
"""
199229
return self._lmcache_engine.build_connector_meta(scheduler_output)
200230

231+
def update_connector_output(self, connector_output: KVConnectorOutput):
232+
"""
233+
Update KVConnector state from worker-side connectors output.
234+
235+
Args:
236+
connector_output (KVConnectorOutput): the worker-side
237+
connectors output.
238+
"""
239+
# Get the KV events
240+
kv_events = connector_output.kv_cache_events
241+
if (
242+
not kv_events
243+
or not isinstance(kv_events, KVEventBatch)
244+
or not kv_events.events
245+
):
246+
return
247+
self._kv_events = kv_events.events
248+
return
249+
201250
def request_finished(
202251
self,
203252
request: "Request",
@@ -214,3 +263,14 @@ def request_finished(
214263
returned by the engine.
215264
"""
216265
return self._lmcache_engine.request_finished(request, block_ids)
266+
267+
def take_events(self) -> Iterable["KVCacheEvent"]:
268+
"""
269+
Take the KV cache events from the connector.
270+
271+
Yields:
272+
New KV cache events since the last call.
273+
"""
274+
if self._kv_events is not None:
275+
yield from self._kv_events
276+
self._kv_events.clear()

vllm/v1/outputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from vllm.v1.core.sched.output import SchedulerOutput
1212

1313
if TYPE_CHECKING:
14+
from vllm.distributed.kv_events import KVEventBatch
1415
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
1516
else:
1617
KVConnectorStats = object
18+
KVEventBatch = object
1719

1820

1921
class LogprobsLists(NamedTuple):
@@ -119,6 +121,7 @@ class KVConnectorOutput:
119121
finished_sending: set[str] | None = None
120122
finished_recving: set[str] | None = None
121123
kv_connector_stats: KVConnectorStats | None = None
124+
kv_cache_events: KVEventBatch | None = None
122125
# IDs of externally computed KV blocks that failed to load.
123126
# Requests referencing these blocks should be rescheduled to recompute them
124127
invalid_block_ids: set[int] = field(default_factory=set)
@@ -134,6 +137,7 @@ def is_empty(self):
134137
not self.finished_sending
135138
and not self.finished_recving
136139
and not self.kv_connector_stats
140+
and not self.kv_cache_events
137141
and not self.invalid_block_ids
138142
)
139143

vllm/v1/worker/kv_connector_model_runner_mixin.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,19 @@ def _get_kv_connector_output(
135135
output.kv_connector_stats = (
136136
KVConnectorModelRunnerMixin.get_kv_connector_stats()
137137
)
138+
output.kv_cache_events = (
139+
KVConnectorModelRunnerMixin.get_kv_connector_kv_cache_events()
140+
)
138141
kv_connector.clear_connector_metadata()
139142

140143
@staticmethod
141144
def get_kv_connector_stats() -> KVConnectorStats | None:
142145
if has_kv_transfer_group():
143146
return get_kv_transfer_group().get_kv_connector_stats()
144147
return None
148+
149+
@staticmethod
150+
def get_kv_connector_kv_cache_events() -> KVConnectorStats | None:
151+
if has_kv_transfer_group():
152+
return get_kv_transfer_group().get_kv_connector_kv_cache_events()
153+
return None

0 commit comments

Comments
 (0)