diff --git a/tpu_inference/kernels/ragged_paged_attention/v3/util.py b/tpu_inference/kernels/ragged_paged_attention/v3/util.py index f12830ce4..6e879058e 100644 --- a/tpu_inference/kernels/ragged_paged_attention/v3/util.py +++ b/tpu_inference/kernels/ragged_paged_attention/v3/util.py @@ -43,7 +43,7 @@ def get_tpu_version() -> int: return -1 if kind.endswith(' lite'): kind = kind[:-len(' lite')] - if kind.endswith('p'): + if kind.endswith('p') or kind.endswith('e'): kind = kind[:-1] if kind == 'TPU7x': return 7 diff --git a/tpu_inference/layers/vllm/sharding.py b/tpu_inference/layers/vllm/sharding.py index 38501a8b2..d1bdc325e 100644 --- a/tpu_inference/layers/vllm/sharding.py +++ b/tpu_inference/layers/vllm/sharding.py @@ -1,11 +1,11 @@ import jax +import jax.numpy as jnp import torch import torchax from jax.sharding import Mesh, NamedSharding, PartitionSpec from torch.nn import Parameter from torch.utils import _pytree as pytree -from torchax.interop import jax_view, torch_view -from torchax.ops.mappings import t2j +from torchax.interop import torch_view from vllm.lora.layers import (MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, RowParallelLinearWithLoRA) @@ -19,6 +19,12 @@ logger = init_logger(__name__) +TORCH_TO_JAX_DTYPE_MAP = { + torch.float32: jnp.float32, + torch.float16: jnp.float16, + torch.bfloat16: jnp.bfloat16, +} + def shard_model_to_tpu(model: torch.nn.Module, mesh: Mesh) -> dict[str, torchax.torch.Tensor]: @@ -75,11 +81,9 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool: def _convert_to_torchax_and_shard(tensor: torch.Tensor, sharding: NamedSharding) -> torch.Tensor: - if isinstance(tensor, torchax.tensor.Tensor): - tensor = jax_view(tensor) - else: - tensor = t2j(tensor) - return torch_view(_sharded_device_put(tensor, sharding)) + np_tensor = tensor.detach().cpu().to(torch.float32).numpy() + dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32) + return torch_view(jax.device_put(np_tensor, sharding).astype(dtype)) def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor, diff --git a/tpu_inference/models/vllm/vllm_model_wrapper.py b/tpu_inference/models/vllm/vllm_model_wrapper.py index 2a2513689..b1ea732b1 100644 --- a/tpu_inference/models/vllm/vllm_model_wrapper.py +++ b/tpu_inference/models/vllm/vllm_model_wrapper.py @@ -105,7 +105,7 @@ def load_weights(self): # Load the vLLM model and wrap it into a new model whose forward # function can calculate the hidden_state and logits. - with load_context, jax.default_device(jax.devices('cpu')[0]): + with load_context: vllm_model = vllm_get_model(vllm_config=vllm_config_for_load) lora_manager = None if vllm_config_for_load.lora_config is not None: