Skip to content

Commit e9f570c

Browse files
committed
[GPT OSS] Add support for both BF16 and MXFP4
Signed-off-by: Jordan Dotzel <amishacorns@users.noreply.github.com>
1 parent 8e19d70 commit e9f570c

File tree

3 files changed

+304
-37
lines changed

3 files changed

+304
-37
lines changed

tpu_inference/models/jax/gpt_oss.py

Lines changed: 173 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
1919
from tpu_inference.layers.jax.transformer_block import TransformerBlock
2020
from tpu_inference.logger import init_logger
21+
from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
22+
MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32)
2123
from tpu_inference.models.jax.utils.weight_utils import (
2224
get_param, model_weights_generator, print_param_info)
2325

@@ -185,13 +187,23 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
185187
"""Loads and transforms all weights from a checkpoint"""
186188
self.rng = nnx.Rngs(rng)
187189

190+
# Determine quantization method from HF config (config.json)
191+
quant_method = (self.hf_config.quantization_config["quant_method"]
192+
if hasattr(self.hf_config, "quantization_config") else
193+
None)
194+
188195
# Format: 'hf_key': ('jax_model_path', transform_function, target_shape)
189196
transforms = {
190197
"transpose_reshape": lambda w, shape: w.T.reshape(shape),
191198
"reshape": lambda b, shape: b.reshape(shape),
192199
"transpose": lambda w, _: w.T,
200+
"swap_last2": lambda w, _: w.swapaxes(-1, -2),
193201
}
194202

203+
# MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
204+
swap_mlp_transform = transforms[
205+
"swap_last2"] if quant_method == MXFP4_QUANT_METHOD else None
206+
195207
mappings = {
196208
# Embeddings, Norms, and LM Head
197209
"model.embed_tokens.weight": ("embedder.input_embedding_table_VD",
@@ -247,11 +259,13 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
247259
"model.layers.*.mlp.router.bias":
248260
("layers.*.custom_module.router.bias_E", None, None),
249261
"model.layers.*.mlp.experts.gate_up_proj":
250-
("layers.*.custom_module.mlp1_weight_EDF2", None, None),
262+
("layers.*.custom_module.mlp1_weight_EDF2", swap_mlp_transform,
263+
None),
251264
"model.layers.*.mlp.experts.gate_up_proj_bias":
252265
("layers.*.custom_module.mlp1_bias_EF2", None, None),
253266
"model.layers.*.mlp.experts.down_proj":
254-
("layers.*.custom_module.mlp2_weight_EFD", None, None),
267+
("layers.*.custom_module.mlp2_weight_EFD", swap_mlp_transform,
268+
None),
255269
"model.layers.*.mlp.experts.down_proj_bias":
256270
("layers.*.custom_module.mlp2_bias_ED", None, None),
257271
}
@@ -265,8 +279,16 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
265279
framework="pt",
266280
download_dir=self.vllm_config.load_config.download_dir)
267281

282+
# Build a pool of weights with MXFP4 experts combined if neededs
283+
pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
284+
names_and_weights_generator,
285+
mappings) if quant_method == MXFP4_QUANT_METHOD else {
286+
loaded_name: loaded_weight
287+
for loaded_name, loaded_weight in names_and_weights_generator
288+
})
289+
268290
with jax.default_device(jax.devices("cpu")[0]):
269-
for loaded_name, loaded_weight in names_and_weights_generator:
291+
for loaded_name, loaded_weight in pool.items():
270292
hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", loaded_name)
271293
if hf_pattern not in mappings:
272294
logger.warning(
@@ -284,48 +306,162 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
284306
"*", layer_num_match.group(1))
285307

286308
model_weight = get_param(model_params, jax_path)
287-
cast_type = model_weight.value.dtype
288309

