Skip to content

Commit fdf5784

Browse files
committed
Handle duplicate kv events from workers
Workers will generate duplicate kv events in LMCache. This commit adds capability to aggregate on events from workers by returning only those that were emitted by all workers. It also provide an abstract class KVConnectorKVEevnts that is implemented by connectors to handle how they emit events. Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
1 parent 6973123 commit fdf5784

File tree

5 files changed

+216
-19
lines changed

5 files changed

+216
-19
lines changed

vllm/distributed/kv_events.py

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import threading
66
import time
77
from abc import ABC, abstractmethod
8-
from collections import deque
8+
from collections import Counter, deque
99
from collections.abc import Callable
1010
from dataclasses import asdict
1111
from itertools import count
@@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
5454
lora_id: int | None
5555
medium: str | None
5656

57+
def __hash__(self) -> int:
58+
return hash(
59+
(
60+
tuple(self.block_hashes),
61+
self.parent_block_hash,
62+
tuple(self.token_ids),
63+
self.block_size,
64+
self.lora_id,
65+
self.medium,
66+
)
67+
)
68+
5769

5870
class BlockRemoved(KVCacheEvent):
5971
block_hashes: list[ExternalBlockHash]
6072
medium: str | None
6173

74+
def __hash__(self) -> int:
75+
return hash((tuple(self.block_hashes), self.medium))
76+
6277

6378
class AllBlocksCleared(KVCacheEvent):
6479
pass
@@ -67,16 +82,102 @@ class AllBlocksCleared(KVCacheEvent):
6782
class KVEventBatch(EventBatch):
6883
events: list[BlockStored | BlockRemoved | AllBlocksCleared]
6984

70-
def combine_unique_ordered_events(self, other: "KVEventBatch") -> "KVEventBatch":
85+
86+
class KVEventAggregator:
87+
"""
88+
Aggregates KV events across multiple workers.
89+
Tracks how many times each event appears and returns only those
90+
that were emitted by all workers.
91+
"""
92+
93+
__slots__ = ("_event_counter", "_num_workers")
94+
95+
def __init__(self, num_workers: int) -> None:
96+
if num_workers <= 0:
97+
raise ValueError("num_workers must be greater than zero.")
98+
self._event_counter: Counter[KVCacheEvent] = Counter()
99+
self._num_workers: int = num_workers
100+
101+
def add_events(self, events: list[KVCacheEvent]) -> None:
102+
"""
103+
Add events from a worker batch.
104+
105+
:param events: List of KVCacheEvent objects.
106+
"""
107+
if not isinstance(events, list):
108+
raise TypeError("events must be a list of KVCacheEvent.")
109+
self._event_counter.update(events)
110+
111+
def get_common_events(self) -> list[KVCacheEvent]:
112+
"""
113+
Return events that appeared in all workers.
114+
115+
:return: List of events present in all workers.
116+
"""
117+
return [
118+
event
119+
for event, count in self._event_counter.items()
120+
if count == self._num_workers
121+
]
122+
123+
def get_all_events(self) -> list[KVCacheEvent]:
124+
"""
125+
Return all events for all workers.
126+
127+
:return: List of events for all workers.
128+
"""
129+
return list(self._event_counter.elements())
130+
131+
def clear_events(self) -> None:
71132
"""
72-
Combine non duplicated events with another `KVEventBatch` object.
133+
Clear all tracked events.
73134
"""
74-
checked_events = set(self.events)
75-
for item in other.events:
76-
if item not in checked_events:
77-
self.events.append(item)
78-
checked_events.add(item)
79-
return self
135+
self._event_counter.clear()
136+
137+
def increment_workers(self, count: int = 1) -> None:
138+
"""
139+
Increment the number of workers contributing events.
140+
141+
:param count: Number of workers to add.
142+
"""
143+
if count <= 0:
144+
raise ValueError("count must be positive.")
145+
self._num_workers += count
146+
147+
def reset_workers(self) -> None:
148+
"""
149+
Reset the number of workers to 1.
150+
"""
151+
self._num_workers = 1
152+
153+
def __repr__(self) -> str:
154+
return (
155+
f"<KVEventAggregator workers={self._num_workers}, "
156+
f"events={len(self._event_counter)}>"
157+
)
158+
159+
160+
class KVConnectorKVEvents(ABC):
161+
"""
162+
Abstract base class for KV events.
163+
Acts as a container for KV events from the connector.
164+
"""
165+
166+
@abstractmethod
167+
def add_events(self, events: list[KVCacheEvent]) -> None:
168+
raise NotImplementedError
169+
170+
@abstractmethod
171+
def aggregate(self) -> "KVConnectorKVEvents":
172+
raise NotImplementedError
173+
174+
@abstractmethod
175+
def increment_workers(self, count: int = 1) -> None:
176+
raise NotImplementedError
177+
178+
@abstractmethod
179+
def get_all_events(self) -> list[KVCacheEvent]:
180+
raise NotImplementedError
80181

