2525from tpu_inference .layers .vllm .quantization import get_tpu_quantization_config
2626from tpu_inference .layers .vllm .sharding import shard_model_to_tpu
2727from tpu_inference .logger import init_logger
28+ from tpu_inference .models .jax .jax_intermediate_tensor import \
29+ JaxIntermediateTensors
2830from tpu_inference .models .vllm .vllm_model_wrapper_context import (
2931 get_vllm_model_wrapper_context , set_vllm_model_wrapper_context )
3032from tpu_inference .runner .lora_utils import replace_lora_metadata
@@ -149,7 +151,8 @@ def jit_step_func(self):
149151 "xla_tpu_reduce_scatter_collective_matmul_mode" :
150152 "post_spmd_conservative"
151153 },
152- static_argnames = ("layer_name_to_kvcache_index" , ),
154+ static_argnames = ("layer_name_to_kvcache_index" , "is_first_rank" ,
155+ "is_last_rank" ),
153156 )
154157 def step_fun (
155158 params_and_buffers , # This has been wrapped into torchax TorchValue
@@ -159,6 +162,9 @@ def step_fun(
159162 input_embeds : jax .Array ,
160163 layer_name_to_kvcache_index : Sequence [Tuple [str , int ]],
161164 lora_metadata ,
165+ intermediate_tensors : JaxIntermediateTensors = None ,
166+ is_first_rank : bool = True ,
167+ is_last_rank : bool = True ,
162168 * args ,
163169 ) -> Tuple [List [jax .Array ], jax .Array ]:
164170 layer_name_to_kvcache_index = dict (layer_name_to_kvcache_index )
@@ -173,13 +179,15 @@ def step_fun(
173179 # torch_view in order to call the Torch function.
174180 original_lora_metadata = replace_lora_metadata (
175181 self .model , lora_metadata , self .vllm_config .lora_config )
176- hidden_states = torch .func .functional_call (
182+ if not is_first_rank :
183+ intermediate_tensors = intermediate_tensors .to_torch ()
184+ output_from_torch = torch .func .functional_call (
177185 self .model ,
178186 torch_view (params_and_buffers ),
179187 kwargs = {
180188 "input_ids" : torch_view (input_ids ),
181189 "positions" : torch_view (attn_metadata .input_positions ),
182- "intermediate_tensors" : None ,
190+ "intermediate_tensors" : intermediate_tensors ,
183191 "inputs_embeds" : None ,
184192 },
185193 tie_weights = False ,
@@ -188,11 +196,13 @@ def step_fun(
188196 self .vllm_config .lora_config )
189197 vllm_model_wrapper_context = get_vllm_model_wrapper_context ()
190198 new_kv_caches = vllm_model_wrapper_context .kv_caches
191- # Wrap the hidden_states from torch land into a JaxValue for the jax
192- # code to consume.
193- hidden_states = jax_view (hidden_states )
194-
195- return new_kv_caches , hidden_states , []
199+ # Wrap the output(hidden states or intermediate tensor)
200+ # from torch land into a JaxValue for the jax code to consume.
201+ if not is_last_rank :
202+ output = JaxIntermediateTensors .from_torch (output_from_torch )
203+ else :
204+ output = jax_view (output_from_torch )
205+ return new_kv_caches , output , []
196206
197207 return step_fun
198208
0 commit comments