Skip to content

Commit 81954b4

Browse files
committed
[GPT OSS] Add support for both BF16 and MXFP4
Signed-off-by: Jordan Dotzel <amishacorns@users.noreply.github.com>
1 parent 6e96676 commit 81954b4

File tree

3 files changed

+214
-77
lines changed

3 files changed

+214
-77
lines changed

tpu_inference/models/jax/gpt_oss.py

Lines changed: 159 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from jax.sharding import PartitionSpec as P
1212
from vllm.config import VllmConfig
1313
from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
14-
dequant_mxfp4_to_bf16,
14+
MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16,
15+
unpack_mxfp4_to_fp32,
1516
)
1617

1718
from tpu_inference.layers.jax.attention.gpt_oss_attention import (
@@ -188,6 +189,13 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
188189
"""Loads and transforms all weights from a checkpoint"""
189190
self.rng = nnx.Rngs(rng)
190191

192+
# Determine quantization method from HF config (config.json)
193+
quant_method = (
194+
self.hf_config.quantization_config["quant_method"]
195+
if hasattr(self.hf_config, "quantization_config")
196+
else None
197+
)
198+
191199
# Format: 'hf_key': ('jax_model_path', transform_function, target_shape)
192200
transforms = {
193201
"transpose_reshape": lambda w, shape: w.T.reshape(shape),
@@ -196,6 +204,9 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
196204
"swap_last2": lambda w, _: w.swapaxes(-1, -2),
197205
}
198206

207+
# MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
208+
swap_mlp_transform = transforms["swap_last2"] if quant_method == MXFP4_QUANT_METHOD else None
209+
199210
mappings = {
200211
# Embeddings, Norms, and LM Head
201212
"model.embed_tokens.weight": ("embedder.input_embedding_table_VD",
@@ -251,11 +262,11 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
251262
"model.layers.*.mlp.router.bias":
252263
("layers.*.custom_module.router.bias_E", None, None),
253264
"model.layers.*.mlp.experts.gate_up_proj":
254-
("layers.*.custom_module.mlp1_weight_EDF2", transforms["swap_last2"], None),
265+
("layers.*.custom_module.mlp1_weight_EDF2", swap_mlp_transform, None),
255266
"model.layers.*.mlp.experts.gate_up_proj_bias":
256267
("layers.*.custom_module.mlp1_bias_EF2", None, None),
257268
"model.layers.*.mlp.experts.down_proj":
258-
("layers.*.custom_module.mlp2_weight_EFD", transforms["swap_last2"], None),
269+
("layers.*.custom_module.mlp2_weight_EFD", swap_mlp_transform, None),
259270
"model.layers.*.mlp.experts.down_proj_bias":
260271
("layers.*.custom_module.mlp2_bias_ED", None, None),
261272
}
@@ -269,9 +280,76 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
269280
framework="pt",
270281
download_dir=self.vllm_config.load_config.download_dir)
271282

272-
# Single pass: build a unified pool. Combine MXFP4 expert blocks/scales
273-
# into a dequantized bf16 tensor as soon as both are seen.
274-
pool: dict[str, torch.Tensor] = {}
283+
# Build a pool of weights with MXFP4 experts combined if neededs
284+
pool: dict[str, torch.Tensor | tuple] = (
285+
self._build_mxfp4_pool(names_and_weights_generator, mappings)
286+
if quant_method == MXFP4_QUANT_METHOD
287+
else {loaded_name: loaded_weight
288+
for loaded_name, loaded_weight in names_and_weights_generator}
289+
)
290+
291+
with jax.default_device(jax.devices("cpu")[0]):
292+
for loaded_name, loaded_weight in pool.items():
293+
hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", loaded_name)
294+
if hf_pattern not in mappings:
295+
logger.warning(
296+
f"No mapping found for checkpoint tensor: {loaded_name}. Skipping."
297+
)
298+
continue
299+
300+
jax_path_template, transform_fn, target_shape = mappings[
301+
hf_pattern]
302+
303+
layer_num_match = re.search(r"layers\.(\d+)", loaded_name)
304+
jax_path = jax_path_template
305+
if layer_num_match:
306+
jax_path = jax_path_template.replace(
307+
"*", layer_num_match.group(1))
308+
309+
model_weight = get_param(model_params, jax_path)
310+
311+
prepared_weight = loaded_weight
312+
if isinstance(loaded_weight, tuple):
313+
# Loaded weight is an MXFP4 tuple
314+
blocks_u8, scales_u8 = loaded_weight
315+
# Quantized param (QArray): set qvalue/scale directly and skip regular path
316+
if hasattr(model_weight, "array"): # QArray check
317+
codes_fp32_t, scales_fp32_t = unpack_mxfp4_to_fp32(blocks_u8, scales_u8)
318+
self._load_mxfp4(
319+
model_weight=model_weight,
320+
codes_fp32_t=codes_fp32_t,
321+
scales_fp32_t=scales_fp32_t,
322+
transform_fn=transform_fn,
323+
)
324+
if is_verbose:
325+
print_param_info(model_weight, loaded_name)
326+
continue
327+
# Not a QArray: dequantize MXFP4 to BF16 full weights
328+
prepared_weight = dequant_mxfp4_to_bf16(blocks_u8, scales_u8)
329+
330+
# Single regular-tensor load call (BF16 or dequantized MXFP4)
331+
cast_type = model_weight.value.dtype
332+
self._load_regular_param(
333+
model_weight=model_weight,
334+
loaded_weight=prepared_weight,
335+
cast_type=cast_type,
336+
transform_fn=transform_fn,
337+
target_shape=target_shape,
338+
jax_path_template=jax_path_template,
339+
)
340+
341+
if is_verbose:
342+
print_param_info(model_weight, loaded_name)
343+
344+
nnx.update(self, model_params)
345+
346+
def _build_mxfp4_pool(self, names_and_weights_generator, mappings):
347+
"""Collect MXFP4 weights into a pool keeping tuples (blocks_u8, scales_u8).
348+
349+
Combines *_blocks and *_scales pairs and stores uint8 tensors together.
350+
Non-expert tensors are kept as-is. Raises if any expert bundle is incomplete.
351+
"""
352+
pool: dict[str, torch.Tensor | tuple] = {}
275353
pending_experts: dict[str, dict[str, torch.Tensor]] = {}
276354
for loaded_name, loaded_weight in names_and_weights_generator:
277355
if loaded_name.endswith("_blocks") or loaded_name.endswith("_scales"):
@@ -282,14 +360,12 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
282360
else:
283361
entry["scales"] = loaded_weight
284362

285-
# If we have both parts, dequantize now and place into the main pool
363+
# If we have both parts, place raw pair into the main pool
286364
if "blocks" in entry and "scales" in entry:
287365
hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", base)
288366
if hf_pattern not in mappings:
289-
logger.warning(f"No mapping found for expert tensor: {base}. Skipping.")
290-
else:
291-
deq = dequant_mxfp4_to_bf16(entry["blocks"], entry["scales"]) # torch.bfloat16
292-
pool[base] = deq
367+
raise ValueError(f"No mapping found for expert tensor: {base}")
368+
pool[base] = (entry["blocks"], entry["scales"])
293369
# Remove from pending to free memory
294370
pending_experts.pop(base, None)
295371
else:
@@ -304,68 +380,82 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
304380
raise RuntimeError(
305381
"Incomplete MXFP4 expert bundle(s) encountered: " + ", ".join(details)
306382
)
383+
return pool
384+
385+
def _load_mxfp4(self,
386+
model_weight,
387+
codes_fp32_t,
388+
scales_fp32_t,
389+
transform_fn=None):
390+
"""Assign decoded MXFP4 codes/scales into a QArray (qvalue/scale)."""
391+
392+
qv = model_weight.array.qvalue
393+
sv = model_weight.array.scale
394+
q_dtype = qv.value.dtype
395+
s_dtype = sv.value.dtype
396+
397+
exp_q_shape = tuple(qv.value.shape)
398+
exp_s_shape = tuple(sv.value.shape)
399+
400+
# Apply optional transform (e.g., swap last two dims) before conversion
401+
if transform_fn is not None:
402+
codes_fp32_t = transform_fn(codes_fp32_t, None)
403+
scales_fp32_t = transform_fn(scales_fp32_t, None)
404+
405+
# Convert from torch.Tensor to numpy before creating JAX arrays
406+
codes_fp32_t = codes_fp32_t.detach().cpu().numpy()
407+
scales_fp32_t = scales_fp32_t.detach().cpu().numpy()
408+
409+
codes_jnp = jnp.asarray(codes_fp32_t).astype(q_dtype)
410+
scales_jnp = jnp.asarray(scales_fp32_t).astype(s_dtype)
411+
412+
def get_q_slice(index):
413+
return codes_jnp[index]
414+
415+
def get_s_slice(index):
416+
return scales_jnp[index]
417+
418+
q_sharded = jax.make_array_from_callback(
419+
exp_q_shape, NamedSharding(self.mesh, P(*qv.sharding)), get_q_slice)
420+
s_sharded = jax.make_array_from_callback(
421+
exp_s_shape, NamedSharding(self.mesh, P(*sv.sharding)), get_s_slice)
422+
423+
model_weight.array.qvalue.value = q_sharded
424+
model_weight.array.scale.value = s_sharded
425+
426+
def _load_regular_param(self,
427+
model_weight,
428+
loaded_weight: torch.Tensor,
429+
cast_type,
430+
transform_fn,
431+
target_shape,
432+
jax_path_template: str):
433+
"""Assign a regular tensor (non-MXFP4) into the model param with transform applied."""
434+
if jax_path_template == "layers.*.attn.sinks_N":
435+
# Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
436+
weight_np = jnp.array(loaded_weight.to(torch.float32).numpy())
437+
else:
438+
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
439+
if torch_view_type:
440+
weight_np = jnp.array(loaded_weight.view(torch_view_type).numpy()).view(cast_type)
441+
else:
442+
raise ValueError(
443+
f"Unsupported dtype for tensor conversion: {cast_type}")
307444

308-
with jax.default_device(jax.devices("cpu")[0]):
309-
for loaded_name, loaded_weight in pool.items():
310-
hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", loaded_name)
311-
if hf_pattern not in mappings:
312-
logger.warning(
313-
f"No mapping found for checkpoint tensor: {loaded_name}. Skipping."
314-
)
315-
continue
316-
317-
jax_path_template, transform_fn, target_shape = mappings[
318-
hf_pattern]
319-
320-
layer_num_match = re.search(r"layers\.(\d+)", loaded_name)
321-
jax_path = jax_path_template
322-
if layer_num_match:
323-
jax_path = jax_path_template.replace(
324-
"*", layer_num_match.group(1))
325-
326-
model_weight = get_param(model_params, jax_path)
327-
cast_type = model_weight.value.dtype
328-
329-
if jax_path_template == "layers.*.attn.sinks_N":
330-
# Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
331-
weight_np = jnp.array(
332-
loaded_weight.to(torch.float32).numpy())
333-
else:
334-
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
335-
if torch_view_type:
336-
# Avoid unnecessary upcasting and mem copy by viewing the tensor's
337-
# raw data as integers before converting to a JAX array.
338-
weight_np = jnp.array(
339-
loaded_weight.view(torch_view_type).numpy()).view(
340-
cast_type)
341-
else:
342-
raise ValueError(
343-
f"Unsupported dtype for tensor conversion: {cast_type}"
344-
)
345-
346-
if transform_fn:
347-
transformed_weight = transform_fn(weight_np, target_shape)
348-
else:
349-
transformed_weight = weight_np
350-
351-
if model_weight.value.shape != transformed_weight.shape:
352-
raise ValueError(
353-
f"Shape mismatch for '{jax_path}': Model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transformation."
354-
)
355-
356-
def get_slice(index):
357-
return transformed_weight[index]
445+
transformed_weight = transform_fn(weight_np, target_shape) if transform_fn else weight_np
358446

359-
sharded_array = jax.make_array_from_callback(
360-
transformed_weight.shape,
361-
NamedSharding(self.mesh, P(*model_weight.sharding)),
362-
get_slice)
363-
model_weight.value = sharded_array
447+
if model_weight.value.shape != transformed_weight.shape:
448+
raise ValueError(
449+
f"Shape mismatch: model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transform.")
364450

365-
if is_verbose:
366-
print_param_info(model_weight, loaded_name)
451+
def get_slice(index):
452+
return transformed_weight[index]
367453

368-
nnx.update(self, model_params)
454+
sharded_array = jax.make_array_from_callback(
455+
transformed_weight.shape,
456+
NamedSharding(self.mesh, P(*model_weight.sharding)),
457+
get_slice)
458+
model_weight.value = sharded_array
369459

370460
def __call__(
371461
self,

tpu_inference/models/jax/utils/quantization/mxfp4_utils.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
MXFP4_BLOCK_SIZE: int = 32
88
# Exponent-only e8m0 scale bias used by MXFP4 scales
99
MXFP4_SCALE_BIAS: int = 127
10+
# Name used in config.json quantization_config["quant_method"]
11+
MXFP4_QUANT_METHOD: str = "mxfp4"
1012

1113

1214
# Precompute a small LUT once; move to device on demand (cheap 16-element copy)
@@ -16,7 +18,7 @@
1618
], dtype=torch.float32)
1719