81182

82183
class EventPublisher(ABC):

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,17 +211,21 @@ def update_finished_set(
211211
combined_kv_cache_events,
212212
type(kv_cache_events),
213213
)
214-
combined_kv_cache_events = (
215-
combined_kv_cache_events.combine_unique_ordered_events(
216-
kv_cache_events
217-
)
218-
)
214+
worker_kv_cache_events = kv_cache_events.get_all_events()
215+
combined_kv_cache_events.add_events(worker_kv_cache_events)
216+
combined_kv_cache_events.increment_workers()
219217

220218
invalid_block_ids |= kv_output.invalid_block_ids
221219

222220
# select output of the worker specified by output_rank
223221
output = outputs[output_rank]
224222

223+
# Aggregate the events across workers.
224+
# This operation needs to be done post worker processing so that we have all
225+
# events for all workers.
226+
if combined_kv_cache_events is not None:
227+
combined_kv_cache_events = combined_kv_cache_events.aggregate()
228+
225229
assert output is not None
226230
output.kv_connector_output = KVConnectorOutput(
227231
finished_sending=finished_sending or None,

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
if TYPE_CHECKING:
5050
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
5151
from vllm.config import VllmConfig
52-
from vllm.distributed.kv_events import KVCacheEvent, KVEventBatch
52+
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
5353
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
5454
KVConnectorPromMetrics,
5555
KVConnectorStats,
@@ -379,7 +379,7 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
379379
"""
380380
return None
381381

382-
def get_kv_connector_kv_cache_events(self) -> Optional["KVEventBatch"]:
382+
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
383383
"""
384384
Get the KV connector kv cache events collected during the last interval.
385385
This function should be called by the model runner every time after the

vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from collections.abc import Iterable
34
from typing import TYPE_CHECKING, Any
45

56
import torch
@@ -8,13 +9,20 @@
89
)
910

1011
from vllm.config import VllmConfig
12+
from vllm.distributed.kv_events import (
13+
BlockStored,
14+
KVCacheEvent,
15+
KVConnectorKVEvents,
16+
KVEventAggregator,
17+
)
1118
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1219
KVConnectorBase_V1,
1320
KVConnectorMetadata,
1421
KVConnectorRole,
1522
)
1623
from vllm.logger import init_logger
1724
from vllm.v1.core.sched.output import SchedulerOutput
25+
from vllm.v1.outputs import KVConnectorOutput
1826

1927
if TYPE_CHECKING:
2028
from vllm.attention.backends.abstract import AttentionMetadata
@@ -26,6 +34,37 @@
2634
logger = init_logger(__name__)
2735

2836

