Skip to content

Commit 2f2730f

Browse files
committed
[GPT OSS] Add support for both BF16 and MXFP4
Signed-off-by: Jordan Dotzel <amishacorns@users.noreply.github.com>
1 parent 6e96676 commit 2f2730f

File tree

3 files changed

+256
-88
lines changed

3 files changed

+256
-88
lines changed

tpu_inference/models/jax/gpt_oss.py

Lines changed: 169 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
from jax.sharding import Mesh, NamedSharding
1111
from jax.sharding import PartitionSpec as P
1212
from vllm.config import VllmConfig
13-
from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
14-
dequant_mxfp4_to_bf16,
15-
)
1613

1714
from tpu_inference.layers.jax.attention.gpt_oss_attention import (
1815
AttentionMetadata, GptOssAttention)
@@ -21,6 +18,8 @@
2118
from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
2219
from tpu_inference.layers.jax.transformer_block import TransformerBlock
2320
from tpu_inference.logger import init_logger
21+
from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
22+
MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32)
2423
from tpu_inference.models.jax.utils.weight_utils import (
2524
get_param, model_weights_generator, print_param_info)
2625

@@ -188,6 +187,11 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
188187
"""Loads and transforms all weights from a checkpoint"""
189188
self.rng = nnx.Rngs(rng)
190189

190+
# Determine quantization method from HF config (config.json)
191+
quant_method = (self.hf_config.quantization_config["quant_method"]
192+
if hasattr(self.hf_config, "quantization_config") else
193+
None)
194+
191195
# Format: 'hf_key': ('jax_model_path', transform_function, target_shape)
192196
transforms = {
193197
"transpose_reshape": lambda w, shape: w.T.reshape(shape),
@@ -196,6 +200,10 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
196200
"swap_last2": lambda w, _: w.swapaxes(-1, -2),
197201
}
198202

203+
# MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
204+
swap_mlp_transform = transforms[
205+
"swap_last2"] if quant_method == MXFP4_QUANT_METHOD else None
206+
199207
mappings = {
200208
# Embeddings, Norms, and LM Head
201209
"model.embed_tokens.weight": ("embedder.input_embedding_table_VD",
@@ -251,11 +259,13 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
251259
"model.layers.*.mlp.router.bias":
252260
("layers.*.custom_module.router.bias_E", None, None),
253261
"model.layers.*.mlp.experts.gate_up_proj":
254-
("layers.*.custom_module.mlp1_weight_EDF2", transforms["swap_last2"], None),
262+
("layers.*.custom_module.mlp1_weight_EDF2", swap_mlp_transform,
263+
None),
255264
"model.layers.*.mlp.experts.gate_up_proj_bias":
256265
("layers.*.custom_module.mlp1_bias_EF2", None, None),
257266
"model.layers.*.mlp.experts.down_proj":
258-
("layers.*.custom_module.mlp2_weight_EFD", transforms["swap_last2"], None),
267+
("layers.*.custom_module.mlp2_weight_EFD", swap_mlp_transform,
268+
None),
259269
"model.layers.*.mlp.experts.down_proj_bias":
260270
("layers.*.custom_module.mlp2_bias_ED", None, None),
261271
}
@@ -269,41 +279,13 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
269279
framework="pt",
270280
download_dir=self.vllm_config.load_config.download_dir)
271281

272-
# Single pass: build a unified pool. Combine MXFP4 expert blocks/scales
273-
# into a dequantized bf16 tensor as soon as both are seen.
274-
pool: dict[str, torch.Tensor] = {}
275-
pending_experts: dict[str, dict[str, torch.Tensor]] = {}
276-
for loaded_name, loaded_weight in names_and_weights_generator:
277-
if loaded_name.endswith("_blocks") or loaded_name.endswith("_scales"):
278-
base = loaded_name[:-7]
279-
entry = pending_experts.setdefault(base, {})
280-
if loaded_name.endswith("_blocks"):
281-
entry["blocks"] = loaded_weight
282-
else:
283-
entry["scales"] = loaded_weight
284-
285-
# If we have both parts, dequantize now and place into the main pool
286-
if "blocks" in entry and "scales" in entry:
287-
hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", base)
288-
if hf_pattern not in mappings:
289-
logger.warning(f"No mapping found for expert tensor: {base}. Skipping.")
290-
else:
291-
deq = dequant_mxfp4_to_bf16(entry["blocks"], entry["scales"]) # torch.bfloat16
292-
pool[base] = deq
293-
# Remove from pending to free memory
294-
pending_experts.pop(base, None)
295-
else:
296-
pool[loaded_name] = loaded_weight
297-
298-
# Enforce completeness of expert bundles
299-
if pending_experts:
300-
details = []
301-
for base, entry in pending_experts.items():
302-
missing = [k for k in ("blocks", "scales") if k not in entry]
303-
details.append(f"{base} (missing: {', '.join(missing) if missing else 'unknown'})")
304-
raise RuntimeError(
305-
"Incomplete MXFP4 expert bundle(s) encountered: " + ", ".join(details)
306-
)
282+
# Build a pool of weights with MXFP4 experts combined if neededs
283+
pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
284+
names_and_weights_generator,
285+
mappings) if quant_method == MXFP4_QUANT_METHOD else {
286+
loaded_name: loaded_weight
287+
for loaded_name, loaded_weight in names_and_weights_generator
288+
})
307289