1820

19-
def _unpack_uint8_to_fp4_values(packed: torch.Tensor) -> torch.Tensor:
21+
def unpack_mxfp4(packed: torch.Tensor) -> torch.Tensor:
2022
"""Unpack uint8 (..., 16) -> fp4 values (..., 32) using low->high nibble order.
2123
2224
Returns float32 values corresponding to FP4 codebook entries.
@@ -29,7 +31,7 @@ def _unpack_uint8_to_fp4_values(packed: torch.Tensor) -> torch.Tensor:
2931
return lut[idx.long()]
3032

3133

32-
def _e8m0_to_scale(u8: torch.Tensor) -> torch.Tensor:
34+
def e8m0_to_fp32(u8: torch.Tensor) -> torch.Tensor:
3335
"""Convert e8m0 uint8 exponents to power-of-two scales using MXFP4_SCALE_BIAS.
3436
3537
Uses ldexp for exact power-of-two scaling: 1.0 * 2**(u8 - bias).
@@ -43,17 +45,38 @@ def dequant_mxfp4_to_bf16(blocks_u8: torch.Tensor, scales_u8: torch.Tensor) -> t
4345
"""Dequantize MXFP4 blocks/scales into bfloat16 values.
4446
4547
Args:
46-
blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte holds 2 FP4 codes.
47-
scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per 32-value block.
48+
blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte holds 2 FP4 codes.
49+
scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per 32-value block.
4850
4951
Returns:
50-
torch.bfloat16 tensor with last logical dimension K = Kb * 32.
52+
torch.bfloat16 tensor with last logical dimension K = Kb * 32.
5153
"""
5254
if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8:
53-
raise ValueError(f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}")
55+
raise ValueError(f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}")
5456
# Unpack FP4 codes to float32 values [..., Kb, 32]
55-
fp4_vals = _unpack_uint8_to_fp4_values(blocks_u8) # (..., Kb, 32)
57+
fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32)
5658
# Compute power-of-two scales and apply per block
57-
scales = _e8m0_to_scale(scales_u8).unsqueeze(-1) # (..., Kb, 1)
59+
scales = e8m0_to_fp32(scales_u8).unsqueeze(-1) # (..., Kb, 1)
5860
full = (fp4_vals * scales).reshape(*fp4_vals.shape[:-2], fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE)
5961
return full.to(torch.bfloat16)
62+
63+
64+
def unpack_mxfp4_to_fp32(blocks_u8: torch.Tensor, scales_u8: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
65+
"""Decode MXFP4 packed blocks and e8m0 scales to float32 codes and scales.
66+
67+
Args:
68+
blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte packs two FP4 codes.
69+
scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per block.
70+
71+
Returns:
72+
(codes_fp32, scales_fp32), where
73+
- codes_fp32 has shape [..., Kb*32] and dtype float32
74+
- scales_fp32 has shape [..., Kb] and dtype float32
75+
"""
76+
if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8:
77+
raise ValueError(
78+
f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}")
79+
fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32) float32
80+
codes_fp32 = fp4_vals.reshape(*fp4_vals.shape[:-2], fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE)
81+
scales_fp32 = e8m0_to_fp32(scales_u8) # (..., Kb) float32
82+
return codes_fp32, scales_fp32

