Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tests/v1/executor/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pytest

from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM
Expand All @@ -28,12 +29,19 @@ def collective_rpc(
kwargs: dict | None = None,
non_block: bool = False,
unique_reply_rank: int | None = None,
kv_output_aggregator: KVOutputAggregator = None,
) -> Any | list[Any] | Future[Any | list[Any]]:
# Drop marker to show that this was run
with open(".marker", "w"):
...
return super().collective_rpc(
method, timeout, args, kwargs, non_block, unique_reply_rank
method,
timeout,
args,
kwargs,
non_block,
unique_reply_rank,
kv_output_aggregator,
)


Expand Down
69 changes: 0 additions & 69 deletions tests/v1/kv_connector/unit/test_output_aggregator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import Future

import pytest

Expand Down Expand Up @@ -86,74 +85,6 @@ def test_aggregate_workers_output():
assert aggregated.invalid_block_ids == {3, 4, 5}


def test_async_aggregate_workers_output():
aggregator = KVOutputAggregator(expected_finished_count=2)

future: Future[list[DummyModelRunnerOutput]] = Future()
result_future = aggregator.async_aggregate(future)

output1 = DummyModelRunnerOutput()
output2 = DummyModelRunnerOutput()
future.set_result([output1, output2])

assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert not aggregated.invalid_block_ids

future = Future()
result_future = aggregator.async_aggregate(future)

output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, finished_recving={"req2"}
)
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
future.set_result([output1, output2])

assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {1}

future = Future()
result_future = aggregator.async_aggregate(future)

output1 = DummyModelRunnerOutput(invalid_block_ids={2})
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
future.set_result([output1, output2])

assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {"req1"}
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {2}

future = Future()
result_future = aggregator.async_aggregate(future)

output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, invalid_block_ids={4, 5}
)
future.set_result([output1, output2])

assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {"req2"}
assert aggregated.invalid_block_ids == {3, 4, 5}


def test_aggregate_workers_output_with_expected_finished_count():
# We create the aggregator expecting to collect from 4 workers
aggregator = KVOutputAggregator(expected_finished_count=4)
Expand Down
40 changes: 0 additions & 40 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
KV cache helper for store.
"""

import contextlib
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import TYPE_CHECKING, Literal

import torch
Expand Down Expand Up @@ -220,43 +217,6 @@ def update_finished_set(

return output

def async_aggregate(
self,
output_future: Future[Sequence[ModelRunnerOutput | None]],
output_rank: int = 0,
) -> Future[ModelRunnerOutput | None]:
"""Takes a future that resolves to a list of outputs and returns a future
which resolves to a single aggregated output."""
result_future: Future[ModelRunnerOutput | None] = Future()

def callback(fut):
if result_future.done():
return
try:
result_future.set_result(self.aggregate(fut.result(), output_rank))
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)

output_future.add_done_callback(callback)

from vllm.v1.executor.multiproc_executor import FutureWrapper

if isinstance(output_future, FutureWrapper):
# Due to the threadless implementation of multiproc FutureWrapper,
# we must block on the delegate future's result() method.
delegate_result = result_future.result

def result(timeout=None):
with contextlib.suppress(Exception):
output_future.result(timeout=timeout)
return delegate_result()

result_future.result = result # type: ignore[method-assign]

return result_future


def _make_src_and_dst_indices(
src_block_ids: list[int],
Expand Down
79 changes: 36 additions & 43 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vllm.config import VllmConfig
from vllm.distributed import destroy_distributed_environment, destroy_model_parallel
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.parallel_state import (
get_dp_group,
get_ep_group,
Expand Down Expand Up @@ -57,8 +58,13 @@


class FutureWrapper(Future):
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
def __init__(
self,
futures_queue: deque[tuple["FutureWrapper", Callable]],
aggregate: Callable = lambda x: x,
):
self.futures_queue = futures_queue
self.aggregate = aggregate
super().__init__()

def result(self, timeout=None):
Expand All @@ -72,7 +78,7 @@ def result(self, timeout=None):

def wait_for_response(self, get_response: Callable):
try:
response = get_response()
response = self.aggregate(get_response())
with suppress(InvalidStateError):
self.set_result(response)
except Exception as e:
Expand Down Expand Up @@ -160,7 +166,6 @@ def _init_executor(self) -> None:
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()

self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None

def start_worker_monitor(self):
workers = self.workers
Expand Down Expand Up @@ -199,44 +204,27 @@ def register_failure_callback(self, callback: FailureCallback):
def execute_model( # type: ignore[override]
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
return self._execute_with_aggregation(
"execute_model", scheduler_output, non_block=non_block
return self.collective_rpc(
"execute_model",
args=(scheduler_output,),
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
kv_output_aggregator=self.kv_output_aggregator,
)

def sample_tokens( # type: ignore[override]
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
return self._execute_with_aggregation( # type: ignore[return-value]
"sample_tokens", grammar_output, non_block=non_block
)

def _execute_with_aggregation(
self, method: str, *args, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if not self.has_connector:
# get output only from a single worker (output_rank)
return self.collective_rpc(
method,
args=args,
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)

# get output from all workers
outputs = self.collective_rpc(
method,
args=args,
return self.collective_rpc(
"sample_tokens",
args=(grammar_output,),
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
kv_output_aggregator=self.kv_output_aggregator,
)

# aggregate all workers output to a single output
assert self.kv_output_aggregator is not None
if non_block:
return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)

def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank)

Expand All @@ -254,29 +242,34 @@ def collective_rpc( # type: ignore[override]
kwargs: dict | None = None,
non_block: bool = False,
unique_reply_rank: int | None = None,
kv_output_aggregator: KVOutputAggregator = None,
) -> Any | list[Any] | Future[Any | list[Any]]:
"""Returns single result if unique_reply_rank is provided, otherwise list."""
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
is provided, otherwise list."""

if self.is_failed:
raise RuntimeError("Executor failed.")

deadline = None if timeout is None else time.monotonic() + timeout
kwargs = kwargs or {}

# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
if kv_output_aggregator is not None:
output_rank = None
aggregate: Callable[[Any], Any] = partial(
kv_output_aggregator.aggregate, output_rank=unique_reply_rank or 0
)
else:
output_rank = unique_reply_rank
aggregate = lambda x: x

if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))

workers = (
(self.workers[unique_reply_rank],)
if unique_reply_rank is not None
else self.workers
(self.workers[output_rank],) if output_rank is not None else self.workers
)

shutdown_event = self.shutdown_event
Expand All @@ -299,10 +292,10 @@ def get_response():
" stack trace above for the root cause"
)
responses.append(result)
return responses[0] if unique_reply_rank is not None else responses
return responses[0] if output_rank is not None else responses

if non_block:
future = FutureWrapper(self.futures_queue)
future = FutureWrapper(self.futures_queue, aggregate=aggregate)
self.futures_queue.appendleft((future, get_response))
return future

Expand All @@ -311,7 +304,7 @@ def get_response():
future, get_fut_response = self.futures_queue.pop()
future.wait_for_response(get_fut_response)

return get_response()
return aggregate(get_response())

@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
Expand Down