diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 7b5cb94cf13e..c0a3f434d37f 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -5,7 +5,7 @@ import threading import time from abc import ABC, abstractmethod -from collections import deque +from collections import Counter, deque from collections.abc import Callable from dataclasses import asdict from itertools import count @@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent): lora_id: int | None medium: str | None + def __hash__(self) -> int: + return hash( + ( + tuple(self.block_hashes), + self.parent_block_hash, + tuple(self.token_ids), + self.block_size, + self.lora_id, + self.medium, + ) + ) + class BlockRemoved(KVCacheEvent): block_hashes: list[ExternalBlockHash] medium: str | None + def __hash__(self) -> int: + return hash((tuple(self.block_hashes), self.medium)) + class AllBlocksCleared(KVCacheEvent): pass @@ -68,6 +83,103 @@ class KVEventBatch(EventBatch): events: list[BlockStored | BlockRemoved | AllBlocksCleared] +class KVEventAggregator: + """ + Aggregates KV events across multiple workers. + Tracks how many times each event appears and returns only those + that were emitted by all workers. + """ + + __slots__ = ("_event_counter", "_num_workers") + + def __init__(self, num_workers: int) -> None: + if num_workers <= 0: + raise ValueError("num_workers must be greater than zero.") + self._event_counter: Counter[KVCacheEvent] = Counter() + self._num_workers: int = num_workers + + def add_events(self, events: list[KVCacheEvent]) -> None: + """ + Add events from a worker batch. + + :param events: List of KVCacheEvent objects. + """ + if not isinstance(events, list): + raise TypeError("events must be a list of KVCacheEvent.") + self._event_counter.update(events) + + def get_common_events(self) -> list[KVCacheEvent]: + """ + Return events that appeared in all workers. + + :return: List of events present in all workers. + """ + return [ + event + for event, count in self._event_counter.items() + if count == self._num_workers + ] + + def get_all_events(self) -> list[KVCacheEvent]: + """ + Return all events for all workers. + + :return: List of events for all workers. + """ + return list(self._event_counter.elements()) + + def clear_events(self) -> None: + """ + Clear all tracked events. + """ + self._event_counter.clear() + + def increment_workers(self, count: int = 1) -> None: + """ + Increment the number of workers contributing events. + + :param count: Number of workers to add. + """ + if count <= 0: + raise ValueError("count must be positive.") + self._num_workers += count + + def reset_workers(self) -> None: + """ + Reset the number of workers to 1. + """ + self._num_workers = 1 + + def __repr__(self) -> str: + return ( + f"" + ) + + +class KVConnectorKVEvents(ABC): + """ + Abstract base class for KV events. + Acts as a container for KV events from the connector. + """ + + @abstractmethod + def add_events(self, events: list[KVCacheEvent]) -> None: + raise NotImplementedError + + @abstractmethod + def aggregate(self) -> "KVConnectorKVEvents": + raise NotImplementedError + + @abstractmethod + def increment_workers(self, count: int = 1) -> None: + raise NotImplementedError + + @abstractmethod + def get_all_events(self) -> list[KVCacheEvent]: + raise NotImplementedError + + class EventPublisher(ABC): """Lightweight publisher for EventBatch batches with data parallelism support. diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index b8eb5ea3b493..2b93f658fdea 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -160,6 +160,7 @@ def update_finished_set( finished_sending = set[str]() finished_recving = set[str]() aggregated_kv_connector_stats = None + combined_kv_cache_events = None invalid_block_ids = set[int]() for model_runner_output in outputs: assert model_runner_output is not None @@ -201,16 +202,36 @@ def update_finished_set( aggregated_kv_connector_stats.aggregate(kv_connector_stats) ) + # Combine kv_cache_events from all workers. + if combined_kv_cache_events is None: + # Use the first worker's kv_cache events as start event list. + combined_kv_cache_events = kv_output.kv_cache_events + elif kv_cache_events := kv_output.kv_cache_events: + assert isinstance( + combined_kv_cache_events, + type(kv_cache_events), + ) + worker_kv_cache_events = kv_cache_events.get_all_events() + combined_kv_cache_events.add_events(worker_kv_cache_events) + combined_kv_cache_events.increment_workers() + invalid_block_ids |= kv_output.invalid_block_ids # select output of the worker specified by output_rank output = outputs[output_rank] + # Aggregate the events across workers. + # This operation needs to be done post worker processing so that we have all + # events for all workers. + if combined_kv_cache_events is not None: + combined_kv_cache_events = combined_kv_cache_events.aggregate() + assert output is not None output.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending or None, finished_recving=finished_recving or None, kv_connector_stats=aggregated_kv_connector_stats or None, + kv_cache_events=combined_kv_cache_events or None, invalid_block_ids=invalid_block_ids, expected_finished_count=self._expected_finished_count, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 74f09278b7bb..64c8680d87bb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -49,7 +49,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.config import VllmConfig - from vllm.distributed.kv_events import KVCacheEvent + from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, KVConnectorStats, @@ -379,6 +379,14 @@ def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]: """ return None + def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]: + """ + Get the KV connector kv cache events collected during the last interval. + This function should be called by the model runner every time after the + model execution and before cleanup. + """ + return None + def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: """ Get the KVConnector handshake metadata for this connector. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 0c24a53fb754..99d74b9a36c5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable from typing import TYPE_CHECKING, Any import torch @@ -8,6 +9,12 @@ ) from vllm.config import VllmConfig +from vllm.distributed.kv_events import ( + BlockStored, + KVCacheEvent, + KVConnectorKVEvents, + KVEventAggregator, +) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, @@ -15,6 +22,7 @@ ) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -26,6 +34,37 @@ logger = init_logger(__name__) +class LMCacheKVEvents(KVConnectorKVEvents): + """ + Concrete implementation of KVConnectorKVEvents using KVEventAggregator. + """ + + def __init__(self, num_workers: int) -> None: + self._aggregator = KVEventAggregator(num_workers) + + def add_events(self, events: list[KVCacheEvent]) -> None: + self._aggregator.add_events(events) + + def aggregate(self) -> "LMCacheKVEvents": + """ + Aggregate KV events and retain only common events. + """ + common_events = self._aggregator.get_common_events() + self._aggregator.clear_events() + self._aggregator.add_events(common_events) + self._aggregator.reset_workers() + return self + + def increment_workers(self, count: int = 1) -> None: + self._aggregator.increment_workers(count) + + def get_all_events(self) -> list[KVCacheEvent]: + return self._aggregator.get_all_events() + + def __repr__(self) -> str: + return f"" + + class LMCacheConnectorV1(KVConnectorBase_V1): def __init__( self, @@ -54,6 +93,8 @@ def __init__( self._lmcache_engine = cls(vllm_config, role, self) + self._kv_cache_events: list[KVCacheEvent] = [] + # ============================== # Worker-side methods # ============================== @@ -151,6 +192,31 @@ def get_block_ids_with_load_errors(self) -> set[int]: # Fallback for older versions that don't support this method return set() + def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None: + """ + Get the KV connector kv cache events collected during the last interval. + """ + + events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined] + if not events: + return None + + blocks: list[BlockStored] = [ + BlockStored( + block_hashes=e.block_hashes, + parent_block_hash=e.parent_block_hash, + token_ids=e.token_ids, + lora_id=e.lora_id, + block_size=e.block_size, + medium=e.medium, + ) + for e in events + ] + + lmcache_kv_events = LMCacheKVEvents(num_workers=1) + lmcache_kv_events.add_events(blocks) + return lmcache_kv_events + # ============================== # Scheduler-side methods # ============================== @@ -198,6 +264,21 @@ def build_connector_meta( """ return self._lmcache_engine.build_connector_meta(scheduler_output) + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + # Get the KV events + kv_cache_events = connector_output.kv_cache_events + if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents): + return + self._kv_cache_events.extend(kv_cache_events.get_all_events()) + return + def request_finished( self, request: "Request", @@ -214,3 +295,14 @@ def request_finished( returned by the engine. """ return self._lmcache_engine.request_finished(request, block_ids) + + def take_events(self) -> Iterable["KVCacheEvent"]: + """ + Take the KV cache events from the connector. + + Yields: + New KV cache events since the last call. + """ + if self._kv_cache_events is not None: + yield from self._kv_cache_events + self._kv_cache_events.clear() diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index e32d5bb608b1..24a1e2341e4d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -11,9 +11,11 @@ from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: + from vllm.distributed.kv_events import KVConnectorKVEvents from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats else: KVConnectorStats = object + KVConnectorKVEvents = object class LogprobsLists(NamedTuple): @@ -119,6 +121,7 @@ class KVConnectorOutput: finished_sending: set[str] | None = None finished_recving: set[str] | None = None kv_connector_stats: KVConnectorStats | None = None + kv_cache_events: KVConnectorKVEvents | None = None # IDs of externally computed KV blocks that failed to load. # Requests referencing these blocks should be rescheduled to recompute them invalid_block_ids: set[int] = field(default_factory=set) @@ -134,6 +137,7 @@ def is_empty(self): not self.finished_sending and not self.finished_recving and not self.kv_connector_stats + and not self.kv_cache_events and not self.invalid_block_ids ) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index ff047d8d03f0..df4679e5d9c2 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -22,7 +22,6 @@ has_kv_transfer_group, ) from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig @@ -138,17 +137,17 @@ def _get_kv_connector_output( ) output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors() - output.kv_connector_stats = ( - KVConnectorModelRunnerMixin.get_kv_connector_stats() - ) + if has_kv_transfer_group(): + output.kv_connector_stats = ( + get_kv_transfer_group().get_kv_connector_stats() + ) + output.kv_cache_events = ( + get_kv_transfer_group().get_kv_connector_kv_cache_events() + ) + else: + output.kv_connector_stats = output.kv_cache_events = None kv_connector.clear_connector_metadata() - @staticmethod - def get_kv_connector_stats() -> KVConnectorStats | None: - if has_kv_transfer_group(): - return get_kv_transfer_group().get_kv_connector_stats() - return None - @staticmethod def use_uniform_kv_cache( attn_groups: list[list[AttentionGroup]],