Skip to content

Commit 68a72a5

Browse files
authored
Revert "[PerfFix] Avoid separate thread for MP executor shm spin (#28012)" (#28289)
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent 0f872b7 commit 68a72a5

File tree

9 files changed

+131
-143
lines changed

9 files changed

+131
-143
lines changed

tests/v1/executor/test_executor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import asyncio
55
import os
66
from collections.abc import Callable
7-
from concurrent.futures import Future
87
from typing import Any
98

109
import pytest
@@ -28,7 +27,7 @@ def collective_rpc(
2827
kwargs: dict | None = None,
2928
non_block: bool = False,
3029
unique_reply_rank: int | None = None,
31-
) -> Any | list[Any] | Future[Any | list[Any]]:
30+
) -> list[Any]:
3231
# Drop marker to show that this was run
3332
with open(".marker", "w"):
3433
...

tests/v1/kv_connector/unit/test_output_aggregator.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@ def test_aggregate_workers_output():
8989
def test_async_aggregate_workers_output():
9090
aggregator = KVOutputAggregator(expected_finished_count=2)
9191

92-
future: Future[list[DummyModelRunnerOutput]] = Future()
93-
result_future = aggregator.async_aggregate(future)
92+
future1: Future[DummyModelRunnerOutput] = Future()
93+
future2: Future[DummyModelRunnerOutput] = Future()
94+
result_future = aggregator.async_aggregate([future1, future2])
9495

9596
output1 = DummyModelRunnerOutput()
9697
output2 = DummyModelRunnerOutput()
97-
future.set_result([output1, output2])
98+
future1.set_result(output1)
99+
future2.set_result(output2)
98100

99101
assert result_future.done()
100102
aggregated = result_future.result()
@@ -104,14 +106,16 @@ def test_async_aggregate_workers_output():
104106
assert aggregated.finished_recving is None
105107
assert not aggregated.invalid_block_ids
106108

107-
future = Future()
108-
result_future = aggregator.async_aggregate(future)
109+
future1 = Future()
110+
future2 = Future()
111+
result_future = aggregator.async_aggregate([future1, future2])
109112

110113
output1 = DummyModelRunnerOutput(
111114
finished_sending={"req1"}, finished_recving={"req2"}
112115
)
113116
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
114-
future.set_result([output1, output2])
117+
future1.set_result(output1)
118+
future2.set_result(output2)
115119

116120
assert result_future.done()
117121
aggregated = result_future.result()
@@ -121,12 +125,14 @@ def test_async_aggregate_workers_output():
121125
assert aggregated.finished_recving is None
122126
assert aggregated.invalid_block_ids == {1}
123127

124-
future = Future()
125-
result_future = aggregator.async_aggregate(future)
128+
future1 = Future()
129+
future2 = Future()
130+
result_future = aggregator.async_aggregate([future1, future2])
126131

127132
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
128133
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
129-
future.set_result([output1, output2])
134+
future1.set_result(output1)
135+
future2.set_result(output2)
130136

131137
assert result_future.done()
132138
aggregated = result_future.result()
@@ -136,14 +142,16 @@ def test_async_aggregate_workers_output():
136142
assert aggregated.finished_recving is None
137143
assert aggregated.invalid_block_ids == {2}
138144

139-
future = Future()
140-
result_future = aggregator.async_aggregate(future)
145+
future1 = Future()
146+
future2 = Future()
147+
result_future = aggregator.async_aggregate([future1, future2])
141148

142149
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
143150
output2 = DummyModelRunnerOutput(
144151
finished_recving={"req2"}, invalid_block_ids={4, 5}
145152
)
146-
future.set_result([output1, output2])
153+
future1.set_result(output1)
154+
future2.set_result(output2)
147155

