@@ -67,6 +67,34 @@ def _swiglu(x: Float, alpha: Float, limit: Float) -> Float:
6767 return gated_activation * (x_linear + 1 )
6868
6969
70+ @dataclass (kw_only = True )
71+ class CombineExperts (nnx .Module ):
72+ """Module for combining expert outputs with weighted sum."""
73+ dtype : jnp .dtype
74+
75+ def __call__ (self , down_proj_TED : Float , weights_TX : Float ,
76+ indices_TX : jax .Array ) -> Float :
77+ """Combines expert outputs using weighted sum.
78+
79+ Args:
80+ down_proj_TED: Expert outputs, shape (tokens, experts, hidden_dim)
81+ weights_TX: Router weights, shape (tokens, experts_per_token)
82+ indices_TX: Selected expert indices, shape (tokens, experts_per_token)
83+
84+ Returns:
85+ Combined output, shape (tokens, hidden_dim)
86+ """
87+ with jax .named_scope ("combine_experts" ):
88+ indices_for_gather = indices_TX [..., None ]
89+ gathered_down_proj_TED = jnp .take_along_axis (down_proj_TED ,
90+ indices_for_gather ,
91+ axis = 1 )
92+ output_TD = jnp .einsum ('TXD,TX -> TD' , gathered_down_proj_TED ,
93+ weights_TX )
94+
95+ return output_TD .astype (self .dtype )
96+
97+
7098@dataclass (kw_only = True )
7199class GptOssMoE (nnx .Module ):
72100 """
@@ -114,20 +142,16 @@ def __call__(self, x_TD: Float) -> Float:
114142 down_proj_TED += self .mlp2_bias_ED .value
115143
116144 # Weighted sum of expert outputs
117- with jax .named_scope ("sum" ):
118- indices_for_gather = indices_TX [..., None ]
119- gathered_down_proj_TED = jnp .take_along_axis (down_proj_TED ,
120- indices_for_gather ,
121- axis = 1 )
122- output_TD = jnp .einsum ('TXD,TX -> TD' , gathered_down_proj_TED ,
123- weights_TX )
145+ output_TD = self .combine_experts (down_proj_TED , weights_TX , indices_TX )
124146
125- return output_TD . astype ( self . dtype )
147+ return output_TD
126148
127149 def __post_init__ (self , rngs : nnx .Rngs ):
128150 """Initializes all weights and biases for the MoE block."""
129151 D , F , E = self .hidden_size , self .intermediate_size_moe , self .num_local_experts
130152
153+ self .combine_experts = CombineExperts (dtype = self .dtype )
154+
131155 # MLP #1 Weights (Combined Gate and Up-projection) and Bias
132156 self .mlp1_weight_EDF2 = create_param (
133157 rngs ,
0 commit comments