Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tpu_inference/kernels/ragged_paged_attention/v3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_tpu_version() -> int:
return -1
if kind.endswith(' lite'):
kind = kind[:-len(' lite')]
if kind.endswith('p'):
if kind.endswith('p') or kind.endswith('e'):
kind = kind[:-1]
if kind == 'TPU7x':
return 7
Expand Down
18 changes: 11 additions & 7 deletions tpu_inference/layers/vllm/sharding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import jax
import jax.numpy as jnp
import torch
import torchax
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from torch.nn import Parameter
from torch.utils import _pytree as pytree
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
from torchax.interop import torch_view
from vllm.lora.layers import (MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA)
Expand All @@ -19,6 +19,12 @@

logger = init_logger(__name__)

TORCH_TO_JAX_DTYPE_MAP = {
torch.float32: jnp.float32,
torch.float16: jnp.float16,
torch.bfloat16: jnp.bfloat16,
}


def shard_model_to_tpu(model: torch.nn.Module,
mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
Expand Down Expand Up @@ -75,11 +81,9 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:

def _convert_to_torchax_and_shard(tensor: torch.Tensor,
sharding: NamedSharding) -> torch.Tensor:
if isinstance(tensor, torchax.tensor.Tensor):
tensor = jax_view(tensor)
else:
tensor = t2j(tensor)
return torch_view(_sharded_device_put(tensor, sharding))
np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))


def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/models/vllm/vllm_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def load_weights(self):

# Load the vLLM model and wrap it into a new model whose forward
# function can calculate the hidden_state and logits.
with load_context, jax.default_device(jax.devices('cpu')[0]):
with load_context:
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
lora_manager = None
if vllm_config_for_load.lora_config is not None:
Expand Down