Skip to content

Commit 15e7aa8

Browse files
committed
Retrive KV events from LMCache
Updates to the LMCache connector to enable events generated by LMCache to be retrieved and published by the scheduler. Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
1 parent ca90f50 commit 15e7aa8

File tree

1 file changed

+114
-1
lines changed

1 file changed

+114
-1
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import TYPE_CHECKING, Any
3+
from collections.abc import Iterable
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING, Any, Optional
46

57
import torch
68
from lmcache.integration.vllm.vllm_v1_adapter import (
79
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
810
)
911

1012
from vllm.config import VllmConfig
13+
from vllm.distributed.kv_events import BlockStored, KVCacheEvent
1114
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1215
KVConnectorBase_V1,
1316
KVConnectorMetadata,
1417
KVConnectorRole,
1518
)
19+
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
1620
from vllm.logger import init_logger
1721
from vllm.v1.core.sched.output import SchedulerOutput
22+
from vllm.v1.outputs import KVConnectorOutput
1823

1924
if TYPE_CHECKING:
2025
from vllm.attention.backends.abstract import AttentionMetadata
@@ -26,6 +31,50 @@
2631
logger = init_logger(__name__)
2732

2833

34+
@dataclass
35+
class LMCacheKVEvents(KVConnectorStats):
36+
"""
37+
Maintain a list of KV events
38+
"""
39+
40+
def aggregate(self, other: "KVConnectorStats") -> "LMCacheKVEvents":
41+
if not other and not isinstance(other, LMCacheKVEvents):
42+
raise TypeError("Can only aggregate with another LMCacheKVEvents")
43+
44+
if other.is_empty():
45+
return self
46+
47+
if self.is_empty():
48+
self.data["kv_events"] = []
49+
50+
other_events = other.get_kv_events()
51+
for other_event in other_events:
52+
self.data["kv_events"].append(other_event)
53+
54+
return self
55+
56+
def reset(self):
57+
self.data.clear()
58+
59+
def reduce(self) -> dict[str, int | float]:
60+
return {
61+
"kv_events": 0,
62+
}
63+
64+
def add_kv_event(self, event: BlockStored):
65+
if self.is_empty():
66+
self.data["kv_events"] = []
67+
self.data["kv_events"].append(event)
68+
69+
def get_kv_events(self) -> list[BlockStored] | None:
70+
if self.is_empty():
71+
return None
72+
return self.data["kv_events"]
73+
74+
def is_empty(self) -> bool:
75+
return not self.data or self.data.get("kv_events", 0) == 0
76+
77+
2978
class LMCacheConnectorV1(KVConnectorBase_V1):
3079
def __init__(
3180
self,
@@ -54,6 +103,8 @@ def __init__(
54103

55104
self._lmcache_engine = cls(vllm_config, role, self)
56105

106+
self._kv_events: list[KVCacheEvent] = []
107+
57108
# ==============================
58109
# Worker-side methods
59110
# ==============================
@@ -136,6 +187,32 @@ def get_finished(
136187
"""
137188
return self._lmcache_engine.get_finished(finished_req_ids)
138189

190+
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
191+
"""
192+
Get the KV connector stats collected during the last interval.
193+
"""
194+
assert self._lmcache_engine is not None
195+
196+
events = self._lmcache_engine.get_kv_events()
197+
if not events:
198+
return None
199+
200+
lmcache_kv_events: LMCacheKVEvents | None = None
201+
for event in events:
202+
if lmcache_kv_events is None:
203+
lmcache_kv_events = LMCacheKVEvents()
204+
block = BlockStored(
205+
block_hashes=event.block_hashes,
206+
parent_block_hash=event.parent_block_hash,
207+
token_ids=event.token_ids,
208+
lora_id=event.lora_id,
209+
block_size=event.block_size,
210+
medium=event.medium,
211+
)
212+
lmcache_kv_events.add_kv_event(block)
213+
214+
return lmcache_kv_events
215+
139216
# ==============================
140217
# Scheduler-side methods
141218
# ==============================
@@ -183,6 +260,25 @@ def build_connector_meta(
183260
"""
184261
return self._lmcache_engine.build_connector_meta(scheduler_output)
185262

263+
def update_connector_output(self, connector_output: KVConnectorOutput):
264+
"""
265+
Update KVConnector state from worker-side connectors output.
266+
267+
Args:
268+
connector_output (KVConnectorOutput): the worker-side
269+
connectors output.
270+
"""
271+
# Get the KV events
272+
kv_events = connector_output.kv_connector_stats
273+
if (
274+
not kv_events
275+
or not isinstance(kv_events, LMCacheKVEvents)
276+
or kv_events.is_empty()
277+
):
278+
return
279+
self._kv_events = kv_events.get_kv_events()
280+
return
281+
186282
def request_finished(
187283
self,
188284
request: "Request",
@@ -199,3 +295,20 @@ def request_finished(
199295
returned by the engine.
200296
"""
201297
return self._lmcache_engine.request_finished(request, block_ids)
298+
299+
def take_events(self) -> Iterable["KVCacheEvent"]:
300+
"""
301+
Take the KV cache events from the connector.
302+
303+
Yields:
304+
New KV cache events since the last call.
305+
"""
306+
if self._kv_events is not None:
307+
yield from self._kv_events
308+
self._kv_events.clear()
309+
310+
@classmethod
311+
def build_kv_connector_stats(
312+
cls, data: dict[str, Any] | None = None
313+
) -> KVConnectorStats | None:
314+
return LMCacheKVEvents(data=data) if data is not None else LMCacheKVEvents()

0 commit comments

Comments
 (0)