Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions tpu_inference/executors/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,17 @@ def _initialize_ray_cluster(self) -> None:
f"current platform {current_platform.device_name} does not "
"support ray.")

placement_group_specs: List[Dict[str, float]] = [{
device_str:
node['Resources'][device_str]
} for node in ray.nodes()]
pp_size = self.parallel_config.pipeline_parallel_size
placement_group_specs: List[Dict[str, float]] = []
if pp_size == 1:
placement_group_specs = [{
device_str: node['Resources'][device_str]
} for node in ray.nodes()]
else:
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
placement_group_specs = [{
device_str: num_devices_per_pp_rank
} for _ in range(pp_size)]

# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
Expand Down Expand Up @@ -330,29 +337,35 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
local_rank = node_workers[node_id].index(rank)
ip = sorted_worker_metadata[rank].ip
prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
ip=ip,
prev_worker_ip=prev_ip,
)
all_kwargs.append(kwargs)
self.collective_rpc("init_worker", args=(all_kwargs, ))
self.collective_rpc("init_device")
if self.parallel_config.pipeline_parallel_size > 1:
self._run_workers("initialize_pp_transfer_connect")
self.collective_rpc("load_model")

if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
int(self.parallel_config.tensor_parallel_size //
num_tpu_per_worker)):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
num_tp_workers = int(
self.parallel_config.tensor_parallel_size //
num_tpu_per_worker)
for tp_rank in range(num_tp_workers):
# PP=2, TP=4, num_tpu_per_worker=2
# pp_tp_workers = [[0, 1], [2, 3]]
rank = (pp_rank * num_tp_workers) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
Expand Down