1+ from itertools import islice
12from typing import List , Optional , Tuple
23
34import jax
67from jax .sharding import Mesh
78from transformers import LlamaConfig , modeling_flax_utils
89from vllm .config import VllmConfig
10+ from vllm .distributed import get_pp_group
911
1012from tpu_inference import utils
1113from tpu_inference .layers .common .attention_interface import attention
1214from tpu_inference .layers .common .attention_metadata import AttentionMetadata
1315from tpu_inference .layers .common .sharding import ShardingAxisName
16+ from tpu_inference .layers .jax .pp_utils import PPMissingLayer , make_layers
1417from tpu_inference .layers .jax .rope_interface import apply_rope
1518from tpu_inference .logger import init_logger
19+ from tpu_inference .models .jax .jax_intermediate_tensor import \
20+ JaxIntermediateTensors
1621from tpu_inference .models .jax .utils .weight_utils import (get_default_maps ,
1722 load_hf_weights )
1823
@@ -235,38 +240,52 @@ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
235240 rms_norm_eps = hf_config .rms_norm_eps
236241 hidden_size = hf_config .hidden_size
237242
238- self .embed = nnx .Embed (
239- num_embeddings = vocab_size ,
240- features = hidden_size ,
241- param_dtype = dtype ,
242- embedding_init = nnx .with_partitioning (
243- init_fn , (ShardingAxisName .VOCAB , None )),
244- rngs = rng ,
245- )
246- self .layers = [
247- LlamaDecoderLayer (
243+ self .is_first_rank = get_pp_group ().is_first_rank
244+ self .is_last_rank = get_pp_group ().is_last_rank
245+
246+ if self .is_first_rank or (hf_config .tie_word_embeddings
247+ and self .is_last_rank ):
248+ self .embed = nnx .Embed (
249+ num_embeddings = vocab_size ,
250+ features = hidden_size ,
251+ param_dtype = dtype ,
252+ embedding_init = nnx .with_partitioning (
253+ init_fn , (ShardingAxisName .VOCAB , None )),
254+ rngs = rng ,
255+ )
256+ else :
257+ self .embed = PPMissingLayer ()
258+
259+ self .start_layer , self .end_layer , self .layers = make_layers (
260+ hf_config .num_hidden_layers ,
261+ lambda : LlamaDecoderLayer (
248262 config = hf_config ,
249263 dtype = dtype ,
250264 rng = rng ,
251265 mesh = mesh ,
252266 # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
253- kv_cache_dtype = vllm_config .cache_config .cache_dtype )
254- for _ in range (hf_config .num_hidden_layers )
255- ]
256- self .norm = nnx .RMSNorm (
257- hidden_size ,
258- epsilon = rms_norm_eps ,
259- param_dtype = dtype ,
260- scale_init = nnx .with_partitioning (init_fn , (None , )),
261- rngs = rng ,
262- )
263- if model_config .hf_config .tie_word_embeddings :
264- self .lm_head = self .embed .embedding
265- else :
266- self .lm_head = nnx .Param (
267- init_fn (rng .params (), (hidden_size , vocab_size ), dtype ),
268- sharding = (None , ShardingAxisName .VOCAB ),
267+ kv_cache_dtype = vllm_config .cache_config .cache_dtype ))
268+ if self .is_last_rank :
269+ self .norm = nnx .RMSNorm (
270+ hidden_size ,
271+ epsilon = rms_norm_eps ,
272+ param_dtype = dtype ,
273+ scale_init = nnx .with_partitioning (init_fn , (None , )),
274+ rngs = rng ,
269275 )
276+ else :
277+ self .norm = PPMissingLayer ()
278+
279+ if self .is_last_rank :
280+ if model_config .hf_config .tie_word_embeddings :
281+ self .lm_head = self .embed .embedding
282+ else :
283+ self .lm_head = nnx .Param (
284+ init_fn (rng .params (), (hidden_size , vocab_size ), dtype ),
285+ sharding = (None , ShardingAxisName .VOCAB ),
286+ )
287+ else :
288+ self .lm_head = PPMissingLayer ()
270289
271290 self .aux_hidden_state_layers = []
272291 if vllm_config .speculative_config and vllm_config .speculative_config .method == "eagle3" :
@@ -282,10 +301,18 @@ def __call__(
282301 kv_caches : List [jax .Array ],
283302 input_ids : jax .Array ,
284303 attention_metadata : AttentionMetadata ,
285- ) -> Tuple [List [jax .Array ], jax .Array , List [jax .Array ]]:
286- x = self .embed (input_ids )
304+ intermediate_tensors : JaxIntermediateTensors | None ,
305+ ) -> Tuple [List [jax .Array ], jax .Array , List [jax .Array ]] | Tuple [
306+ List [jax .Array ], JaxIntermediateTensors ]:
307+ if self .is_first_rank :
308+ x = self .embed (input_ids )
309+ else :
310+ assert intermediate_tensors is not None
311+ x = intermediate_tensors ["hidden_states" ]
312+
287313 aux_hidden_states = []
288- for i , layer in enumerate (self .layers ):
314+ for i , layer in enumerate (
315+ islice (self .layers , self .start_layer , self .end_layer )):
289316 if i in self .aux_hidden_state_layers :
290317 aux_hidden_states .append (x )
291318 kv_cache = kv_caches [i ]
@@ -295,6 +322,10 @@ def __call__(
295322 attention_metadata ,
296323 )
297324 kv_caches [i ] = kv_cache
325+ if not self .is_last_rank :
326+ # Note: add aux_hidden_states to make the output spec consistent.
327+ return kv_caches , JaxIntermediateTensors ({"hidden_states" :
328+ x }), aux_hidden_states
298329 x = self .norm (x )
299330 return kv_caches , x , aux_hidden_states
300331
@@ -313,19 +344,32 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
313344 mesh = mesh ,
314345 )
315346
347+ self .pp_missing_layers = []
348+ for path , module in nnx .iter_graph (self .model ):
349+ if isinstance (module , PPMissingLayer ):
350+ # the path should be sth like ('layers', '0')
351+ self .pp_missing_layers .append ('.' .join ([str (s ) for s in path ]))
352+
316353 def __call__ (
317354 self ,
318355 kv_caches : List [jax .Array ],
319356 input_ids : jax .Array ,
320357 attention_metadata : AttentionMetadata ,
358+ _input_embeds ,
359+ _layer_name_to_kv_cache ,
360+ _lora_metadata ,
361+ intermediate_tensors : JaxIntermediateTensors ,
362+ _is_first_rank : bool ,
363+ _is_last_rank : bool ,
321364 * args ,
322- ) -> Tuple [List [jax .Array ], jax .Array , List [jax .Array ]]:
323- kv_caches , x , aux_hidden_states = self .model (
365+ ) -> Tuple [List [jax .Array ], jax .Array , List [jax .Array ]] | Tuple [
366+ List [jax .Array ], JaxIntermediateTensors ]:
367+ return self .model (
324368 kv_caches ,
325369 input_ids ,
326370 attention_metadata ,
371+ intermediate_tensors ,
327372 )
328- return kv_caches , x , aux_hidden_states
329373
330374 def compute_logits (self , hidden_states : jax .Array ) -> jax .Array :
331375 if self .vllm_config .model_config .hf_config .tie_word_embeddings :
@@ -372,4 +416,5 @@ def load_weights(self, rng_key: jax.Array):
372416 load_hf_weights (vllm_config = self .vllm_config ,
373417 model = self ,
374418 metadata_map = metadata_map ,
375- mesh = self .mesh )
419+ mesh = self .mesh ,
420+ pp_missing_layers = self .pp_missing_layers )
0 commit comments