diff --git a/tests/v1/executor/test_executor.py b/tests/v1/executor/test_executor.py index 91bfba6826e0..e9f635378e57 100644 --- a/tests/v1/executor/test_executor.py +++ b/tests/v1/executor/test_executor.py @@ -9,6 +9,7 @@ import pytest +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.sampling_params import SamplingParams from vllm.v1.engine.async_llm import AsyncLLM @@ -28,12 +29,19 @@ def collective_rpc( kwargs: dict | None = None, non_block: bool = False, unique_reply_rank: int | None = None, + kv_output_aggregator: KVOutputAggregator = None, ) -> Any | list[Any] | Future[Any | list[Any]]: # Drop marker to show that this was run with open(".marker", "w"): ... return super().collective_rpc( - method, timeout, args, kwargs, non_block, unique_reply_rank + method, + timeout, + args, + kwargs, + non_block, + unique_reply_rank, + kv_output_aggregator, ) diff --git a/tests/v1/kv_connector/unit/test_output_aggregator.py b/tests/v1/kv_connector/unit/test_output_aggregator.py index d186f677c02f..b083ccef9819 100644 --- a/tests/v1/kv_connector/unit/test_output_aggregator.py +++ b/tests/v1/kv_connector/unit/test_output_aggregator.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from concurrent.futures import Future import pytest @@ -86,74 +85,6 @@ def test_aggregate_workers_output(): assert aggregated.invalid_block_ids == {3, 4, 5} -def test_async_aggregate_workers_output(): - aggregator = KVOutputAggregator(expected_finished_count=2) - - future: Future[list[DummyModelRunnerOutput]] = Future() - result_future = aggregator.async_aggregate(future) - - output1 = DummyModelRunnerOutput() - output2 = DummyModelRunnerOutput() - future.set_result([output1, output2]) - - assert result_future.done() - aggregated = result_future.result() - assert aggregated is output1 - aggregated = aggregated.kv_connector_output - assert aggregated.finished_sending is None - assert aggregated.finished_recving is None - assert not aggregated.invalid_block_ids - - future = Future() - result_future = aggregator.async_aggregate(future) - - output1 = DummyModelRunnerOutput( - finished_sending={"req1"}, finished_recving={"req2"} - ) - output2 = DummyModelRunnerOutput(invalid_block_ids={1}) - future.set_result([output1, output2]) - - assert result_future.done() - aggregated = result_future.result() - assert aggregated is output1 - aggregated = aggregated.kv_connector_output - assert aggregated.finished_sending is None - assert aggregated.finished_recving is None - assert aggregated.invalid_block_ids == {1} - - future = Future() - result_future = aggregator.async_aggregate(future) - - output1 = DummyModelRunnerOutput(invalid_block_ids={2}) - output2 = DummyModelRunnerOutput(finished_sending={"req1"}) - future.set_result([output1, output2]) - - assert result_future.done() - aggregated = result_future.result() - assert aggregated is output1 - aggregated = aggregated.kv_connector_output - assert aggregated.finished_sending == {"req1"} - assert aggregated.finished_recving is None - assert aggregated.invalid_block_ids == {2} - - future = Future() - result_future = aggregator.async_aggregate(future) - - output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4}) - output2 = DummyModelRunnerOutput( - finished_recving={"req2"}, invalid_block_ids={4, 5} - ) - future.set_result([output1, output2]) - - assert result_future.done() - aggregated = result_future.result() - assert aggregated is output1 - aggregated = aggregated.kv_connector_output - assert aggregated.finished_sending is None - assert aggregated.finished_recving == {"req2"} - assert aggregated.invalid_block_ids == {3, 4, 5} - - def test_aggregate_workers_output_with_expected_finished_count(): # We create the aggregator expecting to collect from 4 workers aggregator = KVOutputAggregator(expected_finished_count=4) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 33a801e135d4..b8eb5ea3b493 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -4,9 +4,6 @@ KV cache helper for store. """ -import contextlib -from collections.abc import Sequence -from concurrent.futures import CancelledError, Future from typing import TYPE_CHECKING, Literal import torch @@ -220,43 +217,6 @@ def update_finished_set( return output - def async_aggregate( - self, - output_future: Future[Sequence[ModelRunnerOutput | None]], - output_rank: int = 0, - ) -> Future[ModelRunnerOutput | None]: - """Takes a future that resolves to a list of outputs and returns a future - which resolves to a single aggregated output.""" - result_future: Future[ModelRunnerOutput | None] = Future() - - def callback(fut): - if result_future.done(): - return - try: - result_future.set_result(self.aggregate(fut.result(), output_rank)) - except CancelledError: - result_future.cancel() - except Exception as e: - result_future.set_exception(e) - - output_future.add_done_callback(callback) - - from vllm.v1.executor.multiproc_executor import FutureWrapper - - if isinstance(output_future, FutureWrapper): - # Due to the threadless implementation of multiproc FutureWrapper, - # we must block on the delegate future's result() method. - delegate_result = result_future.result - - def result(timeout=None): - with contextlib.suppress(Exception): - output_future.result(timeout=timeout) - return delegate_result() - - result_future.result = result # type: ignore[method-assign] - - return result_future - def _make_src_and_dst_indices( src_block_ids: list[int], diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index c9a50ecaa1de..1e249161c688 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -29,6 +29,7 @@ from vllm.config import VllmConfig from vllm.distributed import destroy_distributed_environment, destroy_model_parallel from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.parallel_state import ( get_dp_group, get_ep_group, @@ -57,8 +58,13 @@ class FutureWrapper(Future): - def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]): + def __init__( + self, + futures_queue: deque[tuple["FutureWrapper", Callable]], + aggregate: Callable = lambda x: x, + ): self.futures_queue = futures_queue + self.aggregate = aggregate super().__init__() def result(self, timeout=None): @@ -72,7 +78,7 @@ def result(self, timeout=None): def wait_for_response(self, get_response: Callable): try: - response = get_response() + response = self.aggregate(get_response()) with suppress(InvalidStateError): self.set_result(response) except Exception as e: @@ -160,7 +166,6 @@ def _init_executor(self) -> None: self.futures_queue = deque[tuple[FutureWrapper, Callable]]() self.output_rank = self._get_output_rank() - self.has_connector = self.vllm_config.kv_transfer_config is not None def start_worker_monitor(self): workers = self.workers @@ -199,44 +204,27 @@ def register_failure_callback(self, callback: FailureCallback): def execute_model( # type: ignore[override] self, scheduler_output: SchedulerOutput, non_block: bool = False ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: - return self._execute_with_aggregation( - "execute_model", scheduler_output, non_block=non_block + return self.collective_rpc( + "execute_model", + args=(scheduler_output,), + unique_reply_rank=self.output_rank, + non_block=non_block, + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, + kv_output_aggregator=self.kv_output_aggregator, ) def sample_tokens( # type: ignore[override] self, grammar_output: GrammarOutput | None, non_block: bool = False ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: - return self._execute_with_aggregation( # type: ignore[return-value] - "sample_tokens", grammar_output, non_block=non_block - ) - - def _execute_with_aggregation( - self, method: str, *args, non_block: bool = False - ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: - if not self.has_connector: - # get output only from a single worker (output_rank) - return self.collective_rpc( - method, - args=args, - unique_reply_rank=self.output_rank, - non_block=non_block, - timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, - ) - - # get output from all workers - outputs = self.collective_rpc( - method, - args=args, + return self.collective_rpc( + "sample_tokens", + args=(grammar_output,), + unique_reply_rank=self.output_rank, non_block=non_block, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, + kv_output_aggregator=self.kv_output_aggregator, ) - # aggregate all workers output to a single output - assert self.kv_output_aggregator is not None - if non_block: - return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank) - return self.kv_output_aggregator.aggregate(outputs, self.output_rank) - def execute_dummy_batch(self) -> None: self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank) @@ -254,8 +242,10 @@ def collective_rpc( # type: ignore[override] kwargs: dict | None = None, non_block: bool = False, unique_reply_rank: int | None = None, + kv_output_aggregator: KVOutputAggregator = None, ) -> Any | list[Any] | Future[Any | list[Any]]: - """Returns single result if unique_reply_rank is provided, otherwise list.""" + """Returns single result if unique_reply_rank and/or kv_output_aggregator + is provided, otherwise list.""" if self.is_failed: raise RuntimeError("Executor failed.") @@ -263,20 +253,23 @@ def collective_rpc( # type: ignore[override] deadline = None if timeout is None else time.monotonic() + timeout kwargs = kwargs or {} - # NOTE: If the args are heterogeneous, then we pack them into a list, - # and unpack them in the method of every worker, because every worker - # knows their own rank. + if kv_output_aggregator is not None: + output_rank = None + aggregate: Callable[[Any], Any] = partial( + kv_output_aggregator.aggregate, output_rank=unique_reply_rank or 0 + ) + else: + output_rank = unique_reply_rank + aggregate = lambda x: x if isinstance(method, str): send_method = method else: send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL) - self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank)) + self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank)) workers = ( - (self.workers[unique_reply_rank],) - if unique_reply_rank is not None - else self.workers + (self.workers[output_rank],) if output_rank is not None else self.workers ) shutdown_event = self.shutdown_event @@ -299,10 +292,10 @@ def get_response(): " stack trace above for the root cause" ) responses.append(result) - return responses[0] if unique_reply_rank is not None else responses + return responses[0] if output_rank is not None else responses if non_block: - future = FutureWrapper(self.futures_queue) + future = FutureWrapper(self.futures_queue, aggregate=aggregate) self.futures_queue.appendleft((future, get_response)) return future @@ -311,7 +304,7 @@ def get_response(): future, get_fut_response = self.futures_queue.pop() future.wait_for_response(get_fut_response) - return get_response() + return aggregate(get_response()) @staticmethod def _ensure_worker_termination(worker_procs: list[BaseProcess]):