Skip to content

Commit 9acb1fc

Browse files
committed
enable pp on jax models
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent 292d310 commit 9acb1fc

File tree

4 files changed

+138
-37
lines changed

4 files changed

+138
-37
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import List, Protocol
2+
3+
from flax import nnx
4+
from vllm.distributed import get_pp_group
5+
from vllm.distributed.utils import get_pp_indices
6+
7+
8+
class PPMissingLayer(nnx.Module):
9+
"""
10+
A placeholder layer for missing layers in a pipeline parallel model.
11+
"""
12+
13+
def __init__(self, *args, **kwargs):
14+
pass
15+
16+
def __call__(self, *args, **kwargs):
17+
"""Return the first arg from args or the first value from kwargs."""
18+
return args[0] if args else next(iter(kwargs.values()))
19+
20+
21+
class LayerFn(Protocol):
22+
23+
def __call__(self) -> nnx.Module:
24+
...
25+
26+
27+
def make_layers(
28+
num_hidden_layers: int,
29+
layer_fn: LayerFn,
30+
) -> tuple[int, int, List[nnx.Module]]:
31+
start_layer, end_layer = get_pp_indices(num_hidden_layers,
32+
get_pp_group().rank_in_group,
33+
get_pp_group().world_size)
34+
35+
layers = [PPMissingLayer() for _ in range(start_layer)] \
36+
+ [layer_fn() for _ in range(start_layer, end_layer)] \
37+
+ [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
38+
39+
return start_layer, end_layer, layers

tpu_inference/models/common/model_loader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ def get_flax_model(
217217
hidden_states_sharding, # aux hidden states
218218
),
219219
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
220-
static_argnums=6, #6 is layer_name_to_kvcache_index
220+
static_argnums=(
221+
6, 9, 10
222+
), #6 is layer_name_to_kvcache_index, 9 is is_first_rank, 10 is is_last_rank
221223
)
222224
def run_model(graphdef, state, *args):
223225
model = nnx.merge(graphdef, state)

tpu_inference/models/jax/llama3.py

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from itertools import islice
12
from typing import List, Optional, Tuple
23

34
import jax
@@ -6,13 +7,17 @@
67
from jax.sharding import Mesh
78
from transformers import LlamaConfig, modeling_flax_utils
89
from vllm.config import VllmConfig
10+
from vllm.distributed import get_pp_group
911

1012
from tpu_inference import utils
1113
from tpu_inference.layers.common.attention_interface import attention
1214
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
1315
from tpu_inference.layers.common.sharding import ShardingAxisName
16+
from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
1417
from tpu_inference.layers.jax.rope_interface import apply_rope
1518
from tpu_inference.logger import init_logger
19+
from tpu_inference.models.jax.jax_intermediate_tensor import \
20+
JaxIntermediateTensors
1621
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
1722
load_hf_weights)
1823

@@ -235,38 +240,52 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
235240
rms_norm_eps = hf_config.rms_norm_eps
236241
hidden_size = hf_config.hidden_size
237242

