99import time
1010import traceback
1111import weakref
12+ from collections import deque
1213from collections .abc import Callable
13- from concurrent .futures import Future , ThreadPoolExecutor
14+ from concurrent .futures import Future , InvalidStateError
15+ from contextlib import suppress
1416from dataclasses import dataclass
1517from enum import Enum , auto
1618from functools import cached_property , partial
5456logger = init_logger (__name__ )
5557
5658
59+ class FutureWrapper (Future ):
60+ def __init__ (self , futures_queue : deque [tuple ["FutureWrapper" , Callable ]]):
61+ self .futures_queue = futures_queue
62+ super ().__init__ ()
63+
64+ def result (self , timeout = None ):
65+ if timeout is not None :
66+ raise RuntimeError ("timeout not implemented" )
67+ # Drain any futures ahead of us in the queue.
68+ while not self .done ():
69+ future , get_response = self .futures_queue .pop ()
70+ future .wait_for_response (get_response )
71+ return super ().result ()
72+
73+ def wait_for_response (self , get_response : Callable ):
74+ try :
75+ response = get_response ()
76+ with suppress (InvalidStateError ):
77+ self .set_result (response )
78+ except Exception as e :
79+ with suppress (InvalidStateError ):
80+ self .set_exception (e )
81+
82+
5783class MultiprocExecutor (Executor ):
5884 supports_pp : bool = True
5985
@@ -64,7 +90,6 @@ def _init_executor(self) -> None:
6490 self .is_failed = False
6591 self .shutdown_event = threading .Event ()
6692 self .failure_callback : FailureCallback | None = None
67- self .io_thread_pool : ThreadPoolExecutor | None = None
6893
6994 self .world_size = self .parallel_config .world_size
7095 tensor_parallel_size = self .parallel_config .tensor_parallel_size
@@ -132,12 +157,7 @@ def _init_executor(self) -> None:
132157 uw .death_writer .close ()
133158 self ._ensure_worker_termination ([uw .proc for uw in unready_workers ])
134159
135- # Note: must use only 1 IO thread to keep dequeue sequence
136- # from the response queue.
137- # _async_aggregate_workers_output also assumes a single IO thread.
138- self .io_thread_pool = ThreadPoolExecutor (
139- max_workers = 1 , thread_name_prefix = "mp_exec_io"
140- )
160+ self .futures_queue = deque [tuple [FutureWrapper , Callable ]]()
141161
142162 self .output_rank = self ._get_output_rank ()
143163 self .has_connector = self .vllm_config .kv_transfer_config is not None
@@ -195,14 +215,13 @@ def _execute_with_aggregation(
195215 ) -> ModelRunnerOutput | None | Future [ModelRunnerOutput | None ]:
196216 if not self .has_connector :
197217 # get output only from a single worker (output_rank)
198- ( output ,) = self .collective_rpc (
218+ return self .collective_rpc (
199219 method ,
200220 args = args ,
201221 unique_reply_rank = self .output_rank ,
202222 non_block = non_block ,
203223 timeout = envs .VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS ,
204224 )
205- return output
206225
207226 # get output from all workers
208227 outputs = self .collective_rpc (
@@ -223,20 +242,21 @@ def execute_dummy_batch(self) -> None:
223242
224243 def take_draft_token_ids (self ) -> DraftTokenIds | None :
225244 # OPTIMIZATION: Get output only from a single worker (output_rank)
226- outputs = self .collective_rpc (
245+ return self .collective_rpc (
227246 "take_draft_token_ids" , unique_reply_rank = self .output_rank
228247 )
229- return outputs [0 ]
230248
231- def collective_rpc (
249+ def collective_rpc ( # type: ignore[override]
232250 self ,
233251 method : str | Callable ,
234252 timeout : float | None = None ,
235253 args : tuple = (),
236254 kwargs : dict | None = None ,
237255 non_block : bool = False ,
238256 unique_reply_rank : int | None = None ,
239- ) -> list [Any ]:
257+ ) -> Any | list [Any ] | Future [Any | list [Any ]]:
258+ """Returns single result if unique_reply_rank is provided, otherwise list."""
259+
240260 if self .is_failed :
241261 raise RuntimeError ("Executor failed." )
242262
@@ -246,63 +266,52 @@ def collective_rpc(
246266 # NOTE: If the args are heterogeneous, then we pack them into a list,
247267 # and unpack them in the method of every worker, because every worker
248268 # knows their own rank.
249- try :
250- if isinstance (method , str ):
251- send_method = method
252- else :
253- send_method = cloudpickle .dumps (
254- method , protocol = pickle .HIGHEST_PROTOCOL
255- )
256- self .rpc_broadcast_mq .enqueue (
257- (send_method , args , kwargs , unique_reply_rank )
258- )
259269
260- workers = (
261- (self .workers [unique_reply_rank ],)
262- if unique_reply_rank is not None
263- else self .workers
264- )
265- responses = []
270+ if isinstance (method , str ):
271+ send_method = method
272+ else :
273+ send_method = cloudpickle .dumps (method , protocol = pickle .HIGHEST_PROTOCOL )
274+ self .rpc_broadcast_mq .enqueue ((send_method , args , kwargs , unique_reply_rank ))
266275
267- def get_response (
268- w : WorkerProcHandle ,
269- dequeue_timeout : float | None = None ,
270- cancel_event : threading .Event | None = None ,
271- ):
272- status , result = w .worker_response_mq .dequeue (
273- timeout = dequeue_timeout , cancel = cancel_event
274- )
276+ workers = (
277+ (self .workers [unique_reply_rank ],)
278+ if unique_reply_rank is not None
279+ else self .workers
280+ )
275281
276- if status != WorkerProc .ResponseStatus .SUCCESS :
277- raise RuntimeError (
278- f"Worker failed with error '{ result } ', please check the"
279- " stack trace above for the root cause"
280- )
281- return result
282+ shutdown_event = self .shutdown_event
282283
284+ def get_response ():
285+ responses = []
283286 for w in workers :
284287 dequeue_timeout = (
285288 None if deadline is None else (deadline - time .monotonic ())
286289 )
287-
288- if self .io_thread_pool is not None :
289- # We must consume worker_response_mq from a single thread.
290- result = self .io_thread_pool .submit ( # type: ignore
291- get_response , w , dequeue_timeout , self .shutdown_event
290+ try :
291+ status , result = w .worker_response_mq .dequeue (
292+ timeout = dequeue_timeout , cancel = shutdown_event
292293 )
293- if not non_block :
294- result = result .result ()
295- elif not non_block :
296- result = get_response (w , dequeue_timeout , self .shutdown_event )
297- else :
294+ except TimeoutError as e :
295+ raise TimeoutError (f"RPC call to { method } timed out." ) from e
296+ if status != WorkerProc .ResponseStatus .SUCCESS :
298297 raise RuntimeError (
299- "non_block can only be used when max_concurrent_batches > 1"
298+ f"Worker failed with error '{ result } ', please check the"
299+ " stack trace above for the root cause"
300300 )
301301 responses .append (result )
302+ return responses [0 ] if unique_reply_rank is not None else responses
303+
304+ if non_block :
305+ future = FutureWrapper (self .futures_queue )
306+ self .futures_queue .appendleft ((future , get_response ))
307+ return future
308+
309+ # First drain any pending futures in the queue.
310+ while self .futures_queue :
311+ future , get_fut_response = self .futures_queue .pop ()
312+ future .wait_for_response (get_fut_response )
302313
303- return responses
304- except TimeoutError as e :
305- raise TimeoutError (f"RPC call to { method } timed out." ) from e
314+ return get_response ()
306315
307316 @staticmethod
308317 def _ensure_worker_termination (worker_procs : list [BaseProcess ]):
@@ -348,9 +357,6 @@ def shutdown(self):
348357 self ._ensure_worker_termination ([w .proc for w in workers ])
349358
350359 self .shutdown_event .set ()
351- if self .io_thread_pool is not None :
352- self .io_thread_pool .shutdown (wait = False , cancel_futures = True )
353- del self .io_thread_pool
354360
355361 self .rpc_broadcast_mq = None
356362
0 commit comments