From 81cd2b66302a1fcbccd528a343dbb1e296ad24e3 Mon Sep 17 00:00:00 2001 From: Chenyaaang Date: Wed, 12 Nov 2025 01:19:23 +0000 Subject: [PATCH 1/3] enable pp on ray Signed-off-by: Chenyaaang --- .../executors/ray_distributed_executor.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index f1f055eec..28d2912ef 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -132,10 +132,17 @@ def _initialize_ray_cluster(self) -> None: f"current platform {current_platform.device_name} does not " "support ray.") - placement_group_specs: List[Dict[str, float]] = [{ - device_str: - node['Resources'][device_str] - } for node in ray.nodes()] + tp_size = self.parallel_config.tensor_parallel_size + pp_size = self.parallel_config.pipeline_parallel_size + placement_group_specs: List[Dict[str, float]] = [] + if pp_size == 1: + placement_group_specs = [{ + device_str: node['Resources'][device_str] + } for node in ray.nodes()] + else: + placement_group_specs = [{ + device_str: tp_size + } for _ in range(pp_size)] # vLLM engine is also a worker to execute model with an accelerator, # so it requires to have the device in a current node. Check if @@ -330,6 +337,8 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): all_kwargs = [] for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids): local_rank = node_workers[node_id].index(rank) + ip = sorted_worker_metadata[rank].ip + prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else "" kwargs = dict( vllm_config=self.vllm_config, local_rank=local_rank, @@ -337,22 +346,25 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): distributed_init_method=distributed_init_method, is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0), + ip=ip, + prev_worker_ip=prev_ip, ) all_kwargs.append(kwargs) self.collective_rpc("init_worker", args=(all_kwargs, )) self.collective_rpc("init_device") + self._run_workers("initialize_pp_transfer_connect") self.collective_rpc("load_model") if self.use_ray_spmd_worker: for pp_rank in range(self.parallel_config.pipeline_parallel_size): self.pp_tp_workers.append([]) - for tp_rank in range( - int(self.parallel_config.tensor_parallel_size // - num_tpu_per_worker)): - # PP=2, TP=4 - # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] - rank = (pp_rank * self.parallel_config.tensor_parallel_size - ) + tp_rank + num_tp_workers = int( + self.parallel_config.tensor_parallel_size // + num_tpu_per_worker) + for tp_rank in range(num_tp_workers): + # PP=2, TP=4, num_tpu_per_worker=2 + # pp_tp_workers = [[0, 1], [2, 3]] + rank = (pp_rank * num_tp_workers) + tp_rank assert len(self.pp_tp_workers[pp_rank]) == tp_rank assert pp_rank < len(self.pp_tp_workers) self.pp_tp_workers[pp_rank].append(self.workers[rank]) From fb332df22d707a25cfc766cf27a1bcba9f6f92e3 Mon Sep 17 00:00:00 2001 From: Chenyaaang Date: Wed, 12 Nov 2025 21:40:47 +0000 Subject: [PATCH 2/3] only init pp transfer connect when pp > 1 Signed-off-by: Chenyaaang --- tpu_inference/executors/ray_distributed_executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index 28d2912ef..4d93b2698 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -352,7 +352,8 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): all_kwargs.append(kwargs) self.collective_rpc("init_worker", args=(all_kwargs, )) self.collective_rpc("init_device") - self._run_workers("initialize_pp_transfer_connect") + if self.parallel_config.pipeline_parallel_size > 1: + self._run_workers("initialize_pp_transfer_connect") self.collective_rpc("load_model") if self.use_ray_spmd_worker: From f48e4511d6437a3b84147e58d4612d0eff366e8a Mon Sep 17 00:00:00 2001 From: Chenyaaang Date: Fri, 14 Nov 2025 23:35:44 +0000 Subject: [PATCH 3/3] use sharding config total devices instead of tp shape Signed-off-by: Chenyaaang --- tpu_inference/executors/ray_distributed_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index 4d93b2698..181ace942 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -132,7 +132,6 @@ def _initialize_ray_cluster(self) -> None: f"current platform {current_platform.device_name} does not " "support ray.") - tp_size = self.parallel_config.tensor_parallel_size pp_size = self.parallel_config.pipeline_parallel_size placement_group_specs: List[Dict[str, float]] = [] if pp_size == 1: @@ -140,8 +139,9 @@ def _initialize_ray_cluster(self) -> None: device_str: node['Resources'][device_str] } for node in ray.nodes()] else: + num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices placement_group_specs = [{ - device_str: tp_size + device_str: num_devices_per_pp_rank } for _ in range(pp_size)] # vLLM engine is also a worker to execute model with an accelerator,