238-
self.embed = nnx.Embed(
239-
num_embeddings=vocab_size,
240-
features=hidden_size,
241-
param_dtype=dtype,
242-
embedding_init=nnx.with_partitioning(
243-
init_fn, (ShardingAxisName.VOCAB, None)),
244-
rngs=rng,
245-
)
246-
self.layers = [
247-
LlamaDecoderLayer(
243+
self.is_first_rank = get_pp_group().is_first_rank
244+
self.is_last_rank = get_pp_group().is_last_rank
245+
246+
if self.is_first_rank or (hf_config.tie_word_embeddings
247+
and self.is_last_rank):
248+
self.embed = nnx.Embed(
249+
num_embeddings=vocab_size,
250+
features=hidden_size,
251+
param_dtype=dtype,
252+
embedding_init=nnx.with_partitioning(
253+
init_fn, (ShardingAxisName.VOCAB, None)),
254+
rngs=rng,
255+
)
256+
else:
257+
self.embed = PPMissingLayer()
258+
259+
self.start_layer, self.end_layer, self.layers = make_layers(
260+
hf_config.num_hidden_layers,
261+
lambda: LlamaDecoderLayer(
248262
config=hf_config,
249263
dtype=dtype,
250264
rng=rng,
251265
mesh=mesh,
252266
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
253-
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
254-
for _ in range(hf_config.num_hidden_layers)
255-
]
256-
self.norm = nnx.RMSNorm(
257-
hidden_size,
258-
epsilon=rms_norm_eps,
259-
param_dtype=dtype,
260-
scale_init=nnx.with_partitioning(init_fn, (None, )),
261-
rngs=rng,
262-
)
263-
if model_config.hf_config.tie_word_embeddings:
264-
self.lm_head = self.embed.embedding
265-
else:
266-
self.lm_head = nnx.Param(
267-
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
268-
sharding=(None, ShardingAxisName.VOCAB),
267+
kv_cache_dtype=vllm_config.cache_config.cache_dtype))
268+
if self.is_last_rank:
269+
self.norm = nnx.RMSNorm(
270+
hidden_size,
271+
epsilon=rms_norm_eps,
272+
param_dtype=dtype,
273+
scale_init=nnx.with_partitioning(init_fn, (None, )),
274+
rngs=rng,
269275
)
276+
else:
277+
self.norm = PPMissingLayer()
278+
279+
if self.is_last_rank:
280+
if model_config.hf_config.tie_word_embeddings:
281+
self.lm_head = self.embed.embedding
282+
else:
283+
self.lm_head = nnx.Param(
284+
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
285+
sharding=(None, ShardingAxisName.VOCAB),
286+
)
287+
else:
288+
self.lm_head = PPMissingLayer()
270289

271290
self.aux_hidden_state_layers = []
272291
if vllm_config.speculative_config and vllm_config.speculative_config.method == "eagle3":
@@ -282,10 +301,18 @@ def __call__(
282301
kv_caches: List[jax.Array],
283302
input_ids: jax.Array,
284303
attention_metadata: AttentionMetadata,
285-
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
286-
x = self.embed(input_ids)
304+
intermediate_tensors: JaxIntermediateTensors | None,
305+
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
306+
List[jax.Array], JaxIntermediateTensors]:
307+
if self.is_first_rank:
308+
x = self.embed(input_ids)
309+
else:
310+
assert intermediate_tensors is not None
311+
x = intermediate_tensors["hidden_states"]
312+
287313
aux_hidden_states = []
288-
for i, layer in enumerate(self.layers):
314+
for i, layer in enumerate(
315+
islice(self.layers, self.start_layer, self.end_layer)):
289316
if i in self.aux_hidden_state_layers:
290317
aux_hidden_states.append(x)
291318
kv_cache = kv_caches[i]
@@ -295,6 +322,10 @@ def __call__(
295322
attention_metadata,
296323
)
297324
kv_caches[i] = kv_cache
325+
if not self.is_last_rank:
326+
# Note: add aux_hidden_states to make the output spec consistent.
327+
return kv_caches, JaxIntermediateTensors({"hidden_states":
328+
x}), aux_hidden_states
298329
x = self.norm(x)
299330
return kv_caches, x, aux_hidden_states
300331

@@ -313,19 +344,32 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
313344
mesh=mesh,
314345
)
315346

347+
self.pp_missing_layers = []
348+
for path, module in nnx.iter_graph(self.model):
349+
if isinstance(module, PPMissingLayer):
350+
# the path should be sth like ('layers', '0')
351+
self.pp_missing_layers.append('.'.join([str(s) for s in path]))
352+
316353
def __call__(
317354
self,
318355
kv_caches: List[jax.Array],
319356
input_ids: jax.Array,
320357
attention_metadata: AttentionMetadata,
358+
_input_embeds,
359+
_layer_name_to_kv_cache,
360+
_lora_metadata,
361+
intermediate_tensors: JaxIntermediateTensors,
362+
_is_first_rank: bool,
363+
_is_last_rank: bool,
321364
*args,
322-
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
323-
kv_caches, x, aux_hidden_states = self.model(
365+
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
366+
List[jax.Array], JaxIntermediateTensors]:
367+
return self.model(
324368
kv_caches,
325369
input_ids,
326370
attention_metadata,
371+
intermediate_tensors,
327372
)
328-
return kv_caches, x, aux_hidden_states
329373

330374
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
331375
if self.vllm_config.model_config.hf_config.tie_word_embeddings:
@@ -372,4 +416,5 @@ def load_weights(self, rng_key: jax.Array):
372416
load_hf_weights(vllm_config=self.vllm_config,
373417
model=self,
374418
metadata_map=metadata_map,
375-
mesh=self.mesh)
419+
mesh=self.mesh,
420+
pp_missing_layers=self.pp_missing_layers)

tpu_inference/models/jax/utils/weight_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ def _load_hf_weights_on_thread(vllm_config,
273273
weights_file: str,
274274
filter_regex: str | None = None,
275275
keep_original_dtype_keys_regex: list[str]
276-
| None = None):
276+
| None = None,
277+
pp_missing_layers: list[str] | None = None):
277278
name_map = metadata_map.name_map
278279
reshape_keys = metadata_map.reshape_map
279280
bias_reshape_keys = metadata_map.bias_reshape_map
@@ -338,6 +339,17 @@ def _load_hf_weights_on_thread(vllm_config,
338339
)
339340
continue
340341
model_key = name_map.get(hf_key, hf_key)
342+
# add skip pp missing layers.
343+
def is_pp_missing_layer(hf_key):
344+
has_digit = any(char.isdigit() for char in hf_key)
345+
# add the suffix after digits to avoid it matches "layers.10" with "layers.1"
346+
suffix = "." if has_digit else ""
347+
return any(f'{pp_missing_layer}{suffix}' in hf_key
348+
for pp_missing_layer in pp_missing_layers)
349+
350+
if pp_missing_layers and is_pp_missing_layer(hf_key):
351+
continue
352+
341353
model_weight, model_sharding = get_param_and_sharding(
342354
params, shardings, model_key)
343355

@@ -408,14 +420,16 @@ def load_hf_weights(vllm_config,
408420
mesh: Mesh,
409421
filter_regex: str | None = None,
410422
is_draft_model: bool = False,
411-
keep_original_dtype_keys_regex: list[str] | None = None):
423+
keep_original_dtype_keys_regex: list[str] | None = None,
424+
pp_missing_layers: list[str] | None = None):
412425
"""Load weights from all model weights files to the model, run in multi threads."""
413426
if is_draft_model:
414427
model_path = vllm_config.speculative_config.draft_model_config.model
415428
else:
416429
model_path = vllm_config.model_config.model
417430
weights_files = get_model_weights_files(
418431
model_path, vllm_config.load_config.download_dir)
432+
# For PP, params are partial.
419433
params = nnx.state(model)
420434
max_workers = min(64, len(weights_files))
421435
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
@@ -433,7 +447,8 @@ def load_hf_weights(vllm_config,
433447
mesh,
434448
weights_file,
435449
filter_regex=filter_regex,
436-
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
450+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
451+
pp_missing_layers=pp_missing_layers)
437452
for weights_file in weights_files
438453
]
439454
for future in futures:

0 commit comments

Comments
 (0)