148156
assert result_future.done()
149157
aggregated = result_future.result()

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -221,24 +221,39 @@ def update_finished_set(
221221

222222
def async_aggregate(
223223
self,
224-
output_future: Future[Sequence[ModelRunnerOutput | None]],
224+
output_futures: Sequence[Future[ModelRunnerOutput | None]],
225225
output_rank: int = 0,
226226
) -> Future[ModelRunnerOutput | None]:
227-
"""Takes a future that resolves to a list of outputs and returns a future
228-
which resolves to a single aggregated output."""
227+
"""Takes a list of futures and returns a single future which resolves
228+
to the respective list of outputs."""
229229
result_future: Future[ModelRunnerOutput | None] = Future()
230230

231-
def callback(fut):
232-
if result_future.done():
233-
return
234-
try:
235-
result_future.set_result(self.aggregate(fut.result(), output_rank))
236-
except CancelledError:
237-
result_future.cancel()
238-
except Exception as e:
239-
result_future.set_exception(e)
240-
241-
output_future.add_done_callback(callback)
231+
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
232+
remaining = len(output_futures)
233+
234+
def make_callback(idx):
235+
def callback(fut):
236+
if result_future.done():
237+
return
238+
239+
try:
240+
outputs[idx] = fut.result()
241+
except CancelledError:
242+
result_future.cancel()
243+
except Exception as e:
244+
result_future.set_exception(e)
245+
246+
# this check assumes io_thread_pool uses a single thread
247+
nonlocal remaining
248+
remaining -= 1
249+
if not remaining:
250+
result_future.set_result(self.aggregate(outputs, output_rank))
251+
252+
return callback
253+
254+
for i, output_future in enumerate(output_futures):
255+
output_future.add_done_callback(make_callback(i))
256+
242257
return result_future
243258

244259

vllm/v1/executor/abstract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def collective_rpc(
171171
args: tuple = (),
172172
kwargs: dict | None = None,
173173
non_block: Literal[True] = True,
174-
) -> Future[list[_R]]:
174+
) -> list[Future[_R]]:
175175
pass
176176

