Skip to content

Commit 6487bda

Browse files
committed
Combine events from different worker connectors
It is part of the aggregation of kv_connector_output from all workers. For KV cache events, this means combining events from all workers, remvoing any duplications. Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
1 parent d0854fa commit 6487bda

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

vllm/distributed/kv_events.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,30 @@ class BlockStored(KVCacheEvent):
5454
lora_id: int | None
5555
medium: str | None
5656

57+
def __eq__(self, other):
58+
if isinstance(other, BlockStored):
59+
return (
60+
self.block_hashes == other.block_hashes
61+
and self.parent_block_hash == other.parent_block_hash
62+
and self.token_ids == other.token_ids
63+
and self.block_size == other.block_size
64+
and self.lora_id == other.lora_id
65+
and self.medium == other.medium
66+
)
67+
return False
68+
5769

5870
class BlockRemoved(KVCacheEvent):
5971
block_hashes: list[ExternalBlockHash]
6072
medium: str | None
6173

74+
def __eq__(self, other):
75+
if isinstance(other, BlockRemoved):
76+
return (
77+
self.block_hashes == other.block_hashes and self.medium == other.medium
78+
)
79+
return False
80+
6281

6382
class AllBlocksCleared(KVCacheEvent):
6483
pass
@@ -67,6 +86,17 @@ class AllBlocksCleared(KVCacheEvent):
6786
class KVEventBatch(EventBatch):
6887
events: list[BlockStored | BlockRemoved | AllBlocksCleared]
6988

89+
def combine_unique_ordered_events(self, other: "KVEventBatch") -> "KVEventBatch":
90+
"""
91+
Combine non duplicated events with another `KVEventBatch` object.
92+
"""
93+
combined_events = self.events[:]
94+
for item in other.events:
95+
if item not in combined_events:
96+
combined_events.append(item)
97+
self.events = combined_events
98+
return self
99+
70100

71101
class EventPublisher(ABC):
72102
"""Lightweight publisher for EventBatch batches with data parallelism

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def update_finished_set(
160160
finished_sending = set[str]()
161161
finished_recving = set[str]()
162162
aggregated_kv_connector_stats = None
163+
combined_kv_cache_events = None
163164
invalid_block_ids = set[int]()
164165
for model_runner_output in outputs:
165166
assert model_runner_output is not None
@@ -201,6 +202,21 @@ def update_finished_set(
201202
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
202203
)
203204

205+
# Combine kv_cache_events from all workers.
206+
if combined_kv_cache_events is None:
207+
# Use the first worker's kv_cache events as start event list.
208+
combined_kv_cache_events = kv_output.kv_cache_events
209+
elif kv_cache_events := kv_output.kv_cache_events:
210+
assert isinstance(
211+
combined_kv_cache_events,
212+
type(kv_cache_events),
213+
)
214+
combined_kv_cache_events = (
215+
combined_kv_cache_events.combine_unique_ordered_events(
216+
kv_cache_events
217+
)
218+
)
219+
204220
invalid_block_ids |= kv_output.invalid_block_ids
205221

206222
# select output of the worker specified by output_rank
@@ -211,6 +227,7 @@ def update_finished_set(
211227
finished_sending=finished_sending or None,
212228
finished_recving=finished_recving or None,
213229
kv_connector_stats=aggregated_kv_connector_stats or None,
230+
kv_connector_kv_cache_events=combined_kv_cache_events or None,
214231
invalid_block_ids=invalid_block_ids,
215232
expected_finished_count=self._expected_finished_count,
216233
)

0 commit comments

Comments
 (0)