tpu_inference/models/jax/utils/quantization/quantization_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,27 @@
7171
}
7272
}
7373

74+
# Default Qwix config for GPT-OSS MXFP4 checkpoints.
75+
# Notes:
76+
# - We quantize only the MoE expert weights by default (router stays in BF16).
77+
# - We use Qwix's abstract-model path so weights can be set directly into QArray
78+
# fields during weight loading (similar to DeepSeek's flow).
79+
# - Activation quantization is not set but Qwix would pickup MoE sum if activated
80+
DEFAULT_GPT_OSS_FP4_CONFIG = {
81+
"qwix": {
82+
"use_abstract_model": True,
83+
"scale_dtype": "bfloat16",
84+
"rules": [
85+
{
86+
"module_path": ".*custom_module",
87+
"weight_qtype": "float4_e2m1fn",
88+
"act_qtype": None,
89+
"tile_size": 32,
90+
},
91+
],
92+
}
93+
}
94+
7495

7596
def parse_qwix_config_to_rules(
7697
qwix_config: List[dict]) -> List[qwix.QuantizationRule]:
@@ -400,6 +421,9 @@ def get_default_qwix_quantization_config(
400421
return DEFAULT_DEEPSEEK_FP8_CONFIG
401422
elif model_type == "llama4" and quant_method == "compressed-tensors":
402423
return DEFAULT_LLAMA4_FP8_CONFIG
424+
# MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
425+
elif model_type == "gpt_oss" and quant_method == "mxfp4":
426+
return DEFAULT_GPT_OSS_FP4_CONFIG
403427

404428

405429
def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):

0 commit comments

Comments
 (0)