Skip to content

Commit c6f4a10

Browse files
committed
fix torchas on pathways
Signed-off-by: Richard Liu <ricliu@google.com>
1 parent 6a1da81 commit c6f4a10

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

tpu_inference/kernels/ragged_paged_attention/v3/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_tpu_version() -> int:
4343
return -1
4444
if kind.endswith(' lite'):
4545
kind = kind[:-len(' lite')]
46-
if kind.endswith('p'):
46+
if kind.endswith('p') or kind.endswith('e'):
4747
kind = kind[:-1]
4848
if kind == 'TPU7x':
4949
return 7

tpu_inference/layers/vllm/sharding.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import jax
2+
import jax.numpy as jnp
23
import torch
34
import torchax
45
from jax.sharding import Mesh, NamedSharding, PartitionSpec
@@ -19,6 +20,12 @@
1920

2021
logger = init_logger(__name__)
2122

23+
TORCH_TO_JAX_DTYPE_MAP = {
24+
torch.float32: jnp.float32,
25+
torch.float16: jnp.float16,
26+
torch.bfloat16: jnp.bfloat16,
27+
}
28+
2229

2330
def shard_model_to_tpu(model: torch.nn.Module,
2431
mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
@@ -75,11 +82,9 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
7582

7683
def _convert_to_torchax_and_shard(tensor: torch.Tensor,
7784
sharding: NamedSharding) -> torch.Tensor:
78-
if isinstance(tensor, torchax.tensor.Tensor):
79-
tensor = jax_view(tensor)
80-
else:
81-
tensor = t2j(tensor)
82-
return torch_view(_sharded_device_put(tensor, sharding))
85+
np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
86+
dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
87+
return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
8388

8489

8590
def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def load_weights(self):
105105

106106
# Load the vLLM model and wrap it into a new model whose forward
107107
# function can calculate the hidden_state and logits.
108-
with load_context, jax.default_device(jax.devices('cpu')[0]):
108+
with load_context:
109109
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
110110
lora_manager = None
111111
if vllm_config_for_load.lora_config is not None:

0 commit comments

Comments
 (0)