Skip to content

Commit 74f70aa

Browse files
authored
Enable Pipeline Parallelism on Ray (#1078)
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent 292d310 commit 74f70aa

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

tpu_inference/executors/ray_distributed_executor.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,17 @@ def _initialize_ray_cluster(self) -> None:
131131
f"current platform {current_platform.device_name} does not "
132132
"support ray.")
133133

134-
placement_group_specs: List[Dict[str, float]] = [{
135-
device_str:
136-
node['Resources'][device_str]
137-
} for node in ray.nodes()]
134+
pp_size = self.parallel_config.pipeline_parallel_size
135+
placement_group_specs: List[Dict[str, float]] = []
136+
if pp_size == 1:
137+
placement_group_specs = [{
138+
device_str: node['Resources'][device_str]
139+
} for node in ray.nodes()]
140+
else:
141+
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
142+
placement_group_specs = [{
143+
device_str: num_devices_per_pp_rank
144+
} for _ in range(pp_size)]
138145

139146
# vLLM engine is also a worker to execute model with an accelerator,
140147
# so it requires to have the device in a current node. Check if
@@ -329,29 +336,35 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
329336
all_kwargs = []
330337
for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
331338
local_rank = node_workers[node_id].index(rank)
339+
ip = sorted_worker_metadata[rank].ip
340+
prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
332341
kwargs = dict(
333342
vllm_config=self.vllm_config,
334343
local_rank=local_rank,
335344
rank=rank,
336345
distributed_init_method=distributed_init_method,
337346
is_driver_worker=(not self.parallel_config)
338347
or (rank % self.parallel_config.tensor_parallel_size == 0),
348+
ip=ip,
349+
prev_worker_ip=prev_ip,
339350
)
340351
all_kwargs.append(kwargs)
341352
self.collective_rpc("init_worker", args=(all_kwargs, ))
342353
self.collective_rpc("init_device")
354+
if self.parallel_config.pipeline_parallel_size > 1:
355+
self._run_workers("initialize_pp_transfer_connect")
343356
self.collective_rpc("load_model")
344357

345358
if self.use_ray_spmd_worker:
346359
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
347360
self.pp_tp_workers.append([])
348-
for tp_rank in range(
349-
int(self.parallel_config.tensor_parallel_size //
350-
num_tpu_per_worker)):
351-
# PP=2, TP=4
352-
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
353-
rank = (pp_rank * self.parallel_config.tensor_parallel_size
354-
) + tp_rank
361+
num_tp_workers = int(
362+
self.parallel_config.tensor_parallel_size //
363+
num_tpu_per_worker)
364+
for tp_rank in range(num_tp_workers):
365+
# PP=2, TP=4, num_tpu_per_worker=2
366+
# pp_tp_workers = [[0, 1], [2, 3]]
367+
rank = (pp_rank * num_tp_workers) + tp_rank
355368
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
356369
assert pp_rank < len(self.pp_tp_workers)
357370
self.pp_tp_workers[pp_rank].append(self.workers[rank])

0 commit comments

Comments
 (0)