@@ -131,10 +131,17 @@ def _initialize_ray_cluster(self) -> None:
131131 f"current platform { current_platform .device_name } does not "
132132 "support ray." )
133133
134- placement_group_specs : List [Dict [str , float ]] = [{
135- device_str :
136- node ['Resources' ][device_str ]
137- } for node in ray .nodes ()]
134+ pp_size = self .parallel_config .pipeline_parallel_size
135+ placement_group_specs : List [Dict [str , float ]] = []
136+ if pp_size == 1 :
137+ placement_group_specs = [{
138+ device_str : node ['Resources' ][device_str ]
139+ } for node in ray .nodes ()]
140+ else :
141+ num_devices_per_pp_rank = self .vllm_config .sharding_config .total_devices
142+ placement_group_specs = [{
143+ device_str : num_devices_per_pp_rank
144+ } for _ in range (pp_size )]
138145
139146 # vLLM engine is also a worker to execute model with an accelerator,
140147 # so it requires to have the device in a current node. Check if
@@ -329,29 +336,35 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
329336 all_kwargs = []
330337 for rank , (node_id , _ ) in enumerate (worker_node_and_tpu_ids ):
331338 local_rank = node_workers [node_id ].index (rank )
339+ ip = sorted_worker_metadata [rank ].ip
340+ prev_ip = sorted_worker_metadata [rank - 1 ].ip if rank > 0 else ""
332341 kwargs = dict (
333342 vllm_config = self .vllm_config ,
334343 local_rank = local_rank ,
335344 rank = rank ,
336345 distributed_init_method = distributed_init_method ,
337346 is_driver_worker = (not self .parallel_config )
338347 or (rank % self .parallel_config .tensor_parallel_size == 0 ),
348+ ip = ip ,
349+ prev_worker_ip = prev_ip ,
339350 )
340351 all_kwargs .append (kwargs )
341352 self .collective_rpc ("init_worker" , args = (all_kwargs , ))
342353 self .collective_rpc ("init_device" )
354+ if self .parallel_config .pipeline_parallel_size > 1 :
355+ self ._run_workers ("initialize_pp_transfer_connect" )
343356 self .collective_rpc ("load_model" )
344357
345358 if self .use_ray_spmd_worker :
346359 for pp_rank in range (self .parallel_config .pipeline_parallel_size ):
347360 self .pp_tp_workers .append ([])
348- for tp_rank in range (
349- int ( self .parallel_config .tensor_parallel_size //
350- num_tpu_per_worker )):
351- # PP=2, TP=4
352- # pp_tp_workers = [[0, 1, 2, 3], [ 4, 5, 6, 7]]
353- rank = ( pp_rank * self . parallel_config . tensor_parallel_size
354- ) + tp_rank
361+ num_tp_workers = int (
362+ self .parallel_config .tensor_parallel_size //
363+ num_tpu_per_worker )
364+ for tp_rank in range ( num_tp_workers ):
365+ # PP= 2, TP= 4, num_tpu_per_worker=2
366+ # pp_tp_workers = [[0, 1], [2, 3]]
367+ rank = ( pp_rank * num_tp_workers ) + tp_rank
355368 assert len (self .pp_tp_workers [pp_rank ]) == tp_rank
356369 assert pp_rank < len (self .pp_tp_workers )
357370 self .pp_tp_workers [pp_rank ].append (self .workers [rank ])
0 commit comments