99import time
1010import traceback
1111import weakref
12- from collections import deque
1312from collections .abc import Callable
14- from concurrent .futures import Future , InvalidStateError
15- from contextlib import suppress
13+ from concurrent .futures import Future , ThreadPoolExecutor
1614from dataclasses import dataclass
1715from enum import Enum , auto
1816from functools import cached_property , partial
5654logger = init_logger (__name__ )
5755
5856
59- class FutureWrapper (Future ):
60- def __init__ (self , futures_queue : deque [tuple ["FutureWrapper" , Callable ]]):
61- self .futures_queue = futures_queue
62- super ().__init__ ()
63-
64- def result (self , timeout = None ):
65- if timeout is not None :
66- raise RuntimeError ("timeout not implemented" )
67- # Drain any futures ahead of us in the queue.
68- while not self .done ():
69- future , get_response = self .futures_queue .pop ()
70- future .wait_for_response (get_response )
71- return super ().result ()
72-
73- def wait_for_response (self , get_response : Callable ):
74- try :
75- response = get_response ()
76- with suppress (InvalidStateError ):
77- self .set_result (response )
78- except Exception as e :
79- with suppress (InvalidStateError ):
80- self .set_exception (e )
81-
82-
8357class MultiprocExecutor (Executor ):
8458 supports_pp : bool = True
8559
@@ -90,6 +64,7 @@ def _init_executor(self) -> None:
9064 self .is_failed = False
9165 self .shutdown_event = threading .Event ()
9266 self .failure_callback : FailureCallback | None = None
67+ self .io_thread_pool : ThreadPoolExecutor | None = None
9368
9469 self .world_size = self .parallel_config .world_size
9570 tensor_parallel_size = self .parallel_config .tensor_parallel_size
@@ -157,7 +132,12 @@ def _init_executor(self) -> None:
157132 uw .death_writer .close ()
158133 self ._ensure_worker_termination ([uw .proc for uw in unready_workers ])
159134
160- self .futures_queue = deque [tuple [FutureWrapper , Callable ]]()
135+ # Note: must use only 1 IO thread to keep dequeue sequence
136+ # from the response queue.
137+ # _async_aggregate_workers_output also assumes a single IO thread.
138+ self .io_thread_pool = ThreadPoolExecutor (
139+ max_workers = 1 , thread_name_prefix = "mp_exec_io"
140+ )
161141
162142 self .output_rank = self ._get_output_rank ()
163143 self .has_connector = self .vllm_config .kv_transfer_config is not None
@@ -215,13 +195,14 @@ def _execute_with_aggregation(
215195 ) -> ModelRunnerOutput | None | Future [ModelRunnerOutput | None ]:
216196 if not self .has_connector :
217197 # get output only from a single worker (output_rank)
218- return self .collective_rpc (
198+ ( output ,) = self .collective_rpc (
219199 method ,
220200 args = args ,
221201 unique_reply_rank = self .output_rank ,
222202 non_block = non_block ,
223203 timeout = envs .VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS ,
224204 )
205+ return output
225206
226207 # get output from all workers
227208 outputs = self .collective_rpc (
@@ -242,21 +223,20 @@ def execute_dummy_batch(self) -> None:
242223
243224 def take_draft_token_ids (self ) -> DraftTokenIds | None :
244225 # OPTIMIZATION: Get output only from a single worker (output_rank)
245- return self .collective_rpc (
226+ outputs = self .collective_rpc (
246227 "take_draft_token_ids" , unique_reply_rank = self .output_rank
247228 )
229+ return outputs [0 ]
248230
249- def collective_rpc ( # type: ignore[override]
231+ def collective_rpc (
250232 self ,
251233 method : str | Callable ,
252234 timeout : float | None = None ,
253235 args : tuple = (),
254236 kwargs : dict | None = None ,
255237 non_block : bool = False ,
256238 unique_reply_rank : int | None = None ,
257- ) -> Any | list [Any ] | Future [Any | list [Any ]]:
258- """Returns single result if unique_reply_rank is provided, otherwise list."""
259-
239+ ) -> list [Any ]:
260240 if self .is_failed :
261241 raise RuntimeError ("Executor failed." )
262242
@@ -266,52 +246,63 @@ def collective_rpc( # type: ignore[override]
266246 # NOTE: If the args are heterogeneous, then we pack them into a list,
267247 # and unpack them in the method of every worker, because every worker
268248 # knows their own rank.
249+ try :
250+ if isinstance (method , str ):
251+ send_method = method
252+ else :
253+ send_method = cloudpickle .dumps (
254+ method , protocol = pickle .HIGHEST_PROTOCOL
255+ )
256+ self .rpc_broadcast_mq .enqueue (
257+ (send_method , args , kwargs , unique_reply_rank )
258+ )
269259
270- if isinstance (method , str ):
271- send_method = method
272- else :
273- send_method = cloudpickle .dumps (method , protocol = pickle .HIGHEST_PROTOCOL )
274- self .rpc_broadcast_mq .enqueue ((send_method , args , kwargs , unique_reply_rank ))
260+ workers = (
261+ (self .workers [unique_reply_rank ],)
262+ if unique_reply_rank is not None
263+ else self .workers
264+ )
265+ responses = []
275266
276- workers = (
277- (self .workers [unique_reply_rank ],)
278- if unique_reply_rank is not None
279- else self .workers
280- )
267+ def get_response (
268+ w : WorkerProcHandle ,
269+ dequeue_timeout : float | None = None ,
270+ cancel_event : threading .Event | None = None ,
271+ ):
272+ status , result = w .worker_response_mq .dequeue (
273+ timeout = dequeue_timeout , cancel = cancel_event
274+ )
281275
282- shutdown_event = self .shutdown_event
276+ if status != WorkerProc .ResponseStatus .SUCCESS :
277+ raise RuntimeError (
278+ f"Worker failed with error '{ result } ', please check the"
279+ " stack trace above for the root cause"
280+ )
281+ return result
283282
284- def get_response ():
285- responses = []
286283 for w in workers :
287284 dequeue_timeout = (
288285 None if deadline is None else (deadline - time .monotonic ())
289286 )
290- try :
291- status , result = w .worker_response_mq .dequeue (
292- timeout = dequeue_timeout , cancel = shutdown_event
287+
288+ if self .io_thread_pool is not None :
289+ # We must consume worker_response_mq from a single thread.
290+ result = self .io_thread_pool .submit ( # type: ignore
291+ get_response , w , dequeue_timeout , self .shutdown_event
293292 )
294- except TimeoutError as e :
295- raise TimeoutError (f"RPC call to { method } timed out." ) from e
296- if status != WorkerProc .ResponseStatus .SUCCESS :
293+ if not non_block :
294+ result = result .result ()
295+ elif not non_block :
296+ result = get_response (w , dequeue_timeout , self .shutdown_event )
297+ else :
297298 raise RuntimeError (
298- f"Worker failed with error '{ result } ', please check the"
299- " stack trace above for the root cause"
299+ "non_block can only be used when max_concurrent_batches > 1"
300300 )
301301 responses .append (result )
302- return responses [0 ] if unique_reply_rank is not None else responses
303-
304- if non_block :
305- future = FutureWrapper (self .futures_queue )
306- self .futures_queue .appendleft ((future , get_response ))
307- return future
308-
309- # First drain any pending futures in the queue.
310- while self .futures_queue :
311- future , get_fut_response = self .futures_queue .pop ()
312- future .wait_for_response (get_fut_response )
313302
314- return get_response ()
303+ return responses
304+ except TimeoutError as e :
305+ raise TimeoutError (f"RPC call to { method } timed out." ) from e
315306
316307 @staticmethod
317308 def _ensure_worker_termination (worker_procs : list [BaseProcess ]):
@@ -357,6 +348,9 @@ def shutdown(self):
357348 self ._ensure_worker_termination ([w .proc for w in workers ])
358349
359350 self .shutdown_event .set ()
351+ if self .io_thread_pool is not None :
352+ self .io_thread_pool .shutdown (wait = False , cancel_futures = True )
353+ del self .io_thread_pool
360354
361355 self .rpc_broadcast_mq = None
362356
0 commit comments