Skip to content

Commit 2ce2199

Browse files
ChenyaaangAahilA
authored andcommitted
Enable Pipeline Parallelism to use mp as distributed backend on Jax TPU platform (vllm-project#1054)
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent a0fa725 commit 2ce2199

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tpu_inference/platforms/tpu_platform.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
184184

185185
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
186186
if not multihost_backend: # Single host
187-
logger.info("Force using UniProcExecutor for JAX on single host.")
188-
parallel_config.distributed_executor_backend = "uni"
187+
if parallel_config.pipeline_parallel_size == 1:
188+
logger.info("Force using UniProcExecutor for JAX on \
189+
single host without pipeline parallelism.")
190+
parallel_config.distributed_executor_backend = "uni"
191+
else:
192+
logger.info("Force using MultiprocExecutor for JAX on \
193+
single host with pipeline parallelism.")
194+
parallel_config.distributed_executor_backend = "mp"
189195
elif multihost_backend == "ray":
190196
from tpu_inference.executors.ray_distributed_executor import \
191197
RayDistributedExecutor

0 commit comments

Comments
 (0)