37+
class LMCacheKVEvents(KVConnectorKVEvents):
38+
"""
39+
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
40+
"""
41+
42+
def __init__(self, num_workers: int) -> None:
43+
self._aggregator = KVEventAggregator(num_workers)
44+
45+
def add_events(self, events: list[KVCacheEvent]) -> None:
46+
self._aggregator.add_events(events)
47+
48+
def aggregate(self) -> "LMCacheKVEvents":
49+
"""
50+
Aggregate KV events and retain only common events.
51+
"""
52+
common_events = self._aggregator.get_common_events()
53+
self._aggregator.clear_events()
54+
self._aggregator.add_events(common_events)
55+
self._aggregator.reset_workers()
56+
return self
57+
58+
def increment_workers(self, count: int = 1) -> None:
59+
self._aggregator.increment_workers(count)
60+
61+
def get_all_events(self) -> list[KVCacheEvent]:
62+
return self._aggregator.get_all_events()
63+
64+
def __repr__(self) -> str:
65+
return f"<LMCacheKVEvents events={self.get_all_events()}>"
66+
67+
2968
class LMCacheConnectorV1(KVConnectorBase_V1):
3069
def __init__(
3170
self,
@@ -54,6 +93,8 @@ def __init__(
5493

5594
self._lmcache_engine = cls(vllm_config, role, self)
5695

96+
self._kv_cache_events: list[KVCacheEvent] = []
97+
5798
# ==============================
5899
# Worker-side methods
59100
# ==============================
@@ -151,6 +192,31 @@ def get_block_ids_with_load_errors(self) -> set[int]:
151192
# Fallback for older versions that don't support this method
152193
return set()
153194

195+
def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
196+
"""
197+
Get the KV connector kv cache events collected during the last interval.
198+
"""
199+
200+
events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
201+
if not events:
202+
return None
203+
204+
blocks: list[BlockStored] = [
205+
BlockStored(
206+
block_hashes=e.block_hashes,
207+
parent_block_hash=e.parent_block_hash,
208+
token_ids=e.token_ids,
209+
lora_id=e.lora_id,
210+
block_size=e.block_size,
211+
medium=e.medium,
212+
)
213+
for e in events
214+
]
215+
216+
lmcache_kv_events = LMCacheKVEvents(num_workers=1)
217+
lmcache_kv_events.add_events(blocks)
218+
return lmcache_kv_events
219+
154220
# ==============================
155221
# Scheduler-side methods
156222
# ==============================
@@ -198,6 +264,21 @@ def build_connector_meta(
198264
"""
199265
return self._lmcache_engine.build_connector_meta(scheduler_output)
200266

267+
def update_connector_output(self, connector_output: KVConnectorOutput):
268+
"""
269+
Update KVConnector state from worker-side connectors output.
270+
271+
Args:
272+
connector_output (KVConnectorOutput): the worker-side
273+
connectors output.
274+
"""
275+
# Get the KV events
276+
kv_cache_events = connector_output.kv_cache_events
277+
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
278+
return
279+
self._kv_cache_events.extend(kv_cache_events.get_all_events())
280+
return
281+
201282
def request_finished(
202283
self,
203284
request: "Request",
@@ -214,3 +295,14 @@ def request_finished(
214295
returned by the engine.
215296
"""
216297
return self._lmcache_engine.request_finished(request, block_ids)
298+
299+
def take_events(self) -> Iterable["KVCacheEvent"]:
300+
"""
301+
Take the KV cache events from the connector.
302+
303+
Yields:
304+
New KV cache events since the last call.
305+
"""
306+
if self._kv_cache_events is not None:
307+
yield from self._kv_cache_events
308+
self._kv_cache_events.clear()

vllm/v1/outputs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from vllm.v1.core.sched.output import SchedulerOutput
1212

1313
if TYPE_CHECKING:
14-
from vllm.distributed.kv_events import KVEventBatch
14+
from vllm.distributed.kv_events import KVConnectorKVEvents
1515
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
1616
else:
1717
KVConnectorStats = object
18-
KVEventBatch = object
18+
KVConnectorKVEvents = object
1919

2020

2121
class LogprobsLists(NamedTuple):
@@ -121,7 +121,7 @@ class KVConnectorOutput:
121121
finished_sending: set[str] | None = None
122122
finished_recving: set[str] | None = None
123123
kv_connector_stats: KVConnectorStats | None = None
124-
kv_cache_events: KVEventBatch | None = None
124+
kv_cache_events: KVConnectorKVEvents | None = None
125125
# IDs of externally computed KV blocks that failed to load.
126126
# Requests referencing these blocks should be rescheduled to recompute them
127127
invalid_block_ids: set[int] = field(default_factory=set)

0 commit comments

Comments
 (0)