Skip to content

Commit 67a2da8

Browse files
authored
[PerfFix] Avoid separate thread for MP executor shm spin (take 2) (#28319)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent da786e3 commit 67a2da8

File tree

9 files changed

+153
-128
lines changed

9 files changed

+153
-128
lines changed

tests/v1/executor/test_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import asyncio
55
import os
66
from collections.abc import Callable
7+
from concurrent.futures import Future
78
from typing import Any
89

910
import pytest
@@ -27,7 +28,7 @@ def collective_rpc(
2728
kwargs: dict | None = None,
2829
non_block: bool = False,
2930
unique_reply_rank: int | None = None,
30-
) -> list[Any]:
31+
) -> Any | list[Any] | Future[Any | list[Any]]:
3132
# Drop marker to show that this was run
3233
with open(".marker", "w"):
3334
...

tests/v1/kv_connector/unit/test_output_aggregator.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,12 @@ def test_aggregate_workers_output():
8989
def test_async_aggregate_workers_output():
9090
aggregator = KVOutputAggregator(expected_finished_count=2)
9191

92-
future1: Future[DummyModelRunnerOutput] = Future()
93-
future2: Future[DummyModelRunnerOutput] = Future()
94-
result_future = aggregator.async_aggregate([future1, future2])
92+
future: Future[list[DummyModelRunnerOutput]] = Future()
93+
result_future = aggregator.async_aggregate(future)
9594

9695
output1 = DummyModelRunnerOutput()
9796
output2 = DummyModelRunnerOutput()
98-
future1.set_result(output1)
99-
future2.set_result(output2)
97+
future.set_result([output1, output2])
10098

10199
assert result_future.done()
102100
aggregated = result_future.result()
@@ -106,16 +104,14 @@ def test_async_aggregate_workers_output():
106104
assert aggregated.finished_recving is None
107105
assert not aggregated.invalid_block_ids
108106

109-
future1 = Future()
110-
future2 = Future()
111-
result_future = aggregator.async_aggregate([future1, future2])
107+
future = Future()
108+
result_future = aggregator.async_aggregate(future)
112109

113110
output1 = DummyModelRunnerOutput(
114111
finished_sending={"req1"}, finished_recving={"req2"}
115112
)
116113
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
117-
future1.set_result(output1)
118-
future2.set_result(output2)
114+
future.set_result([output1, output2])
119115

120116
assert result_future.done()
121117
aggregated = result_future.result()
@@ -125,14 +121,12 @@ def test_async_aggregate_workers_output():
125121
assert aggregated.finished_recving is None
126122
assert aggregated.invalid_block_ids == {1}
127123

128-
future1 = Future()
129-
future2 = Future()
130-
result_future = aggregator.async_aggregate([future1, future2])
124+
future = Future()
125+
result_future = aggregator.async_aggregate(future)
131126

132127
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
133128
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
134-
future1.set_result(output1)
135-
future2.set_result(output2)
129+
future.set_result([output1, output2])
136130

137131
assert result_future.done()
138132
aggregated = result_future.result()
@@ -142,16 +136,14 @@ def test_async_aggregate_workers_output():
142136
assert aggregated.finished_recving is None
143137
assert aggregated.invalid_block_ids == {2}
144138

145-
future1 = Future()
146-
future2 = Future()
147-
result_future = aggregator.async_aggregate([future1, future2])
139+
future = Future()
140+
result_future = aggregator.async_aggregate(future)
148141

149142
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
150143
output2 = DummyModelRunnerOutput(
151144
finished_recving={"req2"}, invalid_block_ids={4, 5}
152145
)
153-
future1.set_result(output1)
154-
future2.set_result(output2)
146+
future.set_result([output1, output2])
155147

