1010from jax .sharding import Mesh , NamedSharding
1111from jax .sharding import PartitionSpec as P
1212from vllm .config import VllmConfig
13- from tpu_inference .models .jax .utils .quantization .mxfp4_utils import (
14- dequant_mxfp4_to_bf16 ,
15- )
1613
1714from tpu_inference .layers .jax .attention .gpt_oss_attention import (
1815 AttentionMetadata , GptOssAttention )
2118from tpu_inference .layers .jax .moe .gpt_oss_moe import GptOssMoE , GptOssRouter
2219from tpu_inference .layers .jax .transformer_block import TransformerBlock
2320from tpu_inference .logger import init_logger
21+ from tpu_inference .models .jax .utils .quantization .mxfp4_utils import (
22+ MXFP4_QUANT_METHOD , dequant_mxfp4_to_bf16 , unpack_mxfp4_to_fp32 )
2423from tpu_inference .models .jax .utils .weight_utils import (
2524 get_param , model_weights_generator , print_param_info )
2625
@@ -188,6 +187,11 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
188187 """Loads and transforms all weights from a checkpoint"""
189188 self .rng = nnx .Rngs (rng )
190189
190+ # Determine quantization method from HF config (config.json)
191+ quant_method = (self .hf_config .quantization_config ["quant_method" ]
192+ if hasattr (self .hf_config , "quantization_config" ) else
193+ None )
194+
191195 # Format: 'hf_key': ('jax_model_path', transform_function, target_shape)
192196 transforms = {
193197 "transpose_reshape" : lambda w , shape : w .T .reshape (shape ),
@@ -196,6 +200,10 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
196200 "swap_last2" : lambda w , _ : w .swapaxes (- 1 , - 2 ),
197201 }
198202
203+ # MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
204+ swap_mlp_transform = transforms [
205+ "swap_last2" ] if quant_method == MXFP4_QUANT_METHOD else None
206+
199207 mappings = {
200208 # Embeddings, Norms, and LM Head
201209 "model.embed_tokens.weight" : ("embedder.input_embedding_table_VD" ,
@@ -251,11 +259,13 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
251259 "model.layers.*.mlp.router.bias" :
252260 ("layers.*.custom_module.router.bias_E" , None , None ),
253261 "model.layers.*.mlp.experts.gate_up_proj" :
254- ("layers.*.custom_module.mlp1_weight_EDF2" , transforms ["swap_last2" ], None ),
262+ ("layers.*.custom_module.mlp1_weight_EDF2" , swap_mlp_transform ,
263+ None ),
255264 "model.layers.*.mlp.experts.gate_up_proj_bias" :
256265 ("layers.*.custom_module.mlp1_bias_EF2" , None , None ),
257266 "model.layers.*.mlp.experts.down_proj" :
258- ("layers.*.custom_module.mlp2_weight_EFD" , transforms ["swap_last2" ], None ),
267+ ("layers.*.custom_module.mlp2_weight_EFD" , swap_mlp_transform ,
268+ None ),
259269 "model.layers.*.mlp.experts.down_proj_bias" :
260270 ("layers.*.custom_module.mlp2_bias_ED" , None , None ),
261271 }
@@ -269,41 +279,13 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
269279 framework = "pt" ,
270280 download_dir = self .vllm_config .load_config .download_dir )
271281
272- # Single pass: build a unified pool. Combine MXFP4 expert blocks/scales
273- # into a dequantized bf16 tensor as soon as both are seen.
274- pool : dict [str , torch .Tensor ] = {}
275- pending_experts : dict [str , dict [str , torch .Tensor ]] = {}
276- for loaded_name , loaded_weight in names_and_weights_generator :
277- if loaded_name .endswith ("_blocks" ) or loaded_name .endswith ("_scales" ):
278- base = loaded_name [:- 7 ]
279- entry = pending_experts .setdefault (base , {})
280- if loaded_name .endswith ("_blocks" ):
281- entry ["blocks" ] = loaded_weight
282- else :
283- entry ["scales" ] = loaded_weight
284-
285- # If we have both parts, dequantize now and place into the main pool
286- if "blocks" in entry and "scales" in entry :
287- hf_pattern = re .sub (r"layers\.(\d+)" , "layers.*" , base )
288- if hf_pattern not in mappings :
289- logger .warning (f"No mapping found for expert tensor: { base } . Skipping." )
290- else :
291- deq = dequant_mxfp4_to_bf16 (entry ["blocks" ], entry ["scales" ]) # torch.bfloat16
292- pool [base ] = deq
293- # Remove from pending to free memory
294- pending_experts .pop (base , None )
295- else :
296- pool [loaded_name ] = loaded_weight
297-
298- # Enforce completeness of expert bundles
299- if pending_experts :
300- details = []
301- for base , entry in pending_experts .items ():
302- missing = [k for k in ("blocks" , "scales" ) if k not in entry ]
303- details .append (f"{ base } (missing: { ', ' .join (missing ) if missing else 'unknown' } )" )
304- raise RuntimeError (
305- "Incomplete MXFP4 expert bundle(s) encountered: " + ", " .join (details )
306- )
282+ # Build a pool of weights with MXFP4 experts combined if neededs
283+ pool : dict [str , torch .Tensor | tuple ] = (self ._build_mxfp4_pool (
284+ names_and_weights_generator ,
285+ mappings ) if quant_method == MXFP4_QUANT_METHOD else {
286+ loaded_name : loaded_weight
287+ for loaded_name , loaded_weight in names_and_weights_generator
288+ })
307289
308290 with jax .default_device (jax .devices ("cpu" )[0 ]):
309291 for loaded_name , loaded_weight in pool .items ():
@@ -324,48 +306,162 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
324306 "*" , layer_num_match .group (1 ))
325307
326308 model_weight = get_param (model_params , jax_path )
327- cast_type = model_weight .value .dtype
328309
329- if jax_path_template == "layers.*.attn.sinks_N" :
330- # Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
331- weight_np = jnp .array (
332- loaded_weight .to (torch .float32 ).numpy ())
333- else :
334- torch_view_type = DTYPE_VIEW_MAP .get (jnp .dtype (cast_type ))
335- if torch_view_type :
336- # Avoid unnecessary upcasting and mem copy by viewing the tensor's
337- # raw data as integers before converting to a JAX array.
338- weight_np = jnp .array (
339- loaded_weight .view (torch_view_type ).numpy ()).view (
340- cast_type )
341- else :
342- raise ValueError (
343- f"Unsupported dtype for tensor conversion: { cast_type } "
310+ prepared_weight = loaded_weight
311+ if isinstance (loaded_weight , tuple ):
312+ # Loaded weight is an MXFP4 tuple
313+ blocks_u8 , scales_u8 = loaded_weight
314+ # Quantized param (QArray): set qvalue/scale directly and skip regular path
315+ if hasattr (model_weight , "array" ): # QArray check
316+ codes_fp32_t , scales_fp32_t = unpack_mxfp4_to_fp32 (
317+ blocks_u8 , scales_u8 )
318+ self ._load_mxfp4 (
319+ model_weight = model_weight ,
320+ codes_fp32_t = codes_fp32_t ,
321+ scales_fp32_t = scales_fp32_t ,
322+ transform_fn = transform_fn ,
344323 )
324+ if is_verbose :
325+ print_param_info (model_weight , loaded_name )
326+ continue
327+ # Not a QArray: dequantize MXFP4 to BF16 full weights
328+ prepared_weight = dequant_mxfp4_to_bf16 (
329+ blocks_u8 , scales_u8 )
330+
331+ # Single regular-tensor load call (BF16 or dequantized MXFP4)
332+ cast_type = model_weight .value .dtype
333+ self ._load_regular_param (
334+ model_weight = model_weight ,
335+ loaded_weight = prepared_weight ,
336+ cast_type = cast_type ,
337+ transform_fn = transform_fn ,
338+ target_shape = target_shape ,
339+ jax_path_template = jax_path_template ,
340+ )
341+
342+ if is_verbose :
343+ print_param_info (model_weight , loaded_name )
344+
345+ nnx .update (self , model_params )
346+
347+ def _build_mxfp4_pool (self , names_and_weights_generator , mappings ):
348+ """Collect MXFP4 weights into a pool keeping tuples (blocks_u8, scales_u8).
345349
346- if transform_fn :
347- transformed_weight = transform_fn (weight_np , target_shape )
350+ Combines *_blocks and *_scales pairs and stores uint8 tensors together.
351+ Non-expert tensors are kept as-is. Raises if any expert bundle is incomplete.
352+ """
353+ pool : dict [str , torch .Tensor | tuple ] = {}
354+ pending_experts : dict [str , dict [str , torch .Tensor ]] = {}
355+ for loaded_name , loaded_weight in names_and_weights_generator :
356+ if loaded_name .endswith ("_blocks" ) or loaded_name .endswith (
357+ "_scales" ):
358+ base = loaded_name [:- 7 ]
359+ entry = pending_experts .setdefault (base , {})
360+ if loaded_name .endswith ("_blocks" ):
361+ entry ["blocks" ] = loaded_weight
348362 else :
349- transformed_weight = weight_np
363+ entry [ "scales" ] = loaded_weight
350364
351- if model_weight .value .shape != transformed_weight .shape :
352- raise ValueError (
353- f"Shape mismatch for '{ jax_path } ': Model expects { model_weight .value .shape } , but got { transformed_weight .shape } after transformation."
354- )
365+ # If we have both parts, place raw pair into the main pool
366+ if "blocks" in entry and "scales" in entry :
367+ hf_pattern = re .sub (r"layers\.(\d+)" , "layers.*" , base )
368+ if hf_pattern not in mappings :
369+ raise ValueError (
370+ f"No mapping found for expert tensor: { base } " )
371+ pool [base ] = (entry ["blocks" ], entry ["scales" ])
372+ # Remove from pending to free memory
373+ pending_experts .pop (base , None )
374+ else :
375+ pool [loaded_name ] = loaded_weight
376+
377+ # Enforce completeness of expert bundles
378+ if pending_experts :
379+ details = []
380+ for base , entry in pending_experts .items ():
381+ missing = [k for k in ("blocks" , "scales" ) if k not in entry ]
382+ details .append (
383+ f"{ base } (missing: { ', ' .join (missing ) if missing else 'unknown' } )"
384+ )
385+ raise RuntimeError (
386+ "Incomplete MXFP4 expert bundle(s) encountered: " +
387+ ", " .join (details ))
388+ return pool
389+
390+ def _load_mxfp4 (self ,
391+ model_weight ,
392+ codes_fp32_t ,
393+ scales_fp32_t ,
394+ transform_fn = None ):
395+ """Assign decoded MXFP4 codes/scales into a QArray (qvalue/scale)."""
396+
397+ qv = model_weight .array .qvalue
398+ sv = model_weight .array .scale
399+ q_dtype = qv .value .dtype
400+ s_dtype = sv .value .dtype
401+
402+ exp_q_shape = tuple (qv .value .shape )
403+ exp_s_shape = tuple (sv .value .shape )
404+
405+ # Apply optional transform (e.g., swap last two dims) before conversion
406+ if transform_fn is not None :
407+ codes_fp32_t = transform_fn (codes_fp32_t , None )
408+ scales_fp32_t = transform_fn (scales_fp32_t , None )
409+
410+ # Convert from torch.Tensor to numpy before creating JAX arrays
411+ codes_fp32_t = codes_fp32_t .detach ().cpu ().numpy ()
412+ scales_fp32_t = scales_fp32_t .detach ().cpu ().numpy ()
413+
414+ codes_jnp = jnp .asarray (codes_fp32_t ).astype (q_dtype )
415+ scales_jnp = jnp .asarray (scales_fp32_t ).astype (s_dtype )
416+
417+ def get_q_slice (index ):
418+ return codes_jnp [index ]
419+
420+ def get_s_slice (index ):
421+ return scales_jnp [index ]
422+
423+ q_sharded = jax .make_array_from_callback (
424+ exp_q_shape , NamedSharding (self .mesh , P (* qv .sharding )),
425+ get_q_slice )
426+ s_sharded = jax .make_array_from_callback (
427+ exp_s_shape , NamedSharding (self .mesh , P (* sv .sharding )),
428+ get_s_slice )
429+
430+ model_weight .array .qvalue .value = q_sharded
431+ model_weight .array .scale .value = s_sharded
432+
433+ def _load_regular_param (self , model_weight , loaded_weight : torch .Tensor ,
434+ cast_type , transform_fn , target_shape ,
435+ jax_path_template : str ):
436+ """Assign a regular tensor (non-MXFP4) into the model param with transform applied."""
437+ if jax_path_template == "layers.*.attn.sinks_N" :
438+ # Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
439+ weight_np = jnp .array (loaded_weight .to (torch .float32 ).numpy ())
440+ else :
441+ torch_view_type = DTYPE_VIEW_MAP .get (jnp .dtype (cast_type ))
442+ if torch_view_type :
443+ weight_np = jnp .array (
444+ loaded_weight .view (torch_view_type ).numpy ()).view (
445+ cast_type )
446+ else :
447+ raise ValueError (
448+ f"Unsupported dtype for tensor conversion: { cast_type } " )
355449
356- def get_slice ( index ):
357- return transformed_weight [ index ]
450+ transformed_weight = transform_fn (
451+ weight_np , target_shape ) if transform_fn else weight_np
358452
359- sharded_array = jax .make_array_from_callback (
360- transformed_weight .shape ,
361- NamedSharding (self .mesh , P (* model_weight .sharding )),
362- get_slice )
363- model_weight .value = sharded_array
453+ if model_weight .value .shape != transformed_weight .shape :
454+ raise ValueError (
455+ f"Shape mismatch: model expects { model_weight .value .shape } , but got { transformed_weight .shape } after transform."
456+ )
364457
365- if is_verbose :
366- print_param_info ( model_weight , loaded_name )
458+ def get_slice ( index ) :
459+ return transformed_weight [ index ]
367460
368- nnx .update (self , model_params )
461+ sharded_array = jax .make_array_from_callback (
462+ transformed_weight .shape ,
463+ NamedSharding (self .mesh , P (* model_weight .sharding )), get_slice )
464+ model_weight .value = sharded_array
369465
370466 def __call__ (
371467 self ,
0 commit comments