1818from tpu_inference .layers .jax .moe .gpt_oss_moe import GptOssMoE , GptOssRouter
1919from tpu_inference .layers .jax .transformer_block import TransformerBlock
2020from tpu_inference .logger import init_logger
21+ from tpu_inference .models .jax .utils .quantization .mxfp4_utils import (
22+ MXFP4_QUANT_METHOD , dequant_mxfp4_to_bf16 , unpack_mxfp4_to_fp32 )
2123from tpu_inference .models .jax .utils .weight_utils import (
2224 get_param , model_weights_generator , print_param_info )
2325
@@ -185,13 +187,23 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
185187 """Loads and transforms all weights from a checkpoint"""
186188 self .rng = nnx .Rngs (rng )
187189
190+ # Determine quantization method from HF config (config.json)
191+ quant_method = (self .hf_config .quantization_config ["quant_method" ]
192+ if hasattr (self .hf_config , "quantization_config" ) else
193+ None )
194+
188195 # Format: 'hf_key': ('jax_model_path', transform_function, target_shape)
189196 transforms = {
190197 "transpose_reshape" : lambda w , shape : w .T .reshape (shape ),
191198 "reshape" : lambda b , shape : b .reshape (shape ),
192199 "transpose" : lambda w , _ : w .T ,
200+ "swap_last2" : lambda w , _ : w .swapaxes (- 1 , - 2 ),
193201 }
194202
203+ # MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
204+ swap_mlp_transform = transforms [
205+ "swap_last2" ] if quant_method == MXFP4_QUANT_METHOD else None
206+
195207 mappings = {
196208 # Embeddings, Norms, and LM Head
197209 "model.embed_tokens.weight" : ("embedder.input_embedding_table_VD" ,
@@ -247,11 +259,13 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
247259 "model.layers.*.mlp.router.bias" :
248260 ("layers.*.custom_module.router.bias_E" , None , None ),
249261 "model.layers.*.mlp.experts.gate_up_proj" :
250- ("layers.*.custom_module.mlp1_weight_EDF2" , None , None ),
262+ ("layers.*.custom_module.mlp1_weight_EDF2" , swap_mlp_transform ,
263+ None ),
251264 "model.layers.*.mlp.experts.gate_up_proj_bias" :
252265 ("layers.*.custom_module.mlp1_bias_EF2" , None , None ),
253266 "model.layers.*.mlp.experts.down_proj" :
254- ("layers.*.custom_module.mlp2_weight_EFD" , None , None ),
267+ ("layers.*.custom_module.mlp2_weight_EFD" , swap_mlp_transform ,
268+ None ),
255269 "model.layers.*.mlp.experts.down_proj_bias" :
256270 ("layers.*.custom_module.mlp2_bias_ED" , None , None ),
257271 }
@@ -265,8 +279,16 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
265279 framework = "pt" ,
266280 download_dir = self .vllm_config .load_config .download_dir )
267281
282+ # Build a pool of weights with MXFP4 experts combined if neededs
283+ pool : dict [str , torch .Tensor | tuple ] = (self ._build_mxfp4_pool (
284+ names_and_weights_generator ,
285+ mappings ) if quant_method == MXFP4_QUANT_METHOD else {
286+ loaded_name : loaded_weight
287+ for loaded_name , loaded_weight in names_and_weights_generator
288+ })
289+
268290 with jax .default_device (jax .devices ("cpu" )[0 ]):
269- for loaded_name , loaded_weight in names_and_weights_generator :
291+ for loaded_name , loaded_weight in pool . items () :
270292 hf_pattern = re .sub (r"layers\.(\d+)" , "layers.*" , loaded_name )
271293 if hf_pattern not in mappings :
272294 logger .warning (
@@ -284,48 +306,162 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
284306 "*" , layer_num_match .group (1 ))
285307
286308 model_weight = get_param (model_params , jax_path )
287- cast_type = model_weight .value .dtype
288309
289- if jax_path_template == "layers.*.attn.sinks_N" :
290- # Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
291- weight_np = jnp .array (
292- loaded_weight .to (torch .float32 ).numpy ())
293- else :
294- torch_view_type = DTYPE_VIEW_MAP .get (jnp .dtype (cast_type ))
295- if torch_view_type :
296- # Avoid unnecessary upcasting and mem copy by viewing the tensor's
297- # raw data as integers before converting to a JAX array.
298- weight_np = jnp .array (
299- loaded_weight .view (torch_view_type ).numpy ()).view (
300- cast_type )
301- else :
302- raise ValueError (
303- f"Unsupported dtype for tensor conversion: { cast_type } "
310+ prepared_weight = loaded_weight
311+ if isinstance (loaded_weight , tuple ):
312+ # Loaded weight is an MXFP4 tuple
313+ blocks_u8 , scales_u8 = loaded_weight
314+ # Quantized param (QArray): set qvalue/scale directly and skip regular path
315+ if hasattr (model_weight , "array" ): # QArray check
316+ codes_fp32_t , scales_fp32_t = unpack_mxfp4_to_fp32 (
317+ blocks_u8 , scales_u8 )
318+ self ._load_mxfp4 (
319+ model_weight = model_weight ,
320+ codes_fp32_t = codes_fp32_t ,
321+ scales_fp32_t = scales_fp32_t ,
322+ transform_fn = transform_fn ,
304323 )
324+ if is_verbose :
325+ print_param_info (model_weight , loaded_name )
326+ continue
327+ # Not a QArray: dequantize MXFP4 to BF16 full weights
328+ prepared_weight = dequant_mxfp4_to_bf16 (
329+ blocks_u8 , scales_u8 )
330+
331+ # Single regular-tensor load call (BF16 or dequantized MXFP4)
332+ cast_type = model_weight .value .dtype
333+ self ._load_regular_param (
334+ model_weight = model_weight ,
335+ loaded_weight = prepared_weight ,
336+ cast_type = cast_type ,
337+ transform_fn = transform_fn ,
338+ target_shape = target_shape ,
339+ jax_path_template = jax_path_template ,
340+ )
305341
306- if transform_fn :
307- transformed_weight = transform_fn (weight_np , target_shape )
308- else :
309- transformed_weight = weight_np
342+ if is_verbose :
343+ print_param_info (model_weight , loaded_name )
310344
311- if model_weight .value .shape != transformed_weight .shape :
312- raise ValueError (
313- f"Shape mismatch for '{ jax_path } ': Model expects { model_weight .value .shape } , but got { transformed_weight .shape } after transformation."
314- )
345+ nnx .update (self , model_params )
315346
316- def get_slice (index ):
317- return transformed_weight [index ]
347+ def _build_mxfp4_pool (self , names_and_weights_generator , mappings ):
348+ """Collect MXFP4 weights into a pool keeping tuples (blocks_u8, scales_u8).
349+
350+ Combines *_blocks and *_scales pairs and stores uint8 tensors together.
351+ Non-expert tensors are kept as-is. Raises if any expert bundle is incomplete.
352+ """
353+ pool : dict [str , torch .Tensor | tuple ] = {}
354+ pending_experts : dict [str , dict [str , torch .Tensor ]] = {}
355+ for loaded_name , loaded_weight in names_and_weights_generator :
356+ if loaded_name .endswith ("_blocks" ) or loaded_name .endswith (
357+ "_scales" ):
358+ base = loaded_name [:- 7 ]
359+ entry = pending_experts .setdefault (base , {})
360+ if loaded_name .endswith ("_blocks" ):
361+ entry ["blocks" ] = loaded_weight
362+ else :
363+ entry ["scales" ] = loaded_weight
318364
319- sharded_array = jax .make_array_from_callback (
320- transformed_weight .shape ,
321- NamedSharding (self .mesh , P (* model_weight .sharding )),
322- get_slice )
323- model_weight .value = sharded_array
365+ # If we have both parts, place raw pair into the main pool
366+ if "blocks" in entry and "scales" in entry :
367+ hf_pattern = re .sub (r"layers\.(\d+)" , "layers.*" , base )
368+ if hf_pattern not in mappings :
369+ raise ValueError (
370+ f"No mapping found for expert tensor: { base } " )
371+ pool [base ] = (entry ["blocks" ], entry ["scales" ])
372+ # Remove from pending to free memory
373+ pending_experts .pop (base , None )
374+ else :
375+ pool [loaded_name ] = loaded_weight
376+
377+ # Enforce completeness of expert bundles
378+ if pending_experts :
379+ details = []
380+ for base , entry in pending_experts .items ():
381+ missing = [k for k in ("blocks" , "scales" ) if k not in entry ]
382+ details .append (
383+ f"{ base } (missing: { ', ' .join (missing ) if missing else 'unknown' } )"
384+ )
385+ raise RuntimeError (
386+ "Incomplete MXFP4 expert bundle(s) encountered: " +
387+ ", " .join (details ))
388+ return pool
389+
390+ def _load_mxfp4 (self ,
391+ model_weight ,
392+ codes_fp32_t ,
393+ scales_fp32_t ,
394+ transform_fn = None ):
395+ """Assign decoded MXFP4 codes/scales into a QArray (qvalue/scale)."""
396+
397+ qv = model_weight .array .qvalue
398+ sv = model_weight .array .scale
399+ q_dtype = qv .value .dtype
400+ s_dtype = sv .value .dtype
401+
402+ exp_q_shape = tuple (qv .value .shape )
403+ exp_s_shape = tuple (sv .value .shape )
404+
405+ # Apply optional transform (e.g., swap last two dims) before conversion
406+ if transform_fn is not None :
407+ codes_fp32_t = transform_fn (codes_fp32_t , None )
408+ scales_fp32_t = transform_fn (scales_fp32_t , None )
409+
410+ # Convert from torch.Tensor to numpy before creating JAX arrays
411+ codes_fp32_t = codes_fp32_t .detach ().cpu ().numpy ()
412+ scales_fp32_t = scales_fp32_t .detach ().cpu ().numpy ()
413+
414+ codes_jnp = jnp .asarray (codes_fp32_t ).astype (q_dtype )
415+ scales_jnp = jnp .asarray (scales_fp32_t ).astype (s_dtype )
416+
417+ def get_q_slice (index ):
418+ return codes_jnp [index ]
419+
420+ def get_s_slice (index ):
421+ return scales_jnp [index ]
422+
423+ q_sharded = jax .make_array_from_callback (
424+ exp_q_shape , NamedSharding (self .mesh , P (* qv .sharding )),
425+ get_q_slice )
426+ s_sharded = jax .make_array_from_callback (
427+ exp_s_shape , NamedSharding (self .mesh , P (* sv .sharding )),
428+ get_s_slice )
429+
430+ model_weight .array .qvalue .value = q_sharded
431+ model_weight .array .scale .value = s_sharded
432+
433+ def _load_regular_param (self , model_weight , loaded_weight : torch .Tensor ,
434+ cast_type , transform_fn , target_shape ,
435+ jax_path_template : str ):
436+ """Assign a regular tensor (non-MXFP4) into the model param with transform applied."""
437+ if jax_path_template == "layers.*.attn.sinks_N" :
438+ # Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
439+ weight_np = jnp .array (loaded_weight .to (torch .float32 ).numpy ())
440+ else :
441+ torch_view_type = DTYPE_VIEW_MAP .get (jnp .dtype (cast_type ))
442+ if torch_view_type :
443+ weight_np = jnp .array (
444+ loaded_weight .view (torch_view_type ).numpy ()).view (
445+ cast_type )
446+ else :
447+ raise ValueError (
448+ f"Unsupported dtype for tensor conversion: { cast_type } " )
449+
450+ transformed_weight = transform_fn (
451+ weight_np , target_shape ) if transform_fn else weight_np
452+
453+ if model_weight .value .shape != transformed_weight .shape :
454+ raise ValueError (
455+ f"Shape mismatch: model expects { model_weight .value .shape } , but got { transformed_weight .shape } after transform."
456+ )
324457
325- if is_verbose :
326- print_param_info ( model_weight , loaded_name )
458+ def get_slice ( index ) :
459+ return transformed_weight [ index ]
327460
328- nnx .update (self , model_params )
461+ sharded_array = jax .make_array_from_callback (
462+ transformed_weight .shape ,
463+ NamedSharding (self .mesh , P (* model_weight .sharding )), get_slice )
464+ model_weight .value = sharded_array
329465
330466 def __call__ (
331467 self ,
0 commit comments