diff --git a/tpu_inference/models/vllm/vllm_model_wrapper.py b/tpu_inference/models/vllm/vllm_model_wrapper.py index 2a2513689..bbdd49374 100644 --- a/tpu_inference/models/vllm/vllm_model_wrapper.py +++ b/tpu_inference/models/vllm/vllm_model_wrapper.py @@ -26,6 +26,8 @@ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config from tpu_inference.layers.vllm.sharding import shard_model_to_tpu from tpu_inference.logger import init_logger +from tpu_inference.models.jax.jax_intermediate_tensor import \ + JaxIntermediateTensors from tpu_inference.models.vllm.vllm_model_wrapper_context import ( get_vllm_model_wrapper_context, set_vllm_model_wrapper_context) from tpu_inference.runner.lora_utils import replace_lora_metadata @@ -137,7 +139,8 @@ def jit_step_func(self): "xla_tpu_reduce_scatter_collective_matmul_mode": "post_spmd_conservative" }, - static_argnames=("layer_name_to_kvcache_index", ), + static_argnames=("layer_name_to_kvcache_index", "is_first_rank", + "is_last_rank"), ) def step_fun( params_and_buffers, # This has been wrapped into torchax TorchValue @@ -147,6 +150,9 @@ def step_fun( input_embeds: jax.Array, layer_name_to_kvcache_index: Sequence[Tuple[str, int]], lora_metadata, + intermediate_tensors: JaxIntermediateTensors = None, + is_first_rank: bool = True, + is_last_rank: bool = True, *args, ) -> Tuple[List[jax.Array], jax.Array]: layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index) @@ -161,13 +167,15 @@ def step_fun( # torch_view in order to call the Torch function. original_lora_metadata = replace_lora_metadata( self.model, lora_metadata, self.vllm_config.lora_config) - hidden_states = torch.func.functional_call( + if not is_first_rank: + intermediate_tensors = intermediate_tensors.to_torch() + output_from_torch = torch.func.functional_call( self.model, torch_view(params_and_buffers), kwargs={ "input_ids": torch_view(input_ids), "positions": torch_view(attn_metadata.input_positions), - "intermediate_tensors": None, + "intermediate_tensors": intermediate_tensors, "inputs_embeds": None, }, tie_weights=False, @@ -176,11 +184,13 @@ def step_fun( self.vllm_config.lora_config) vllm_model_wrapper_context = get_vllm_model_wrapper_context() new_kv_caches = vllm_model_wrapper_context.kv_caches - # Wrap the hidden_states from torch land into a JaxValue for the jax - # code to consume. - hidden_states = jax_view(hidden_states) - - return new_kv_caches, hidden_states, [] + # Wrap the output(hidden states or intermediate tensor) + # from torch land into a JaxValue for the jax code to consume. + if not is_last_rank: + output = JaxIntermediateTensors.from_torch(output_from_torch) + else: + output = jax_view(output_from_torch) + return new_kv_caches, output, [] return step_fun