1111from jax .sharding import PartitionSpec as P
1212from vllm .config import VllmConfig
1313from tpu_inference .models .jax .utils .quantization .mxfp4_utils import (
14- dequant_mxfp4_to_bf16 ,
14+ MXFP4_QUANT_METHOD , dequant_mxfp4_to_bf16 ,
15+ unpack_mxfp4_to_fp32 ,
1516)
1617
1718from tpu_inference .layers .jax .attention .gpt_oss_attention import (
@@ -188,6 +189,13 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
188189 """Loads and transforms all weights from a checkpoint"""
189190 self .rng = nnx .Rngs (rng )
190191
192+ # Determine quantization method from HF config (config.json)
193+ quant_method = (
194+ self .hf_config .quantization_config ["quant_method" ]
195+ if hasattr (self .hf_config , "quantization_config" )
196+ else None
197+ )
198+
191199 # Format: 'hf_key': ('jax_model_path', transform_function, target_shape)
192200 transforms = {
193201 "transpose_reshape" : lambda w , shape : w .T .reshape (shape ),
@@ -196,6 +204,9 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
196204 "swap_last2" : lambda w , _ : w .swapaxes (- 1 , - 2 ),
197205 }
198206
207+ # MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
208+ swap_mlp_transform = transforms ["swap_last2" ] if quant_method == MXFP4_QUANT_METHOD else None
209+
199210 mappings = {
200211 # Embeddings, Norms, and LM Head
201212 "model.embed_tokens.weight" : ("embedder.input_embedding_table_VD" ,
@@ -251,11 +262,11 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
251262 "model.layers.*.mlp.router.bias" :
252263 ("layers.*.custom_module.router.bias_E" , None , None ),
253264 "model.layers.*.mlp.experts.gate_up_proj" :
254- ("layers.*.custom_module.mlp1_weight_EDF2" , transforms [ "swap_last2" ] , None ),
265+ ("layers.*.custom_module.mlp1_weight_EDF2" , swap_mlp_transform , None ),
255266 "model.layers.*.mlp.experts.gate_up_proj_bias" :
256267 ("layers.*.custom_module.mlp1_bias_EF2" , None , None ),
257268 "model.layers.*.mlp.experts.down_proj" :
258- ("layers.*.custom_module.mlp2_weight_EFD" , transforms [ "swap_last2" ] , None ),
269+ ("layers.*.custom_module.mlp2_weight_EFD" , swap_mlp_transform , None ),
259270 "model.layers.*.mlp.experts.down_proj_bias" :
260271 ("layers.*.custom_module.mlp2_bias_ED" , None , None ),
261272 }
@@ -269,9 +280,76 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
269280 framework = "pt" ,
270281 download_dir = self .vllm_config .load_config .download_dir )
271282
272- # Single pass: build a unified pool. Combine MXFP4 expert blocks/scales
273- # into a dequantized bf16 tensor as soon as both are seen.
274- pool : dict [str , torch .Tensor ] = {}
283+ # Build a pool of weights with MXFP4 experts combined if neededs
284+ pool : dict [str , torch .Tensor | tuple ] = (
285+ self ._build_mxfp4_pool (names_and_weights_generator , mappings )
286+ if quant_method == MXFP4_QUANT_METHOD
287+ else {loaded_name : loaded_weight
288+ for loaded_name , loaded_weight in names_and_weights_generator }
289+ )
290+
291+ with jax .default_device (jax .devices ("cpu" )[0 ]):
292+ for loaded_name , loaded_weight in pool .items ():
293+ hf_pattern = re .sub (r"layers\.(\d+)" , "layers.*" , loaded_name )
294+ if hf_pattern not in mappings :
295+ logger .warning (
296+ f"No mapping found for checkpoint tensor: { loaded_name } . Skipping."
297+ )
298+ continue
299+
300+ jax_path_template , transform_fn , target_shape = mappings [
301+ hf_pattern ]
302+
303+ layer_num_match = re .search (r"layers\.(\d+)" , loaded_name )
304+ jax_path = jax_path_template
305+ if layer_num_match :
306+ jax_path = jax_path_template .replace (
307+ "*" , layer_num_match .group (1 ))
308+
309+ model_weight = get_param (model_params , jax_path )
310+
311+ prepared_weight = loaded_weight
312+ if isinstance (loaded_weight , tuple ):
313+ # Loaded weight is an MXFP4 tuple
314+ blocks_u8 , scales_u8 = loaded_weight
315+ # Quantized param (QArray): set qvalue/scale directly and skip regular path
316+ if hasattr (model_weight , "array" ): # QArray check
317+ codes_fp32_t , scales_fp32_t = unpack_mxfp4_to_fp32 (blocks_u8 , scales_u8 )
318+ self ._load_mxfp4 (
319+ model_weight = model_weight ,
320+ codes_fp32_t = codes_fp32_t ,
321+ scales_fp32_t = scales_fp32_t ,
322+ transform_fn = transform_fn ,
323+ )
324+ if is_verbose :
325+ print_param_info (model_weight , loaded_name )
326+ continue
327+ # Not a QArray: dequantize MXFP4 to BF16 full weights
328+ prepared_weight = dequant_mxfp4_to_bf16 (blocks_u8 , scales_u8 )
329+
330+ # Single regular-tensor load call (BF16 or dequantized MXFP4)
331+ cast_type = model_weight .value .dtype
332+ self ._load_regular_param (
333+ model_weight = model_weight ,
334+ loaded_weight = prepared_weight ,
335+ cast_type = cast_type ,
336+ transform_fn = transform_fn ,
337+ target_shape = target_shape ,
338+ jax_path_template = jax_path_template ,
339+ )
340+
341+ if is_verbose :
342+ print_param_info (model_weight , loaded_name )
343+
344+ nnx .update (self , model_params )
345+
346+ def _build_mxfp4_pool (self , names_and_weights_generator , mappings ):
347+ """Collect MXFP4 weights into a pool keeping tuples (blocks_u8, scales_u8).
348+
349+ Combines *_blocks and *_scales pairs and stores uint8 tensors together.
350+ Non-expert tensors are kept as-is. Raises if any expert bundle is incomplete.
351+ """
352+ pool : dict [str , torch .Tensor | tuple ] = {}
275353 pending_experts : dict [str , dict [str , torch .Tensor ]] = {}
276354 for loaded_name , loaded_weight in names_and_weights_generator :
277355 if loaded_name .endswith ("_blocks" ) or loaded_name .endswith ("_scales" ):
@@ -282,14 +360,12 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
282360 else :
283361 entry ["scales" ] = loaded_weight
284362
285- # If we have both parts, dequantize now and place into the main pool
363+ # If we have both parts, place raw pair into the main pool
286364 if "blocks" in entry and "scales" in entry :
287365 hf_pattern = re .sub (r"layers\.(\d+)" , "layers.*" , base )
288366 if hf_pattern not in mappings :
289- logger .warning (f"No mapping found for expert tensor: { base } . Skipping." )
290- else :
291- deq = dequant_mxfp4_to_bf16 (entry ["blocks" ], entry ["scales" ]) # torch.bfloat16
292- pool [base ] = deq
367+ raise ValueError (f"No mapping found for expert tensor: { base } " )
368+ pool [base ] = (entry ["blocks" ], entry ["scales" ])
293369 # Remove from pending to free memory
294370 pending_experts .pop (base , None )
295371 else :
@@ -304,68 +380,82 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
304380 raise RuntimeError (
305381 "Incomplete MXFP4 expert bundle(s) encountered: " + ", " .join (details )
306382 )
383+ return pool
384+
385+ def _load_mxfp4 (self ,
386+ model_weight ,
387+ codes_fp32_t ,
388+ scales_fp32_t ,
389+ transform_fn = None ):
390+ """Assign decoded MXFP4 codes/scales into a QArray (qvalue/scale)."""
391+
392+ qv = model_weight .array .qvalue
393+ sv = model_weight .array .scale
394+ q_dtype = qv .value .dtype
395+ s_dtype = sv .value .dtype
396+
397+ exp_q_shape = tuple (qv .value .shape )
398+ exp_s_shape = tuple (sv .value .shape )
399+
400+ # Apply optional transform (e.g., swap last two dims) before conversion
401+ if transform_fn is not None :
402+ codes_fp32_t = transform_fn (codes_fp32_t , None )
403+ scales_fp32_t = transform_fn (scales_fp32_t , None )
404+
405+ # Convert from torch.Tensor to numpy before creating JAX arrays
406+ codes_fp32_t = codes_fp32_t .detach ().cpu ().numpy ()
407+ scales_fp32_t = scales_fp32_t .detach ().cpu ().numpy ()
408+
409+ codes_jnp = jnp .asarray (codes_fp32_t ).astype (q_dtype )
410+ scales_jnp = jnp .asarray (scales_fp32_t ).astype (s_dtype )
411+
412+ def get_q_slice (index ):
413+ return codes_jnp [index ]
414+
415+ def get_s_slice (index ):
416+ return scales_jnp [index ]
417+
418+ q_sharded = jax .make_array_from_callback (
419+ exp_q_shape , NamedSharding (self .mesh , P (* qv .sharding )), get_q_slice )
420+ s_sharded = jax .make_array_from_callback (
421+ exp_s_shape , NamedSharding (self .mesh , P (* sv .sharding )), get_s_slice )
422+
423+ model_weight .array .qvalue .value = q_sharded
424+ model_weight .array .scale .value = s_sharded
425+
426+ def _load_regular_param (self ,
427+ model_weight ,
428+ loaded_weight : torch .Tensor ,
429+ cast_type ,
430+ transform_fn ,
431+ target_shape ,
432+ jax_path_template : str ):
433+ """Assign a regular tensor (non-MXFP4) into the model param with transform applied."""
434+ if jax_path_template == "layers.*.attn.sinks_N" :
435+ # Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
436+ weight_np = jnp .array (loaded_weight .to (torch .float32 ).numpy ())
437+ else :
438+ torch_view_type = DTYPE_VIEW_MAP .get (jnp .dtype (cast_type ))
439+ if torch_view_type :
440+ weight_np = jnp .array (loaded_weight .view (torch_view_type ).numpy ()).view (cast_type )
441+ else :
442+ raise ValueError (
443+ f"Unsupported dtype for tensor conversion: { cast_type } " )
307444
308- with jax .default_device (jax .devices ("cpu" )[0 ]):
309- for loaded_name , loaded_weight in pool .items ():
310- hf_pattern = re .sub (r"layers\.(\d+)" , "layers.*" , loaded_name )
311- if hf_pattern not in mappings :
312- logger .warning (
313- f"No mapping found for checkpoint tensor: { loaded_name } . Skipping."
314- )
315- continue
316-
317- jax_path_template , transform_fn , target_shape = mappings [
318- hf_pattern ]
319-
320- layer_num_match = re .search (r"layers\.(\d+)" , loaded_name )
321- jax_path = jax_path_template
322- if layer_num_match :
323- jax_path = jax_path_template .replace (
324- "*" , layer_num_match .group (1 ))
325-
326- model_weight = get_param (model_params , jax_path )
327- cast_type = model_weight .value .dtype
328-
329- if jax_path_template == "layers.*.attn.sinks_N" :
330- # Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
331- weight_np = jnp .array (
332- loaded_weight .to (torch .float32 ).numpy ())
333- else :
334- torch_view_type = DTYPE_VIEW_MAP .get (jnp .dtype (cast_type ))
335- if torch_view_type :
336- # Avoid unnecessary upcasting and mem copy by viewing the tensor's
337- # raw data as integers before converting to a JAX array.
338- weight_np = jnp .array (
339- loaded_weight .view (torch_view_type ).numpy ()).view (
340- cast_type )
341- else :
342- raise ValueError (
343- f"Unsupported dtype for tensor conversion: { cast_type } "
344- )
345-
346- if transform_fn :
347- transformed_weight = transform_fn (weight_np , target_shape )
348- else :
349- transformed_weight = weight_np
350-
351- if model_weight .value .shape != transformed_weight .shape :
352- raise ValueError (
353- f"Shape mismatch for '{ jax_path } ': Model expects { model_weight .value .shape } , but got { transformed_weight .shape } after transformation."
354- )
355-
356- def get_slice (index ):
357- return transformed_weight [index ]
445+ transformed_weight = transform_fn (weight_np , target_shape ) if transform_fn else weight_np
358446
359- sharded_array = jax .make_array_from_callback (
360- transformed_weight .shape ,
361- NamedSharding (self .mesh , P (* model_weight .sharding )),
362- get_slice )
363- model_weight .value = sharded_array
447+ if model_weight .value .shape != transformed_weight .shape :
448+ raise ValueError (
449+ f"Shape mismatch: model expects { model_weight .value .shape } , but got { transformed_weight .shape } after transform." )
364450
365- if is_verbose :
366- print_param_info ( model_weight , loaded_name )
451+ def get_slice ( index ) :
452+ return transformed_weight [ index ]
367453
368- nnx .update (self , model_params )
454+ sharded_array = jax .make_array_from_callback (
455+ transformed_weight .shape ,
456+ NamedSharding (self .mesh , P (* model_weight .sharding )),
457+ get_slice )
458+ model_weight .value = sharded_array
369459
370460 def __call__ (
371461 self ,
0 commit comments