diff --git a/tpu_inference/layers/jax/moe/gpt_oss_moe.py b/tpu_inference/layers/jax/moe/gpt_oss_moe.py index 4fe26dda3..8ee1d33ff 100644 --- a/tpu_inference/layers/jax/moe/gpt_oss_moe.py +++ b/tpu_inference/layers/jax/moe/gpt_oss_moe.py @@ -67,6 +67,34 @@ def _swiglu(x: Float, alpha: Float, limit: Float) -> Float: return gated_activation * (x_linear + 1) +@dataclass(kw_only=True) +class CombineExperts(nnx.Module): + """Module for combining expert outputs with weighted sum.""" + dtype: jnp.dtype + + def __call__(self, down_proj_TED: Float, weights_TX: Float, + indices_TX: jax.Array) -> Float: + """Combines expert outputs using weighted sum. + + Args: + down_proj_TED: Expert outputs, shape (tokens, experts, hidden_dim) + weights_TX: Router weights, shape (tokens, experts_per_token) + indices_TX: Selected expert indices, shape (tokens, experts_per_token) + + Returns: + Combined output, shape (tokens, hidden_dim) + """ + with jax.named_scope("combine_experts"): + indices_for_gather = indices_TX[..., None] + gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED, + indices_for_gather, + axis=1) + output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED, + weights_TX) + + return output_TD.astype(self.dtype) + + @dataclass(kw_only=True) class GptOssMoE(nnx.Module): """ @@ -114,20 +142,16 @@ def __call__(self, x_TD: Float) -> Float: down_proj_TED += self.mlp2_bias_ED.value # Weighted sum of expert outputs - with jax.named_scope("sum"): - indices_for_gather = indices_TX[..., None] - gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED, - indices_for_gather, - axis=1) - output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED, - weights_TX) + output_TD = self.combine_experts(down_proj_TED, weights_TX, indices_TX) - return output_TD.astype(self.dtype) + return output_TD def __post_init__(self, rngs: nnx.Rngs): """Initializes all weights and biases for the MoE block.""" D, F, E = self.hidden_size, self.intermediate_size_moe, self.num_local_experts + self.combine_experts = CombineExperts(dtype=self.dtype) + # MLP #1 Weights (Combined Gate and Up-projection) and Bias self.mlp1_weight_EDF2 = create_param( rngs, diff --git a/tpu_inference/models/jax/gpt_oss.py b/tpu_inference/models/jax/gpt_oss.py index ebea0522b..45fc66a6e 100644 --- a/tpu_inference/models/jax/gpt_oss.py +++ b/tpu_inference/models/jax/gpt_oss.py @@ -18,6 +18,8 @@ from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter from tpu_inference.layers.jax.transformer_block import TransformerBlock from tpu_inference.logger import init_logger +from tpu_inference.models.jax.utils.quantization.mxfp4_utils import ( + MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32) from tpu_inference.models.jax.utils.weight_utils import ( get_param, model_weights_generator, print_param_info) @@ -80,7 +82,7 @@ def __init__(self, hidden_size=hidden_size, dtype=dtype, rngs=self.rng, - vd_sharding=(('data', 'model'), None), + vd_sharding=P(('data', 'model'), None), random_init=self.random_init, ) @@ -103,9 +105,9 @@ def __init__(self, query_tnh=P(None, 'model', None), keyvalue_skh=P(None, 'model', None), attn_o_tnh=P(None, 'model', None), - dnh_sharding=(None, 'model', None), - dkh_sharding=(None, 'model', None), - nhd_sharding=('model', None, None), + dnh_sharding=P(None, 'model', None), + dkh_sharding=P(None, 'model', None), + nhd_sharding=P('model', None, None), mesh=self.mesh, ) @@ -118,9 +120,9 @@ def __init__(self, dtype=dtype, router_act='softmax', random_init=self.random_init, - activation_ffw_td=('data', None), - ed_sharding=('model', None), - e_sharding=('model', ), + activation_ffw_td=P('data', None), + ed_sharding=P('model', None), + e_sharding=P('model'), ) moe_mlp = GptOssMoE( @@ -133,10 +135,10 @@ def __init__(self, router=router, swiglu_limit=swiglu_limit, # Sharding configuration - activation_ffw_td=('data', None), - edf_sharding=('model', None, None), - efd_sharding=('model', None, None), - ed_sharding=('model', None), + activation_ffw_td=P('data', None), + edf_sharding=P('model', None, None), + efd_sharding=P('model', None, None), + ed_sharding=P('model', None), ) block = TransformerBlock( @@ -146,6 +148,7 @@ def __init__(self, epsilon=rms_norm_eps, dtype=dtype, rngs=self.rng, + activation_ffw_td=P('data', None), ), pre_mlp_norm=RMSNorm( dims=hidden_size, @@ -153,6 +156,7 @@ def __init__(self, epsilon=rms_norm_eps, dtype=dtype, rngs=self.rng, + activation_ffw_td=P('data', None), ), attn=attn, custom_module=moe_mlp, @@ -165,6 +169,7 @@ def __init__(self, random_init=self.random_init, epsilon=rms_norm_eps, dtype=dtype, + activation_ffw_td=P('data', None), ) self.lm_head = LMhead( @@ -172,8 +177,8 @@ def __init__(self, hidden_size=hidden_size, dtype=dtype, rngs=self.rng, - vd_sharding=(('data', 'model'), None), - dv_sharding=(None, ('data', 'model')), + vd_sharding=P(('data', 'model'), None), + dv_sharding=P(None, ('data', 'model')), random_init=self.random_init, ) @@ -185,13 +190,23 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None): """Loads and transforms all weights from a checkpoint""" self.rng = nnx.Rngs(rng) + # Determine quantization method from HF config (config.json) + quant_method = (self.hf_config.quantization_config["quant_method"] + if hasattr(self.hf_config, "quantization_config") else + None) + # Format: 'hf_key': ('jax_model_path', transform_function, target_shape) transforms = { "transpose_reshape": lambda w, shape: w.T.reshape(shape), "reshape": lambda b, shape: b.reshape(shape), "transpose": lambda w, _: w.T, + "swap_last2": lambda w, _: w.swapaxes(-1, -2), } + # MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor + swap_mlp_transform = transforms[ + "swap_last2"] if quant_method == MXFP4_QUANT_METHOD else None + mappings = { # Embeddings, Norms, and LM Head "model.embed_tokens.weight": ("embedder.input_embedding_table_VD", @@ -247,11 +262,13 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None): "model.layers.*.mlp.router.bias": ("layers.*.custom_module.router.bias_E", None, None), "model.layers.*.mlp.experts.gate_up_proj": - ("layers.*.custom_module.mlp1_weight_EDF2", None, None), + ("layers.*.custom_module.mlp1_weight_EDF2", swap_mlp_transform, + None), "model.layers.*.mlp.experts.gate_up_proj_bias": ("layers.*.custom_module.mlp1_bias_EF2", None, None), "model.layers.*.mlp.experts.down_proj": - ("layers.*.custom_module.mlp2_weight_EFD", None, None), + ("layers.*.custom_module.mlp2_weight_EFD", swap_mlp_transform, + None), "model.layers.*.mlp.experts.down_proj_bias": ("layers.*.custom_module.mlp2_bias_ED", None, None), } @@ -265,8 +282,16 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None): framework="pt", download_dir=self.vllm_config.load_config.download_dir) + # Build a pool of weights with MXFP4 experts combined if neededs + pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool( + names_and_weights_generator, + mappings) if quant_method == MXFP4_QUANT_METHOD else { + loaded_name: loaded_weight + for loaded_name, loaded_weight in names_and_weights_generator + }) + with jax.default_device(jax.devices("cpu")[0]): - for loaded_name, loaded_weight in names_and_weights_generator: + for loaded_name, loaded_weight in pool.items(): hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", loaded_name) if hf_pattern not in mappings: logger.warning( @@ -284,48 +309,162 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None): "*", layer_num_match.group(1)) model_weight = get_param(model_params, jax_path) - cast_type = model_weight.value.dtype - if jax_path_template == "layers.*.attn.sinks_N": - # Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel - weight_np = jnp.array( - loaded_weight.to(torch.float32).numpy()) - else: - torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type)) - if torch_view_type: - # Avoid unnecessary upcasting and mem copy by viewing the tensor's - # raw data as integers before converting to a JAX array. - weight_np = jnp.array( - loaded_weight.view(torch_view_type).numpy()).view( - cast_type) - else: - raise ValueError( - f"Unsupported dtype for tensor conversion: {cast_type}" + prepared_weight = loaded_weight + if isinstance(loaded_weight, tuple): + # Loaded weight is an MXFP4 tuple + blocks_u8, scales_u8 = loaded_weight + # Quantized param (QArray): set qvalue/scale directly and skip regular path + if hasattr(model_weight, "array"): # QArray check + codes_fp32_t, scales_fp32_t = unpack_mxfp4_to_fp32( + blocks_u8, scales_u8) + self._load_mxfp4( + model_weight=model_weight, + codes_fp32_t=codes_fp32_t, + scales_fp32_t=scales_fp32_t, + transform_fn=transform_fn, ) + if is_verbose: + print_param_info(model_weight, loaded_name) + continue + # Not a QArray: dequantize MXFP4 to BF16 full weights + prepared_weight = dequant_mxfp4_to_bf16( + blocks_u8, scales_u8) + + # Single regular-tensor load call (BF16 or dequantized MXFP4) + cast_type = model_weight.value.dtype + self._load_regular_param( + model_weight=model_weight, + loaded_weight=prepared_weight, + cast_type=cast_type, + transform_fn=transform_fn, + target_shape=target_shape, + jax_path_template=jax_path_template, + ) - if transform_fn: - transformed_weight = transform_fn(weight_np, target_shape) - else: - transformed_weight = weight_np + if is_verbose: + print_param_info(model_weight, loaded_name) - if model_weight.value.shape != transformed_weight.shape: - raise ValueError( - f"Shape mismatch for '{jax_path}': Model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transformation." - ) + nnx.update(self, model_params) - def get_slice(index): - return transformed_weight[index] + def _build_mxfp4_pool(self, names_and_weights_generator, mappings): + """Collect MXFP4 weights into a pool keeping tuples (blocks_u8, scales_u8). + + Combines *_blocks and *_scales pairs and stores uint8 tensors together. + Non-expert tensors are kept as-is. Raises if any expert bundle is incomplete. + """ + pool: dict[str, torch.Tensor | tuple] = {} + pending_experts: dict[str, dict[str, torch.Tensor]] = {} + for loaded_name, loaded_weight in names_and_weights_generator: + if loaded_name.endswith("_blocks") or loaded_name.endswith( + "_scales"): + base = loaded_name[:-7] + entry = pending_experts.setdefault(base, {}) + if loaded_name.endswith("_blocks"): + entry["blocks"] = loaded_weight + else: + entry["scales"] = loaded_weight - sharded_array = jax.make_array_from_callback( - transformed_weight.shape, - NamedSharding(self.mesh, P(*model_weight.sharding)), - get_slice) - model_weight.value = sharded_array + # If we have both parts, place raw pair into the main pool + if "blocks" in entry and "scales" in entry: + hf_pattern = re.sub(r"layers\.(\d+)", "layers.*", base) + if hf_pattern not in mappings: + raise ValueError( + f"No mapping found for expert tensor: {base}") + pool[base] = (entry["blocks"], entry["scales"]) + # Remove from pending to free memory + pending_experts.pop(base, None) + else: + pool[loaded_name] = loaded_weight + + # Enforce completeness of expert bundles + if pending_experts: + details = [] + for base, entry in pending_experts.items(): + missing = [k for k in ("blocks", "scales") if k not in entry] + details.append( + f"{base} (missing: {', '.join(missing) if missing else 'unknown'})" + ) + raise RuntimeError( + "Incomplete MXFP4 expert bundle(s) encountered: " + + ", ".join(details)) + return pool + + def _load_mxfp4(self, + model_weight, + codes_fp32_t, + scales_fp32_t, + transform_fn=None): + """Assign decoded MXFP4 codes/scales into a QArray (qvalue/scale).""" + + qv = model_weight.array.qvalue + sv = model_weight.array.scale + q_dtype = qv.value.dtype + s_dtype = sv.value.dtype + + exp_q_shape = tuple(qv.value.shape) + exp_s_shape = tuple(sv.value.shape) + + # Apply optional transform (e.g., swap last two dims) before conversion + if transform_fn is not None: + codes_fp32_t = transform_fn(codes_fp32_t, None) + scales_fp32_t = transform_fn(scales_fp32_t, None) + + # Convert from torch.Tensor to numpy before creating JAX arrays + codes_fp32_t = codes_fp32_t.detach().cpu().numpy() + scales_fp32_t = scales_fp32_t.detach().cpu().numpy() + + codes_jnp = jnp.asarray(codes_fp32_t).astype(q_dtype) + scales_jnp = jnp.asarray(scales_fp32_t).astype(s_dtype) + + def get_q_slice(index): + return codes_jnp[index] + + def get_s_slice(index): + return scales_jnp[index] + + q_sharded = jax.make_array_from_callback( + exp_q_shape, NamedSharding(self.mesh, P(*qv.sharding)), + get_q_slice) + s_sharded = jax.make_array_from_callback( + exp_s_shape, NamedSharding(self.mesh, P(*sv.sharding)), + get_s_slice) + + model_weight.array.qvalue.value = q_sharded + model_weight.array.scale.value = s_sharded + + def _load_regular_param(self, model_weight, loaded_weight: torch.Tensor, + cast_type, transform_fn, target_shape, + jax_path_template: str): + """Assign a regular tensor (non-MXFP4) into the model param with transform applied.""" + if jax_path_template == "layers.*.attn.sinks_N": + # Checkpoint is bf16, but we have to upcast sinks to f32, as required by RPA_v3 kernel + weight_np = jnp.array(loaded_weight.to(torch.float32).numpy()) + else: + torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type)) + if torch_view_type: + weight_np = jnp.array( + loaded_weight.view(torch_view_type).numpy()).view( + cast_type) + else: + raise ValueError( + f"Unsupported dtype for tensor conversion: {cast_type}") + + transformed_weight = transform_fn( + weight_np, target_shape) if transform_fn else weight_np + + if model_weight.value.shape != transformed_weight.shape: + raise ValueError( + f"Shape mismatch: model expects {model_weight.value.shape}, but got {transformed_weight.shape} after transform." + ) - if is_verbose: - print_param_info(model_weight, loaded_name) + def get_slice(index): + return transformed_weight[index] - nnx.update(self, model_params) + sharded_array = jax.make_array_from_callback( + transformed_weight.shape, + NamedSharding(self.mesh, P(*model_weight.sharding)), get_slice) + model_weight.value = sharded_array def __call__( self, diff --git a/tpu_inference/models/jax/utils/quantization/mxfp4_utils.py b/tpu_inference/models/jax/utils/quantization/mxfp4_utils.py new file mode 100644 index 000000000..5db9a8893 --- /dev/null +++ b/tpu_inference/models/jax/utils/quantization/mxfp4_utils.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +# MXFP4 constants +MXFP4_BLOCK_SIZE: int = 32 +# Exponent-only e8m0 scale bias used by MXFP4 scales +MXFP4_SCALE_BIAS: int = 127 +# Name used in config.json quantization_config["quant_method"] +MXFP4_QUANT_METHOD: str = "mxfp4" + +# Precompute a small LUT once; move to device on demand (cheap 16-element copy) +FP4_LUT = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, # 0b0000-0b0111 + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, # 0b1000-0b1111 + ], + dtype=torch.float32) + + +def unpack_mxfp4(packed: torch.Tensor) -> torch.Tensor: + """Unpack uint8 (..., 16) -> fp4 values (..., 32) using low->high nibble order. + + Returns float32 values corresponding to FP4 codebook entries. + """ + assert packed.dtype == torch.uint8 + low = packed & 0x0F + high = (packed >> 4) & 0x0F + idx = torch.stack([low, high], dim=-1).flatten(-2) + lut = FP4_LUT.to(packed.device) + return lut[idx.long()] + + +def e8m0_to_fp32(u8: torch.Tensor) -> torch.Tensor: + """Convert e8m0 uint8 exponents to power-of-two scales using MXFP4_SCALE_BIAS. + + Uses ldexp for exact power-of-two scaling: 1.0 * 2**(u8 - bias). + """ + exponents = (u8.to(torch.int32) - int(MXFP4_SCALE_BIAS)).to(torch.int32) + ones = torch.ones_like(u8, dtype=torch.float32) + return torch.ldexp(ones, exponents) + + +def dequant_mxfp4_to_bf16(blocks_u8: torch.Tensor, + scales_u8: torch.Tensor) -> torch.Tensor: + """Dequantize MXFP4 blocks/scales into bfloat16 values. + + Args: + blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte holds 2 FP4 codes. + scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per 32-value block. + + Returns: + torch.bfloat16 tensor with last logical dimension K = Kb * 32. + """ + if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8: + raise ValueError( + f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}" + ) + # Unpack FP4 codes to float32 values [..., Kb, 32] + fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32) + # Compute power-of-two scales and apply per block + scales = e8m0_to_fp32(scales_u8).unsqueeze(-1) # (..., Kb, 1) + full = (fp4_vals * scales).reshape(*fp4_vals.shape[:-2], + fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE) + return full.to(torch.bfloat16) + + +def unpack_mxfp4_to_fp32( + blocks_u8: torch.Tensor, + scales_u8: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Decode MXFP4 packed blocks and e8m0 scales to float32 codes and scales. + + Args: + blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte packs two FP4 codes. + scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per block. + + Returns: + (codes_fp32, scales_fp32), where + - codes_fp32 has shape [..., Kb*32] and dtype float32 + - scales_fp32 has shape [..., Kb] and dtype float32 + """ + if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8: + raise ValueError( + f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}" + ) + fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32) float32 + codes_fp32 = fp4_vals.reshape(*fp4_vals.shape[:-2], + fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE) + scales_fp32 = e8m0_to_fp32(scales_u8) # (..., Kb) float32 + return codes_fp32, scales_fp32 diff --git a/tpu_inference/models/jax/utils/quantization/quantization_utils.py b/tpu_inference/models/jax/utils/quantization/quantization_utils.py index c1f89efb9..ea6eaec9a 100644 --- a/tpu_inference/models/jax/utils/quantization/quantization_utils.py +++ b/tpu_inference/models/jax/utils/quantization/quantization_utils.py @@ -71,6 +71,29 @@ } } +# Default Qwix config for GPT-OSS MXFP4 checkpoints. +# Notes: +# - We quantize only the MoE expert weights by default (router stays in BF16). +# - We use Qwix's abstract-model path so weights can be set directly into QArray +# fields during weight loading (similar to DeepSeek's flow). +# - Activation quantization is not set but Qwix would pickup MoE sum if activated +DEFAULT_GPT_OSS_FP4_CONFIG = { + "qwix": { + "use_abstract_model": + True, + "scale_dtype": + "bfloat16", + "rules": [ + { + "module_path": ".*custom_module", + "weight_qtype": "float4_e2m1fn", + "act_qtype": None, + "tile_size": 32, + }, + ], + } +} + def parse_qwix_config_to_rules( qwix_config: List[dict]) -> List[qwix.QuantizationRule]: @@ -400,6 +423,9 @@ def get_default_qwix_quantization_config( return DEFAULT_DEEPSEEK_FP8_CONFIG elif model_type == "llama4" and quant_method == "compressed-tensors": return DEFAULT_LLAMA4_FP8_CONFIG + # MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix + elif model_type == "gpt_oss" and quant_method == "mxfp4": + return DEFAULT_GPT_OSS_FP4_CONFIG def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):