File tree Expand file tree Collapse file tree 3 files changed +12
-7
lines changed
kernels/ragged_paged_attention/v3 Expand file tree Collapse file tree 3 files changed +12
-7
lines changed Original file line number Diff line number Diff line change @@ -43,7 +43,7 @@ def get_tpu_version() -> int:
4343 return - 1
4444 if kind .endswith (' lite' ):
4545 kind = kind [:- len (' lite' )]
46- if kind .endswith ('p' ):
46+ if kind .endswith ('p' ) or kind . endswith ( 'e' ) :
4747 kind = kind [:- 1 ]
4848 if kind == 'TPU7x' :
4949 return 7
Original file line number Diff line number Diff line change 11import jax
2+ import jax .numpy as jnp
23import torch
34import torchax
45from jax .sharding import Mesh , NamedSharding , PartitionSpec
1920
2021logger = init_logger (__name__ )
2122
23+ TORCH_TO_JAX_DTYPE_MAP = {
24+ torch .float32 : jnp .float32 ,
25+ torch .float16 : jnp .float16 ,
26+ torch .bfloat16 : jnp .bfloat16 ,
27+ }
28+
2229
2330def shard_model_to_tpu (model : torch .nn .Module ,
2431 mesh : Mesh ) -> dict [str , torchax .torch .Tensor ]:
@@ -75,11 +82,9 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
7582
7683def _convert_to_torchax_and_shard (tensor : torch .Tensor ,
7784 sharding : NamedSharding ) -> torch .Tensor :
78- if isinstance (tensor , torchax .tensor .Tensor ):
79- tensor = jax_view (tensor )
80- else :
81- tensor = t2j (tensor )
82- return torch_view (_sharded_device_put (tensor , sharding ))
85+ np_tensor = tensor .detach ().cpu ().to (torch .float32 ).numpy ()
86+ dtype = TORCH_TO_JAX_DTYPE_MAP .get (tensor .dtype , jnp .float32 )
87+ return torch_view (jax .device_put (np_tensor , sharding ).astype (dtype ))
8388
8489
8590def _shard_tensor_to_tpu_replicated (tensor : torch .Tensor ,
Original file line number Diff line number Diff line change @@ -105,7 +105,7 @@ def load_weights(self):
105105
106106 # Load the vLLM model and wrap it into a new model whose forward
107107 # function can calculate the hidden_state and logits.
108- with load_context , jax . default_device ( jax . devices ( 'cpu' )[ 0 ]) :
108+ with load_context :
109109 vllm_model = vllm_get_model (vllm_config = vllm_config_for_load )
110110 lora_manager = None
111111 if vllm_config_for_load .lora_config is not None :
You can’t perform that action at this time.
0 commit comments