11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- from typing import TYPE_CHECKING , Any
3+ from collections .abc import Iterable
4+ from dataclasses import dataclass
5+ from typing import TYPE_CHECKING , Any , Optional
46
57import torch
68from lmcache .integration .vllm .vllm_v1_adapter import (
79 LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl ,
810)
911
1012from vllm .config import VllmConfig
13+ from vllm .distributed .kv_events import BlockStored , KVCacheEvent
1114from vllm .distributed .kv_transfer .kv_connector .v1 .base import (
1215 KVConnectorBase_V1 ,
1316 KVConnectorMetadata ,
1417 KVConnectorRole ,
1518)
19+ from vllm .distributed .kv_transfer .kv_connector .v1 .metrics import KVConnectorStats
1620from vllm .logger import init_logger
1721from vllm .v1 .core .sched .output import SchedulerOutput
22+ from vllm .v1 .outputs import KVConnectorOutput
1823
1924if TYPE_CHECKING :
2025 from vllm .attention .backends .abstract import AttentionMetadata
2631logger = init_logger (__name__ )
2732
2833
34+ @dataclass
35+ class LMCacheKVEvents (KVConnectorStats ):
36+ """
37+ Maintain a list of KV events
38+ """
39+
40+ def aggregate (self , other : "KVConnectorStats" ) -> "LMCacheKVEvents" :
41+ if not other and not isinstance (other , LMCacheKVEvents ):
42+ raise TypeError ("Can only aggregate with another LMCacheKVEvents" )
43+
44+ if other .is_empty ():
45+ return self
46+
47+ if self .is_empty ():
48+ self .data ["kv_events" ] = []
49+
50+ other_events = other .get_kv_events ()
51+ for other_event in other_events :
52+ self .data ["kv_events" ].append (other_event )
53+
54+ return self
55+
56+ def reset (self ):
57+ self .data .clear ()
58+
59+ def reduce (self ) -> dict [str , int | float ]:
60+ return {
61+ "kv_events" : 0 ,
62+ }
63+
64+ def add_kv_event (self , event : BlockStored ):
65+ if self .is_empty ():
66+ self .data ["kv_events" ] = []
67+ self .data ["kv_events" ].append (event )
68+
69+ def get_kv_events (self ) -> list [BlockStored ] | None :
70+ if self .is_empty ():
71+ return None
72+ return self .data ["kv_events" ]
73+
74+ def is_empty (self ) -> bool :
75+ return not self .data or self .data .get ("kv_events" , 0 ) == 0
76+
77+
2978class LMCacheConnectorV1 (KVConnectorBase_V1 ):
3079 def __init__ (
3180 self ,
@@ -54,6 +103,8 @@ def __init__(
54103
55104 self ._lmcache_engine = cls (vllm_config , role , self )
56105
106+ self ._kv_events : list [KVCacheEvent ] = []
107+
57108 # ==============================
58109 # Worker-side methods
59110 # ==============================
@@ -136,6 +187,32 @@ def get_finished(
136187 """
137188 return self ._lmcache_engine .get_finished (finished_req_ids )
138189
190+ def get_kv_connector_stats (self ) -> Optional ["KVConnectorStats" ]:
191+ """
192+ Get the KV connector stats collected during the last interval.
193+ """
194+ assert self ._lmcache_engine is not None
195+
196+ events = self ._lmcache_engine .get_kv_events ()
197+ if not events :
198+ return None
199+
200+ lmcache_kv_events : LMCacheKVEvents | None = None
201+ for event in events :
202+ if lmcache_kv_events is None :
203+ lmcache_kv_events = LMCacheKVEvents ()
204+ block = BlockStored (
205+ block_hashes = event .block_hashes ,
206+ parent_block_hash = event .parent_block_hash ,
207+ token_ids = event .token_ids ,
208+ lora_id = event .lora_id ,
209+ block_size = event .block_size ,
210+ medium = event .medium ,
211+ )
212+ lmcache_kv_events .add_kv_event (block )
213+
214+ return lmcache_kv_events
215+
139216 # ==============================
140217 # Scheduler-side methods
141218 # ==============================
@@ -183,6 +260,25 @@ def build_connector_meta(
183260 """
184261 return self ._lmcache_engine .build_connector_meta (scheduler_output )
185262
263+ def update_connector_output (self , connector_output : KVConnectorOutput ):
264+ """
265+ Update KVConnector state from worker-side connectors output.
266+
267+ Args:
268+ connector_output (KVConnectorOutput): the worker-side
269+ connectors output.
270+ """
271+ # Get the KV events
272+ kv_events = connector_output .kv_connector_stats
273+ if (
274+ not kv_events
275+ or not isinstance (kv_events , LMCacheKVEvents )
276+ or kv_events .is_empty ()
277+ ):
278+ return
279+ self ._kv_events = kv_events .get_kv_events ()
280+ return
281+
186282 def request_finished (
187283 self ,
188284 request : "Request" ,
@@ -199,3 +295,20 @@ def request_finished(
199295 returned by the engine.
200296 """
201297 return self ._lmcache_engine .request_finished (request , block_ids )
298+
299+ def take_events (self ) -> Iterable ["KVCacheEvent" ]:
300+ """
301+ Take the KV cache events from the connector.
302+
303+ Yields:
304+ New KV cache events since the last call.
305+ """
306+ if self ._kv_events is not None :
307+ yield from self ._kv_events
308+ self ._kv_events .clear ()
309+
310+ @classmethod
311+ def build_kv_connector_stats (
312+ cls , data : dict [str , Any ] | None = None
313+ ) -> KVConnectorStats | None :
314+ return LMCacheKVEvents (data = data ) if data is not None else LMCacheKVEvents ()
0 commit comments