Skip to content

Commit 27aee4f

Browse files
committed
enable pp on tpu jax platform
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent b43ea79 commit 27aee4f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tpu_inference/platforms/tpu_jax.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
189189

190190
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
191191
if not multihost_backend: # Single host
192-
logger.info("Force using UniProcExecutor for JAX on single host.")
193-
parallel_config.distributed_executor_backend = "uni"
192+
if parallel_config.pipeline_parallel_size == 1:
193+
logger.info("Force using UniProcExecutor for JAX on \
194+
single host without pipeline parallelism.")
195+
parallel_config.distributed_executor_backend = "uni"
196+
else:
197+
logger.info("Force using MultiprocExecutor for JAX on \
198+
single host with pipeline parallelism.")
199+
parallel_config.distributed_executor_backend = "mp"
194200
elif multihost_backend == "ray":
195201
from tpu_inference.executors.ray_distributed_executor import \
196202
RayDistributedExecutor

0 commit comments

Comments
 (0)