We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 70ebcc9 commit 8e19d70Copy full SHA for 8e19d70
tpu_inference/layers/vllm/sharding.py
@@ -84,7 +84,7 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
84
85
def _convert_to_torchax_and_shard(tensor: torch.Tensor,
86
sharding: NamedSharding) -> torch.Tensor:
87
- if os.getenv("VLLM_TPU_USE_PATHWAYS", False) and tensor is torch.Tensor:
+ if os.getenv("VLLM_TPU_USING_PATHWAYS", False) and tensor is torch.Tensor:
88
np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
89
dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
90
return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
0 commit comments