Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions tpu_inference/models/vllm/vllm_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.jax_intermediate_tensor import \
JaxIntermediateTensors
from tpu_inference.models.vllm.vllm_model_wrapper_context import (
get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
from tpu_inference.runner.lora_utils import replace_lora_metadata
Expand Down Expand Up @@ -137,7 +139,8 @@ def jit_step_func(self):
"xla_tpu_reduce_scatter_collective_matmul_mode":
"post_spmd_conservative"
},
static_argnames=("layer_name_to_kvcache_index", ),
static_argnames=("layer_name_to_kvcache_index", "is_first_rank",
"is_last_rank"),
)
def step_fun(
params_and_buffers, # This has been wrapped into torchax TorchValue
Expand All @@ -147,6 +150,9 @@ def step_fun(
input_embeds: jax.Array,
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
lora_metadata,
intermediate_tensors: JaxIntermediateTensors = None,
is_first_rank: bool = True,
is_last_rank: bool = True,
*args,
) -> Tuple[List[jax.Array], jax.Array]:
layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
Expand All @@ -161,13 +167,15 @@ def step_fun(
# torch_view in order to call the Torch function.
original_lora_metadata = replace_lora_metadata(
self.model, lora_metadata, self.vllm_config.lora_config)
hidden_states = torch.func.functional_call(
if not is_first_rank:
intermediate_tensors = intermediate_tensors.to_torch()
output_from_torch = torch.func.functional_call(
self.model,
torch_view(params_and_buffers),
kwargs={
"input_ids": torch_view(input_ids),
"positions": torch_view(attn_metadata.input_positions),
"intermediate_tensors": None,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": None,
},
tie_weights=False,
Expand All @@ -176,11 +184,13 @@ def step_fun(
self.vllm_config.lora_config)
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
new_kv_caches = vllm_model_wrapper_context.kv_caches
# Wrap the hidden_states from torch land into a JaxValue for the jax
# code to consume.
hidden_states = jax_view(hidden_states)

return new_kv_caches, hidden_states, []
# Wrap the output(hidden states or intermediate tensor)
# from torch land into a JaxValue for the jax code to consume.
if not is_last_rank:
output = JaxIntermediateTensors.from_torch(output_from_torch)
else:
output = jax_view(output_from_torch)
return new_kv_caches, output, []

return step_fun

Expand Down