289-
if jax_path_template == "layers.*.attn.sinks_N":
290-
# Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
291-
weight_np = jnp.array(
292-
loaded_weight.to(torch.float32).numpy())
293-
else:
294-
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
295-
if torch_view_type:
296-
# Avoid unnecessary upcasting and mem copy by viewing the tensor's
297-
# raw data as integers before converting to a JAX array.
298-
weight_np = jnp.array(
299-
loaded_weight.view(torch_view_type).numpy()).view(
300-
cast_type)
301-
else:
302-
raise ValueError(
303-
f"Unsupported dtype for tensor conversion: {cast_type}"
310+
prepared_weight = loaded_weight
311+
if isinstance(loaded_weight, tuple):
312+
# Loaded weight is an MXFP4 tuple
313+
blocks_u8, scales_u8 = loaded_weight
314+
# Quantized param (QArray): set qvalue/scale directly and skip regular path
315+
if hasattr(model_weight, "array"): # QArray check
316+
codes_fp32_t, scales_fp32_t = unpack_mxfp4_to_fp32(
317+
blocks_u8, scales_u8)
318+
self._load_mxfp4(
319+
model_weight=model_weight,
320+
codes_fp32_t=codes_fp32_t,
321+
scales_fp32_t=scales_fp32_t,
322+
transform_fn=transform_fn,
304323
)
324+
if is_verbose:
325+
print_param_info(model_weight, loaded_name)
326+
continue
327+
# Not a QArray: dequantize MXFP4 to BF16 full weights
328+
prepared_weight = dequant_mxfp4_to_bf16(
329+
blocks_u8, scales_u8)
330+
331+
# Single regular-tensor load call (BF16 or dequantized MXFP4)
332+
cast_type = model_weight.value.dtype
333+
self._load_regular_param(
334+
model_weight=model_weight,
335+
loaded_weight=prepared_weight,
336+
cast_type=cast_type,
337+
transform_fn=transform_fn,
338+
target_shape=target_shape,
339+
jax_path_template=jax_path_template,
340+
)
305341

306-
if transform_fn:
307-
transformed_weight = transform_fn(weight_np, target_shape)
308-
else:
309-
transformed_weight = weight_np
342+
if is_verbose:
343+
print_param_info(model_weight, loaded_name)
310344

311-
if model_weight.value.shape != transformed_weight.shape:
312-
raise ValueError(
313-
f"Shape mismatch for '{jax_path}': Model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transformation."
314-
)
345+
nnx.update(self, model_params)
315346

316-
def get_slice(index):
317-
return transformed_weight[index]
347+
def _build_mxfp4_pool(self, names_and_weights_generator, mappings):
348+
"""Collect MXFP4 weights into a pool keeping tuples (blocks_u8, scales_u8).
349+
350+
Combines *_blocks and *_scales pairs and stores uint8 tensors together.
351+
Non-expert tensors are kept as-is. Raises if any expert bundle is incomplete.
352+
"""
353+
pool: dict[str, torch.Tensor | tuple] = {}
354+
pending_experts: dict[str, dict[str, torch.Tensor]] = {}
355+
for loaded_name, loaded_weight in names_and_weights_generator:
356+
if loaded_name.endswith("_blocks") or loaded_name.endswith(
357+
"_scales"):
358+
base = loaded_name[:-7]
359+
entry = pending_experts.setdefault(base, {})
360+
if loaded_name.endswith("_blocks"):
361+
entry["blocks"] = loaded_weight
362+
else:
363+
entry["scales"] = loaded_weight
318364

319-
sharded_array = jax.make_array_from_callback(
320-
transformed_weight.shape,
321-
NamedSharding(self.mesh, P(*model_weight.sharding)),
322-
get_slice)
323-
model_weight.value = sharded_array
365+
# If we have both parts, place raw pair into the main pool
366+
if "blocks" in entry and "scales" in entry:
367+
hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", base)
368+
if hf_pattern not in mappings:
369+
raise ValueError(
370+
f"No mapping found for expert tensor: {base}")
371+
pool[base] = (entry["blocks"], entry["scales"])
372+
# Remove from pending to free memory
373+
pending_experts.pop(base, None)
374+
else:
375+
pool[loaded_name] = loaded_weight
376+
377+
# Enforce completeness of expert bundles
378+
if pending_experts:
379+
details = []
380+
for base, entry in pending_experts.items():
381+
missing = [k for k in ("blocks", "scales") if k not in entry]
382+
details.append(
383+
f"{base} (missing: {', '.join(missing) if missing else 'unknown'})"
384+
)
385+
raise RuntimeError(
386+
"Incomplete MXFP4 expert bundle(s) encountered: " +
387+
", ".join(details))
388+
return pool
389+
390+
def _load_mxfp4(self,
391+
model_weight,
392+
codes_fp32_t,
393+
scales_fp32_t,
394+
transform_fn=None):
395+
"""Assign decoded MXFP4 codes/scales into a QArray (qvalue/scale)."""
396+
397+
qv = model_weight.array.qvalue
398+
sv = model_weight.array.scale
399+
q_dtype = qv.value.dtype
400+
s_dtype = sv.value.dtype
401+
402+
exp_q_shape = tuple(qv.value.shape)
403+
exp_s_shape = tuple(sv.value.shape)
404+
405+
# Apply optional transform (e.g., swap last two dims) before conversion
406+
if transform_fn is not None:
407+
codes_fp32_t = transform_fn(codes_fp32_t, None)
408+
scales_fp32_t = transform_fn(scales_fp32_t, None)
409+
410+
# Convert from torch.Tensor to numpy before creating JAX arrays
411+
codes_fp32_t = codes_fp32_t.detach().cpu().numpy()
412+
scales_fp32_t = scales_fp32_t.detach().cpu().numpy()
413+
414+
codes_jnp = jnp.asarray(codes_fp32_t).astype(q_dtype)
415+
scales_jnp = jnp.asarray(scales_fp32_t).astype(s_dtype)
416+
417+
def get_q_slice(index):
418+
return codes_jnp[index]
419+
420+
def get_s_slice(index):
421+
return scales_jnp[index]
422+
423+
q_sharded = jax.make_array_from_callback(
424+
exp_q_shape, NamedSharding(self.mesh, P(*qv.sharding)),
425+
get_q_slice)
426+
s_sharded = jax.make_array_from_callback(
427+
exp_s_shape, NamedSharding(self.mesh, P(*sv.sharding)),
428+
get_s_slice)
429+
430+
model_weight.array.qvalue.value = q_sharded
431+
model_weight.array.scale.value = s_sharded
432+
433+
def _load_regular_param(self, model_weight, loaded_weight: torch.Tensor,
434+
cast_type, transform_fn, target_shape,
435+
jax_path_template: str):
436+
"""Assign a regular tensor (non-MXFP4) into the model param with transform applied."""
437+
if jax_path_template == "layers.*.attn.sinks_N":
438+
# Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
439+
weight_np = jnp.array(loaded_weight.to(torch.float32).numpy())
440+
else:
441+
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
442+
if torch_view_type:
443+
weight_np = jnp.array(
444+
loaded_weight.view(torch_view_type).numpy()).view(
445+
cast_type)
446+
else:
447+
raise ValueError(
448+
f"Unsupported dtype for tensor conversion: {cast_type}")
449+
450+
transformed_weight = transform_fn(
451+
weight_np, target_shape) if transform_fn else weight_np
452+
453+
if model_weight.value.shape != transformed_weight.shape:
454+
raise ValueError(
455+
f"Shape mismatch: model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transform."
456+
)
324457

325-
if is_verbose:
326-
print_param_info(model_weight, loaded_name)
458+
def get_slice(index):
459+
return transformed_weight[index]
327460

328-
nnx.update(self, model_params)
461+
sharded_array = jax.make_array_from_callback(
462+
transformed_weight.shape,
463+
NamedSharding(self.mesh, P(*model_weight.sharding)), get_slice)
464+
model_weight.value = sharded_array
329465

330466
def __call__(
331467
self,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
6+
# MXFP4 constants
7+
MXFP4_BLOCK_SIZE: int = 32
8+
# Exponent-only e8m0 scale bias used by MXFP4 scales
9+
MXFP4_SCALE_BIAS: int = 127
10+
# Name used in config.json quantization_config["quant_method"]
11+
MXFP4_QUANT_METHOD: str = "mxfp4"
12+
13+
# Precompute a small LUT once; move to device on demand (cheap 16-element copy)
14+
FP4_LUT = torch.tensor(
15+
[
16+
0.0,
17+
0.5,
18+
1.0,
19+
1.5,
20+
2.0,
21+
3.0,
22+
4.0,
23+
6.0, # 0b0000-0b0111
24+
-0.0,
25+
-0.5,
26+
-1.0,
27+
-1.5,
28+
-2.0,
29+
-3.0,
30+
-4.0,
31+
-6.0, # 0b1000-0b1111
32+
],
33+
dtype=torch.float32)
34+
35+
36+
def unpack_mxfp4(packed: torch.Tensor) -> torch.Tensor:
37+
"""Unpack uint8 (..., 16) -> fp4 values (..., 32) using low->high nibble order.
38+
39+
Returns float32 values corresponding to FP4 codebook entries.
40+
"""
41+
assert packed.dtype == torch.uint8
42+
low = packed & 0x0F
43+
high = (packed >> 4) & 0x0F
44+
idx = torch.stack([low, high], dim=-1).flatten(-2)
45+
lut = FP4_LUT.to(packed.device)
46+
return lut[idx.long()]
47+
48+
49+
def e8m0_to_fp32(u8: torch.Tensor) -> torch.Tensor:
50+
"""Convert e8m0 uint8 exponents to power-of-two scales using MXFP4_SCALE_BIAS.
51+
52+
Uses ldexp for exact power-of-two scaling: 1.0 * 2**(u8 - bias).
53+
"""
54+
exponents = (u8.to(torch.int32) - int(MXFP4_SCALE_BIAS)).to(torch.int32)
55+
ones = torch.ones_like(u8, dtype=torch.float32)
56+
return torch.ldexp(ones, exponents)
57+
58+
59+
def dequant_mxfp4_to_bf16(blocks_u8: torch.Tensor,
60+
scales_u8: torch.Tensor) -> torch.Tensor:
61+
"""Dequantize MXFP4 blocks/scales into bfloat16 values.
62+
63+
Args:
64+
blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte holds 2 FP4 codes.
65+
scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per 32-value block.
66+
67+
Returns:
68+
torch.bfloat16 tensor with last logical dimension K = Kb * 32.
69+
"""
70+
if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8:
71+
raise ValueError(
72+
f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}"
73+
)
74+
# Unpack FP4 codes to float32 values [..., Kb, 32]
75+
fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32)
76+
# Compute power-of-two scales and apply per block
77+
scales = e8m0_to_fp32(scales_u8).unsqueeze(-1) # (..., Kb, 1)
78+
full = (fp4_vals * scales).reshape(*fp4_vals.shape[:-2],
79+
fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE)
80+
return full.to(torch.bfloat16)
81+
82+
83+
def unpack_mxfp4_to_fp32(
84+
blocks_u8: torch.Tensor,
85+
scales_u8: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
86+
"""Decode MXFP4 packed blocks and e8m0 scales to float32 codes and scales.
87+
88+
Args:
89+
blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte packs two FP4 codes.
90+
scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per block.
91+
92+
Returns:
93+
(codes_fp32, scales_fp32), where
94+
- codes_fp32 has shape [..., Kb*32] and dtype float32
95+
- scales_fp32 has shape [..., Kb] and dtype float32
96+
"""
97+
if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8:
98+
raise ValueError(
99+
f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}"
100+
)
101+
fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32) float32
102+
codes_fp32 = fp4_vals.reshape(*fp4_vals.shape[:-2],
103+
fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE)
104+
scales_fp32 = e8m0_to_fp32(scales_u8) # (..., Kb) float32
105+
return codes_fp32, scales_fp32

0 commit comments

Comments
 (0)