308290
with jax.default_device(jax.devices("cpu")[0]):
309291
for loaded_name, loaded_weight in pool.items():
@@ -324,48 +306,162 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
324306
"*", layer_num_match.group(1))
325307

326308
model_weight = get_param(model_params, jax_path)
327-
cast_type = model_weight.value.dtype
328309

329-
if jax_path_template == "layers.*.attn.sinks_N":
330-
# Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
331-
weight_np = jnp.array(
332-
loaded_weight.to(torch.float32).numpy())
333-
else:
334-
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
335-
if torch_view_type:
336-
# Avoid unnecessary upcasting and mem copy by viewing the tensor's
337-
# raw data as integers before converting to a JAX array.
338-
weight_np = jnp.array(
339-
loaded_weight.view(torch_view_type).numpy()).view(
340-
cast_type)
341-
else:
342-
raise ValueError(
343-
f"Unsupported dtype for tensor conversion: {cast_type}"
310+
prepared_weight = loaded_weight
311+
if isinstance(loaded_weight, tuple):
312+
# Loaded weight is an MXFP4 tuple
313+
blocks_u8, scales_u8 = loaded_weight
314+
# Quantized param (QArray): set qvalue/scale directly and skip regular path
315+
if hasattr(model_weight, "array"): # QArray check
316+
codes_fp32_t, scales_fp32_t = unpack_mxfp4_to_fp32(
317+
blocks_u8, scales_u8)
318+
self._load_mxfp4(
319+
model_weight=model_weight,
320+
codes_fp32_t=codes_fp32_t,
321+
scales_fp32_t=scales_fp32_t,
322+
transform_fn=transform_fn,
344323
)
324+
if is_verbose:
325+
print_param_info(model_weight, loaded_name)
326+
continue
327+
# Not a QArray: dequantize MXFP4 to BF16 full weights
328+
prepared_weight = dequant_mxfp4_to_bf16(
329+
blocks_u8, scales_u8)
330+
331+
# Single regular-tensor load call (BF16 or dequantized MXFP4)
332+
cast_type = model_weight.value.dtype
333+
self._load_regular_param(
334+
model_weight=model_weight,
335+
loaded_weight=prepared_weight,
336+
cast_type=cast_type,
337+
transform_fn=transform_fn,
338+
target_shape=target_shape,
339+
jax_path_template=jax_path_template,
340+
)
341+
342+
if is_verbose:
343+
print_param_info(model_weight, loaded_name)
344+
345+
nnx.update(self, model_params)
346+
347+
def _build_mxfp4_pool(self, names_and_weights_generator, mappings):
348+
"""Collect MXFP4 weights into a pool keeping tuples (blocks_u8, scales_u8).
345349
346-
if transform_fn:
347-
transformed_weight = transform_fn(weight_np, target_shape)
350+
Combines *_blocks and *_scales pairs and stores uint8 tensors together.
351+
Non-expert tensors are kept as-is. Raises if any expert bundle is incomplete.
352+
"""
353+
pool: dict[str, torch.Tensor | tuple] = {}
354+
pending_experts: dict[str, dict[str, torch.Tensor]] = {}
355+
for loaded_name, loaded_weight in names_and_weights_generator:
356+
if loaded_name.endswith("_blocks") or loaded_name.endswith(
357+
"_scales"):
358+
base = loaded_name[:-7]
359+
entry = pending_experts.setdefault(base, {})
360+
if loaded_name.endswith("_blocks"):
361+
entry["blocks"] = loaded_weight
348362
else:
349-
transformed_weight = weight_np
363+
entry["scales"] = loaded_weight
350364

351-
if model_weight.value.shape != transformed_weight.shape:
352-
raise ValueError(
353-
f"Shape mismatch for '{jax_path}': Model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transformation."
354-
)
365+
# If we have both parts, place raw pair into the main pool
366+
if "blocks" in entry and "scales" in entry:
367+
hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", base)
368+
if hf_pattern not in mappings:
369+
raise ValueError(
370+
f"No mapping found for expert tensor: {base}")
371+
pool[base] = (entry["blocks"], entry["scales"])
372+
# Remove from pending to free memory
373+
pending_experts.pop(base, None)
374+
else:
375+
pool[loaded_name] = loaded_weight
376+
377+
# Enforce completeness of expert bundles
378+
if pending_experts:
379+
details = []
380+
for base, entry in pending_experts.items():
381+
missing = [k for k in ("blocks", "scales") if k not in entry]
382+
details.append(
383+
f"{base} (missing: {', '.join(missing) if missing else 'unknown'})"
384+
)
385+
raise RuntimeError(
386+
"Incomplete MXFP4 expert bundle(s) encountered: " +
387+
", ".join(details))
388+
return pool
389+
390+
def _load_mxfp4(self,
391+
model_weight,
392+
codes_fp32_t,
393+
scales_fp32_t,
394+
transform_fn=None):
395+
"""Assign decoded MXFP4 codes/scales into a QArray (qvalue/scale)."""
396+
397+
qv = model_weight.array.qvalue
398+
sv = model_weight.array.scale
399+
q_dtype = qv.value.dtype
400+
s_dtype = sv.value.dtype
401+
402+
exp_q_shape = tuple(qv.value.shape)
403+
exp_s_shape = tuple(sv.value.shape)
404+
405+
# Apply optional transform (e.g., swap last two dims) before conversion
406+
if transform_fn is not None:
407+
codes_fp32_t = transform_fn(codes_fp32_t, None)
408+
scales_fp32_t = transform_fn(scales_fp32_t, None)
409+
410+
# Convert from torch.Tensor to numpy before creating JAX arrays
411+
codes_fp32_t = codes_fp32_t.detach().cpu().numpy()
412+
scales_fp32_t = scales_fp32_t.detach().cpu().numpy()
413+
414+
codes_jnp = jnp.asarray(codes_fp32_t).astype(q_dtype)
415+
scales_jnp = jnp.asarray(scales_fp32_t).astype(s_dtype)
416+
417+
def get_q_slice(index):
418+
return codes_jnp[index]
419+
420+
def get_s_slice(index):
421+
return scales_jnp[index]
422+
423+
q_sharded = jax.make_array_from_callback(
424+
exp_q_shape, NamedSharding(self.mesh, P(*qv.sharding)),
425+
get_q_slice)
426+
s_sharded = jax.make_array_from_callback(
427+
exp_s_shape, NamedSharding(self.mesh, P(*sv.sharding)),
428+
get_s_slice)
429+
430+
model_weight.array.qvalue.value = q_sharded
431+
model_weight.array.scale.value = s_sharded
432+
433+
def _load_regular_param(self, model_weight, loaded_weight: torch.Tensor,
434+
cast_type, transform_fn, target_shape,
435+
jax_path_template: str):
436+
"""Assign a regular tensor (non-MXFP4) into the model param with transform applied."""
437+
if jax_path_template == "layers.*.attn.sinks_N":
438+
# Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel
439+
weight_np = jnp.array(loaded_weight.to(torch.float32).numpy())
440+
else:
441+
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
442+
if torch_view_type:
443+
weight_np = jnp.array(
444+
loaded_weight.view(torch_view_type).numpy()).view(
445+
cast_type)
446+
else:
447+
raise ValueError(
448+
f"Unsupported dtype for tensor conversion: {cast_type}")
355449

356-
def get_slice(index):
357-
return transformed_weight[index]
450+
transformed_weight = transform_fn(
451+
weight_np, target_shape) if transform_fn else weight_np
358452

359-
sharded_array = jax.make_array_from_callback(
360-
transformed_weight.shape,
361-
NamedSharding(self.mesh, P(*model_weight.sharding)),
362-
get_slice)
363-
model_weight.value = sharded_array
453+
if model_weight.value.shape != transformed_weight.shape:
454+
raise ValueError(
455+
f"Shape mismatch: model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transform."
456+
)
364457

365-
if is_verbose:
366-
print_param_info(model_weight, loaded_name)
458+
def get_slice(index):
459+
return transformed_weight[index]
367460

368-
nnx.update(self, model_params)
461+
sharded_array = jax.make_array_from_callback(
462+
transformed_weight.shape,
463+
NamedSharding(self.mesh, P(*model_weight.sharding)), get_slice)
464+
model_weight.value = sharded_array
369465

370466
def __call__(
371467
self,

0 commit comments

Comments
 (0)