Skip to content

Commit 1fa5e73

Browse files
committed
enable pp on tpu jax platform
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent 6c4d129 commit 1fa5e73

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tpu_inference/platforms/tpu_platform.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
184184

185185
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
186186
if not multihost_backend: # Single host
187-
logger.info("Force using UniProcExecutor for JAX on single host.")
188-
parallel_config.distributed_executor_backend = "uni"
187+
if parallel_config.pipeline_parallel_size == 1:
188+
logger.info("Force using UniProcExecutor for JAX on \
189+
single host without pipeline parallelism.")
190+
parallel_config.distributed_executor_backend = "uni"
191+
else:
192+
logger.info("Force using MultiprocExecutor for JAX on \
193+
single host with pipeline parallelism.")
194+
parallel_config.distributed_executor_backend = "mp"
189195
elif multihost_backend == "ray":
190196
from tpu_inference.executors.ray_distributed_executor import \
191197
RayDistributedExecutor

0 commit comments

Comments
 (0)