Skip to content

Commit 8e19d70

Browse files
authored
Fix typo in VLLM_TPU_USING_PATHWAYS flag (#1074)
Signed-off-by: Richard Liu <ricliu@google.com>
1 parent 70ebcc9 commit 8e19d70

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tpu_inference/layers/vllm/sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
8484

8585
def _convert_to_torchax_and_shard(tensor: torch.Tensor,
8686
sharding: NamedSharding) -> torch.Tensor:
87-
if os.getenv("VLLM_TPU_USE_PATHWAYS", False) and tensor is torch.Tensor:
87+
if os.getenv("VLLM_TPU_USING_PATHWAYS", False) and tensor is torch.Tensor:
8888
np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
8989
dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
9090
return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))

0 commit comments

Comments
 (0)