Skip to content

Commit b557388

Browse files
committed
[GPT-OSS] Separate expert sum to not quantize with Qwix
Signed-off-by: Jordan Dotzel <amishacorns@users.noreply.github.com>
1 parent 59029c9 commit b557388

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

tpu_inference/layers/jax/moe/gpt_oss_moe.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,34 @@ def _swiglu(x: Float, alpha: Float, limit: Float) -> Float:
6767
return gated_activation * (x_linear + 1)
6868

6969

70+
@dataclass(kw_only=True)
71+
class CombineExperts(nnx.Module):
72+
"""Module for combining expert outputs with weighted sum."""
73+
dtype: jnp.dtype
74+
75+
def __call__(self, down_proj_TED: Float, weights_TX: Float,
76+
indices_TX: jax.Array) -> Float:
77+
"""Combines expert outputs using weighted sum.
78+
79+
Args:
80+
down_proj_TED: Expert outputs, shape (tokens, experts, hidden_dim)
81+
weights_TX: Router weights, shape (tokens, experts_per_token)
82+
indices_TX: Selected expert indices, shape (tokens, experts_per_token)
83+
84+
Returns:
85+
Combined output, shape (tokens, hidden_dim)
86+
"""
87+
with jax.named_scope("combine_experts"):
88+
indices_for_gather = indices_TX[..., None]
89+
gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED,
90+
indices_for_gather,
91+
axis=1)
92+
output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED,
93+
weights_TX)
94+
95+
return output_TD.astype(self.dtype)
96+
97+
7098
@dataclass(kw_only=True)
7199
class GptOssMoE(nnx.Module):
72100
"""
@@ -114,20 +142,16 @@ def __call__(self, x_TD: Float) -> Float:
114142
down_proj_TED += self.mlp2_bias_ED.value
115143

116144
# Weighted sum of expert outputs
117-
with jax.named_scope("sum"):
118-
indices_for_gather = indices_TX[..., None]
119-
gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED,
120-
indices_for_gather,
121-
axis=1)
122-
output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED,
123-
weights_TX)
145+
output_TD = self.combine_experts(down_proj_TED, weights_TX, indices_TX)
124146

125-
return output_TD.astype(self.dtype)
147+
return output_TD
126148

127149
def __post_init__(self, rngs: nnx.Rngs):
128150
"""Initializes all weights and biases for the MoE block."""
129151
D, F, E = self.hidden_size, self.intermediate_size_moe, self.num_local_experts
130152

153+
self.combine_experts = CombineExperts(dtype=self.dtype)
154+
131155
# MLP #1 Weights (Combined Gate and Up-projection) and Bias
132156
self.mlp1_weight_EDF2 = create_param(
133157
rngs,

0 commit comments

Comments
 (0)