156148
assert result_future.done()
157149
aggregated = result_future.result()

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
KV cache helper for store.
55
"""
66

7+
import contextlib
78
from collections.abc import Sequence
89
from concurrent.futures import CancelledError, Future
910
from typing import TYPE_CHECKING, Literal
@@ -221,38 +222,38 @@ def update_finished_set(
221222

222223
def async_aggregate(
223224
self,
224-
output_futures: Sequence[Future[ModelRunnerOutput | None]],
225+
output_future: Future[Sequence[ModelRunnerOutput | None]],
225226
output_rank: int = 0,
226227
) -> Future[ModelRunnerOutput | None]:
227-
"""Takes a list of futures and returns a single future which resolves
228-
to the respective list of outputs."""
228+
"""Takes a future that resolves to a list of outputs and returns a future
229+
which resolves to a single aggregated output."""
229230
result_future: Future[ModelRunnerOutput | None] = Future()
230231

231-
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
232-
remaining = len(output_futures)
232+
def callback(fut):
233+
if result_future.done():
234+
return
235+
try:
236+
result_future.set_result(self.aggregate(fut.result(), output_rank))
237+
except CancelledError:
238+
result_future.cancel()
239+
except Exception as e:
240+
result_future.set_exception(e)
233241

234-
def make_callback(idx):
235-
def callback(fut):
236-
if result_future.done():
237-
return
242+
output_future.add_done_callback(callback)
238243

239-
try:
240-
outputs[idx] = fut.result()
241-
except CancelledError:
242-
result_future.cancel()
243-
except Exception as e:
244-
result_future.set_exception(e)
244+
from vllm.v1.executor.multiproc_executor import FutureWrapper
245245

246-
# this check assumes io_thread_pool uses a single thread
247-
nonlocal remaining
248-
remaining -= 1
249-
if not remaining:
250-
result_future.set_result(self.aggregate(outputs, output_rank))
246+
if isinstance(output_future, FutureWrapper):
247+
# Due to the threadless implementation of multiproc FutureWrapper,
248+
# we must block on the delegate future's result() method.
249+
delegate_result = result_future.result
251250

252-
return callback
251+
def result(timeout=None):
252+
with contextlib.suppress(Exception):
253+
output_future.result(timeout=timeout)
254+
return delegate_result()
253255

254-
for i, output_future in enumerate(output_futures):
255-
output_future.add_done_callback(make_callback(i))
256+
result_future.result = result # type: ignore[method-assign]
256257

257258
return result_future
258259

vllm/v1/executor/abstract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def collective_rpc(
171171
args: tuple = (),
172172
kwargs: dict | None = None,
173173
non_block: Literal[True] = True,
174-
) -> list[Future[_R]]:
174+
) -> Future[list[_R]]:
175175
pass
176176

177177
@abstractmethod
@@ -219,7 +219,7 @@ def sample_tokens(
219219

220220
def sample_tokens(
221221
self, grammar_output: GrammarOutput | None, non_block: bool = False
222-
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
222+
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
223223
output = self.collective_rpc( # type: ignore[call-overload]
224224
"sample_tokens", args=(grammar_output,), non_block=non_block
225225
)

vllm/v1/executor/multiproc_executor.py

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import time
1010
import traceback
1111
import weakref
12+
from collections import deque
1213
from collections.abc import Callable
13-
from concurrent.futures import Future, ThreadPoolExecutor
14+
from concurrent.futures import Future, InvalidStateError
15+
from contextlib import suppress
1416
from dataclasses import dataclass
1517
from enum import Enum, auto
1618
from functools import cached_property, partial
@@ -54,6 +56,30 @@
5456
logger = init_logger(__name__)
5557

5658

59+
class FutureWrapper(Future):
60+
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
61+
self.futures_queue = futures_queue
62+
super().__init__()
63+
64+
def result(self, timeout=None):
65+
if timeout is not None:
66+
raise RuntimeError("timeout not implemented")
67+
# Drain any futures ahead of us in the queue.
68+
while not self.done():
69+
future, get_response = self.futures_queue.pop()
70+
future.wait_for_response(get_response)
71+
return super().result()
72+
73+
def wait_for_response(self, get_response: Callable):
74+
try:
75+
response = get_response()
76+
with suppress(InvalidStateError):
77+
self.set_result(response)
78+
except Exception as e:
79+
with suppress(InvalidStateError):
80+
self.set_exception(e)
81+
82+
5783
class MultiprocExecutor(Executor):
5884
supports_pp: bool = True
5985

@@ -64,7 +90,6 @@ def _init_executor(self) -> None:
6490
self.is_failed = False
6591
self.shutdown_event = threading.Event()
6692
self.failure_callback: FailureCallback | None = None
67-
self.io_thread_pool: ThreadPoolExecutor | None = None
6893

6994
self.world_size = self.parallel_config.world_size
7095
tensor_parallel_size = self.parallel_config.tensor_parallel_size
@@ -132,12 +157,7 @@ def _init_executor(self) -> None:
132157
uw.death_writer.close()
133158
self._ensure_worker_termination([uw.proc for uw in unready_workers])
134159

135-
# Note: must use only 1 IO thread to keep dequeue sequence
136-
# from the response queue.
137-
# _async_aggregate_workers_output also assumes a single IO thread.
138-
self.io_thread_pool = ThreadPoolExecutor(
139-
max_workers=1, thread_name_prefix="mp_exec_io"
140-
)
160+
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
141161

142162
self.output_rank = self._get_output_rank()
143163
self.has_connector = self.vllm_config.kv_transfer_config is not None
@@ -195,14 +215,13 @@ def _execute_with_aggregation(
195215
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
196216
if not self.has_connector:
197217
# get output only from a single worker (output_rank)
198-
(output,) = self.collective_rpc(
218+
return self.collective_rpc(
199219
method,
200220
args=args,
201221
unique_reply_rank=self.output_rank,
202222
non_block=non_block,
203223
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
204224
)
205-
return output
206225

207226
# get output from all workers
208227
outputs = self.collective_rpc(
@@ -223,20 +242,21 @@ def execute_dummy_batch(self) -> None:
223242

224243
def take_draft_token_ids(self) -> DraftTokenIds | None:
225244
# OPTIMIZATION: Get output only from a single worker (output_rank)
226-
outputs = self.collective_rpc(
245+
return self.collective_rpc(
227246
"take_draft_token_ids", unique_reply_rank=self.output_rank
228247
)
229-
return outputs[0]
230248

231-
def collective_rpc(
249+
def collective_rpc( # type: ignore[override]
232250
self,
233251
method: str | Callable,
234252
timeout: float | None = None,
235253
args: tuple = (),
236254
kwargs: dict | None = None,
237255
non_block: bool = False,
238256
unique_reply_rank: int | None = None,
239-
) -> list[Any]:
257+
) -> Any | list[Any] | Future[Any | list[Any]]:
258+
"""Returns single result if unique_reply_rank is provided, otherwise list."""
259+
240260
if self.is_failed:
241261
raise RuntimeError("Executor failed.")
242262

@@ -246,63 +266,52 @@ def collective_rpc(
246266
# NOTE: If the args are heterogeneous, then we pack them into a list,
247267
# and unpack them in the method of every worker, because every worker
248268
# knows their own rank.
249-
try:
250-
if isinstance(method, str):
251-
send_method = method
252-
else:
253-
send_method = cloudpickle.dumps(
254-
method, protocol=pickle.HIGHEST_PROTOCOL
255-
)
256-
self.rpc_broadcast_mq.enqueue(
257-
(send_method, args, kwargs, unique_reply_rank)
258-
)
259269

260-
workers = (
261-
(self.workers[unique_reply_rank],)
262-
if unique_reply_rank is not None
263-
else self.workers
264-
)
265-
responses = []
270+
if isinstance(method, str):
271+
send_method = method
272+
else:
273+
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
274+
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
266275

267-
def get_response(
268-
w: WorkerProcHandle,
269-
dequeue_timeout: float | None = None,
270-
cancel_event: threading.Event | None = None,
271-
):
272-
status, result = w.worker_response_mq.dequeue(
273-
timeout=dequeue_timeout, cancel=cancel_event
274-
)
276+
workers = (
277+
(self.workers[unique_reply_rank],)
278+
if unique_reply_rank is not None
279+
else self.workers
280+
)
275281

276-
if status != WorkerProc.ResponseStatus.SUCCESS:
277-
raise RuntimeError(
278-
f"Worker failed with error '{result}', please check the"
279-
" stack trace above for the root cause"
280-
)
281-
return result
282+
shutdown_event = self.shutdown_event
282283

284+
def get_response():
285+
responses = []
283286
for w in workers:
284287
dequeue_timeout = (
285288
None if deadline is None else (deadline - time.monotonic())
286289
)
287-
288-
if self.io_thread_pool is not None:
289-
# We must consume worker_response_mq from a single thread.
290-
result = self.io_thread_pool.submit( # type: ignore
291-
get_response, w, dequeue_timeout, self.shutdown_event
290+
try:
291+
status, result = w.worker_response_mq.dequeue(
292+
timeout=dequeue_timeout, cancel=shutdown_event
292293
)
293-
if not non_block:
294-
result = result.result()
295-
elif not non_block:
296-
result = get_response(w, dequeue_timeout, self.shutdown_event)
297-
else:
294+
except TimeoutError as e:
295+
raise TimeoutError(f"RPC call to {method} timed out.") from e
296+
if status != WorkerProc.ResponseStatus.SUCCESS:
298297
raise RuntimeError(
299-
"non_block can only be used when max_concurrent_batches > 1"
298+
f"Worker failed with error '{result}', please check the"
299+
" stack trace above for the root cause"
300300
)
301301
responses.append(result)
302+
return responses[0] if unique_reply_rank is not None else responses
303+
304+
if non_block:
305+
future = FutureWrapper(self.futures_queue)
306+
self.futures_queue.appendleft((future, get_response))
307+
return future
308+
309+
# First drain any pending futures in the queue.
310+
while self.futures_queue:
311+
future, get_fut_response = self.futures_queue.pop()
312+
future.wait_for_response(get_fut_response)
302313

303-
return responses
304-
except TimeoutError as e:
305-
raise TimeoutError(f"RPC call to {method} timed out.") from e
314+
return get_response()
306315

307316
@staticmethod
308317
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
@@ -348,9 +357,6 @@ def shutdown(self):
348357
self._ensure_worker_termination([w.proc for w in workers])
349358

350359
self.shutdown_event.set()
351-
if self.io_thread_pool is not None:
352-
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
353-
del self.io_thread_pool
354360

355361
self.rpc_broadcast_mq = None
356362

0 commit comments

Comments
 (0)