177177
@abstractmethod
@@ -219,7 +219,7 @@ def sample_tokens(
219219

220220
def sample_tokens(
221221
self, grammar_output: GrammarOutput | None, non_block: bool = False
222-
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
222+
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
223223
output = self.collective_rpc( # type: ignore[call-overload]
224224
"sample_tokens", args=(grammar_output,), non_block=non_block
225225
)

vllm/v1/executor/multiproc_executor.py

Lines changed: 61 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
import time
1010
import traceback
1111
import weakref
12-
from collections import deque
1312
from collections.abc import Callable
14-
from concurrent.futures import Future, InvalidStateError
15-
from contextlib import suppress
13+
from concurrent.futures import Future, ThreadPoolExecutor
1614
from dataclasses import dataclass
1715
from enum import Enum, auto
1816
from functools import cached_property, partial
@@ -56,30 +54,6 @@
5654
logger = init_logger(__name__)
5755

5856

59-
class FutureWrapper(Future):
60-
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
61-
self.futures_queue = futures_queue
62-
super().__init__()
63-
64-
def result(self, timeout=None):
65-
if timeout is not None:
66-
raise RuntimeError("timeout not implemented")
67-
# Drain any futures ahead of us in the queue.
68-
while not self.done():
69-
future, get_response = self.futures_queue.pop()
70-
future.wait_for_response(get_response)
71-
return super().result()
72-
73-
def wait_for_response(self, get_response: Callable):
74-
try:
75-
response = get_response()
76-
with suppress(InvalidStateError):
77-
self.set_result(response)
78-
except Exception as e:
79-
with suppress(InvalidStateError):
80-
self.set_exception(e)
81-
82-
8357
class MultiprocExecutor(Executor):
8458
supports_pp: bool = True
8559

@@ -90,6 +64,7 @@ def _init_executor(self) -> None:
9064
self.is_failed = False
9165
self.shutdown_event = threading.Event()
9266
self.failure_callback: FailureCallback | None = None
67+
self.io_thread_pool: ThreadPoolExecutor | None = None
9368

9469
self.world_size = self.parallel_config.world_size
9570
tensor_parallel_size = self.parallel_config.tensor_parallel_size
@@ -157,7 +132,12 @@ def _init_executor(self) -> None:
157132
uw.death_writer.close()
158133
self._ensure_worker_termination([uw.proc for uw in unready_workers])
159134

160-
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
135+
# Note: must use only 1 IO thread to keep dequeue sequence
136+
# from the response queue.
137+
# _async_aggregate_workers_output also assumes a single IO thread.
138+
self.io_thread_pool = ThreadPoolExecutor(
139+
max_workers=1, thread_name_prefix="mp_exec_io"
140+
)
161141

162142
self.output_rank = self._get_output_rank()
163143
self.has_connector = self.vllm_config.kv_transfer_config is not None
@@ -215,13 +195,14 @@ def _execute_with_aggregation(
215195
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
216196
if not self.has_connector:
217197
# get output only from a single worker (output_rank)
218-
return self.collective_rpc(
198+
(output,) = self.collective_rpc(
219199
method,
220200
args=args,
221201
unique_reply_rank=self.output_rank,
222202
non_block=non_block,
223203
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
224204
)
205+
return output
225206

226207
# get output from all workers
227208
outputs = self.collective_rpc(
@@ -242,21 +223,20 @@ def execute_dummy_batch(self) -> None:
242223

243224
def take_draft_token_ids(self) -> DraftTokenIds | None:
244225
# OPTIMIZATION: Get output only from a single worker (output_rank)
245-
return self.collective_rpc(
226+
outputs = self.collective_rpc(
246227
"take_draft_token_ids", unique_reply_rank=self.output_rank
247228
)
229+
return outputs[0]
248230

249-
def collective_rpc( # type: ignore[override]
231+
def collective_rpc(
250232
self,
251233
method: str | Callable,
252234
timeout: float | None = None,
253235
args: tuple = (),
254236
kwargs: dict | None = None,
255237
non_block: bool = False,
256238
unique_reply_rank: int | None = None,
257-
) -> Any | list[Any] | Future[Any | list[Any]]:
258-
"""Returns single result if unique_reply_rank is provided, otherwise list."""
259-
239+
) -> list[Any]:
260240
if self.is_failed:
261241
raise RuntimeError("Executor failed.")
262242

@@ -266,52 +246,63 @@ def collective_rpc( # type: ignore[override]
266246
# NOTE: If the args are heterogeneous, then we pack them into a list,
267247
# and unpack them in the method of every worker, because every worker
268248
# knows their own rank.
249+
try:
250+
if isinstance(method, str):
251+
send_method = method
252+
else:
253+
send_method = cloudpickle.dumps(
254+
method, protocol=pickle.HIGHEST_PROTOCOL
255+
)
256+
self.rpc_broadcast_mq.enqueue(
257+
(send_method, args, kwargs, unique_reply_rank)
258+
)
269259

270-
if isinstance(method, str):
271-
send_method = method
272-
else:
273-
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
274-
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
260+
workers = (
261+
(self.workers[unique_reply_rank],)
262+
if unique_reply_rank is not None
263+
else self.workers
264+
)
265+
responses = []
275266

276-
workers = (
277-
(self.workers[unique_reply_rank],)
278-
if unique_reply_rank is not None
279-
else self.workers
280-
)
267+
def get_response(
268+
w: WorkerProcHandle,
269+
dequeue_timeout: float | None = None,
270+
cancel_event: threading.Event | None = None,
271+
):
272+
status, result = w.worker_response_mq.dequeue(
273+
timeout=dequeue_timeout, cancel=cancel_event
274+
)
281275

282-
shutdown_event = self.shutdown_event
276+
if status != WorkerProc.ResponseStatus.SUCCESS:
277+
raise RuntimeError(
278+
f"Worker failed with error '{result}', please check the"
279+
" stack trace above for the root cause"
280+
)
281+
return result
283282

284-
def get_response():
285-
responses = []
286283
for w in workers:
287284
dequeue_timeout = (
288285
None if deadline is None else (deadline - time.monotonic())
289286
)
290-
try:
291-
status, result = w.worker_response_mq.dequeue(
292-
timeout=dequeue_timeout, cancel=shutdown_event
287+
288+
if self.io_thread_pool is not None:
289+
# We must consume worker_response_mq from a single thread.
290+
result = self.io_thread_pool.submit( # type: ignore
291+
get_response, w, dequeue_timeout, self.shutdown_event
293292
)
294-
except TimeoutError as e:
295-
raise TimeoutError(f"RPC call to {method} timed out.") from e
296-
if status != WorkerProc.ResponseStatus.SUCCESS:
293+
if not non_block:
294+
result = result.result()
295+
elif not non_block:
296+
result = get_response(w, dequeue_timeout, self.shutdown_event)
297+
else:
297298
raise RuntimeError(
298-
f"Worker failed with error '{result}', please check the"
299-
" stack trace above for the root cause"
299+
"non_block can only be used when max_concurrent_batches > 1"
300300
)
301301
responses.append(result)
302-
return responses[0] if unique_reply_rank is not None else responses
303-
304-
if non_block:
305-
future = FutureWrapper(self.futures_queue)
306-
self.futures_queue.appendleft((future, get_response))
307-
return future
308-
309-
# First drain any pending futures in the queue.
310-
while self.futures_queue:
311-
future, get_fut_response = self.futures_queue.pop()
312-
future.wait_for_response(get_fut_response)
313302

314-
return get_response()
303+
return responses
304+
except TimeoutError as e:
305+
raise TimeoutError(f"RPC call to {method} timed out.") from e
315306

316307
@staticmethod
317308
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
@@ -357,6 +348,9 @@ def shutdown(self):
357348
self._ensure_worker_termination([w.proc for w in workers])
358349

359350
self.shutdown_event.set()
351+
if self.io_thread_pool is not None:
352+
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
353+
del self.io_thread_pool
360354

361355
self.rpc_broadcast_mq = None
362356

0 commit comments

Comments
 (0)