Skip to content

Commit 6c4d129

Browse files
authored
Enable Pipeline Parallelism on torchax path (#1055)
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent 9314721 commit 6c4d129

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
2626
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
2727
from tpu_inference.logger import init_logger
28+
from tpu_inference.models.jax.jax_intermediate_tensor import \
29+
JaxIntermediateTensors
2830
from tpu_inference.models.vllm.vllm_model_wrapper_context import (
2931
get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
3032
from tpu_inference.runner.lora_utils import replace_lora_metadata
@@ -149,7 +151,8 @@ def jit_step_func(self):
149151
"xla_tpu_reduce_scatter_collective_matmul_mode":
150152
"post_spmd_conservative"
151153
},
152-
static_argnames=("layer_name_to_kvcache_index", ),
154+
static_argnames=("layer_name_to_kvcache_index", "is_first_rank",
155+
"is_last_rank"),
153156
)
154157
def step_fun(
155158
params_and_buffers, # This has been wrapped into torchax TorchValue
@@ -159,6 +162,9 @@ def step_fun(
159162
input_embeds: jax.Array,
160163
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
161164
lora_metadata,
165+
intermediate_tensors: JaxIntermediateTensors = None,
166+
is_first_rank: bool = True,
167+
is_last_rank: bool = True,
162168
*args,
163169
) -> Tuple[List[jax.Array], jax.Array]:
164170
layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
@@ -173,13 +179,15 @@ def step_fun(
173179
# torch_view in order to call the Torch function.
174180
original_lora_metadata = replace_lora_metadata(
175181
self.model, lora_metadata, self.vllm_config.lora_config)
176-
hidden_states = torch.func.functional_call(
182+
if not is_first_rank:
183+
intermediate_tensors = intermediate_tensors.to_torch()
184+
output_from_torch = torch.func.functional_call(
177185
self.model,
178186
torch_view(params_and_buffers),
179187
kwargs={
180188
"input_ids": torch_view(input_ids),
181189
"positions": torch_view(attn_metadata.input_positions),
182-
"intermediate_tensors": None,
190+
"intermediate_tensors": intermediate_tensors,
183191
"inputs_embeds": None,
184192
},
185193
tie_weights=False,
@@ -188,11 +196,13 @@ def step_fun(
188196
self.vllm_config.lora_config)
189197
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
190198
new_kv_caches = vllm_model_wrapper_context.kv_caches
191-
# Wrap the hidden_states from torch land into a JaxValue for the jax
192-
# code to consume.
193-
hidden_states = jax_view(hidden_states)
194-
195-
return new_kv_caches, hidden_states, []
199+
# Wrap the output(hidden states or intermediate tensor)
200+
# from torch land into a JaxValue for the jax code to consume.
201+
if not is_last_rank:
202+
output = JaxIntermediateTensors.from_torch(output_from_torch)
203+
else:
204+
output = jax_view(output_from_torch)
205+
return new_kv_caches, output, []
196206

197207
return step_fun
198208

0 commit comments

Comments
 (0)