2424 QuantizationConfig , QuantizeMethodBase )
2525
2626from tpu_inference .kernels .fused_moe .v1 .kernel import fused_ep_moe
27- from tpu_inference .layers .vllm .fused_moe import jax_fused_moe_func_padded
27+ from tpu_inference .layers .vllm .fused_moe import fused_moe_func_padded
2828from tpu_inference .layers .vllm .linear_common import (
2929 reorder_concatenated_tensor_for_sharding ,
3030 slice_sharded_tensor_for_concatenation , torch_to_jax_param )
@@ -191,8 +191,12 @@ def select_gemm_impl(
191191 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
192192 assert isinstance (layer , FusedMoE )
193193
194- w2_weight = t2j (layer .w2_weight , use_dlpack = False )
195194 w13_weight = t2j (layer .w13_weight , use_dlpack = False )
195+ w2_weight = t2j (layer .w2_weight , use_dlpack = False )
196+
197+ if self .moe .has_bias :
198+ w13_bias = t2j (layer .w13_bias , use_dlpack = False )
199+ w2_bias = t2j (layer .w2_bias , use_dlpack = False )
196200
197201 if self .use_kernel and layer .use_ep :
198202 # Kernel expects:
@@ -208,25 +212,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
208212 # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
209213 w13_reshaped = w13_weight .reshape (num_experts , 2 ,
210214 intermediate_size , hidden_size )
211- w13_weight = jnp .transpose (w13_reshaped , (0 , 1 , 3 , 2 ))
215+ w13_weight_transposed = jnp .transpose (w13_reshaped , (0 , 1 , 3 , 2 ))
212216
213217 # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
214218 w2_weight_transposed = jnp .transpose (w2_weight , (0 , 2 , 1 ))
215219
216220 # Apply EP sharding
217221 w13_weight = jax .device_put (
218- w13_weight ,
222+ w13_weight_transposed ,
219223 Format (Layout ((0 , 1 , 2 , 3 )),
220224 NamedSharding (self .mesh , P ("model" , None , None , None ))))
221- w2_weight_transposed = jax .device_put (
225+ w2_weight = jax .device_put (
222226 w2_weight_transposed ,
223227 Format (Layout ((0 , 1 , 2 )),
224228 NamedSharding (self .mesh , P ("model" , None , None ))))
225229
226- layer .w13_weight = Parameter (torch_view (w13_weight ),
227- requires_grad = False )
228- layer .w2_weight = Parameter (torch_view (w2_weight_transposed ),
229- requires_grad = False )
230+ if self .moe .has_bias :
231+ w13_bias = w13_bias .reshape (num_experts , 2 , intermediate_size )
232+
233+ # Apply EP sharding
234+ w13_bias = jax .device_put (
235+ w13_bias ,
236+ Format (Layout ((0 , 1 , 2 )),
237+ NamedSharding (self .mesh , P ("model" , None , None ))))
238+ w2_bias = jax .device_put (
239+ w2_bias ,
240+ Format (Layout ((0 , 1 )),
241+ NamedSharding (self .mesh , P ("model" , None ))))
242+
230243 else :
231244 # Original logic for non-kernel path
232245 if layer .use_ep :
@@ -238,6 +251,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
238251 w2_weight ,
239252 Format (Layout ((0 , 1 , 2 )),
240253 NamedSharding (self .mesh , P ("model" , None , None ))))
254+
255+ if self .moe .has_bias :
256+ w13_bias = jax .device_put (
257+ w13_bias ,
258+ Format (Layout ((0 , 1 )),
259+ NamedSharding (self .mesh , P ("model" , None ))))
260+ w2_bias = jax .device_put (
261+ w2_bias ,
262+ Format (Layout ((0 , 1 )),
263+ NamedSharding (self .mesh , P ("model" , None ))))
264+
241265 else :
242266 intermediate_size = w13_weight .shape [1 ] // 2
243267 assert intermediate_size == w2_weight .shape [- 1 ]
@@ -255,11 +279,27 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
255279 Format (Layout ((0 , 1 , 2 )),
256280 NamedSharding (self .mesh , P (None , None , "model" ))))
257281
282+ if self .moe .has_bias :
283+ w13_bias = jax .device_put (
284+ w13_bias ,
285+ Format (Layout ((0 , 1 )),
286+ NamedSharding (self .mesh , P (None , "model" ))))
287+ w2_bias = jax .device_put (
288+ w2_bias ,
289+ Format (Layout ((0 , 1 )),
290+ NamedSharding (self .mesh , P (None , None ))))
291+
258292 layer .w13_weight = Parameter (torch_view (w13_weight ),
259293 requires_grad = False )
260294 layer .w2_weight = Parameter (torch_view (w2_weight ),
261295 requires_grad = False )
262296
297+ if self .moe .has_bias :
298+ layer .w13_bias = Parameter (torch_view (w13_bias ),
299+ requires_grad = False )
300+ layer .w2_bias = Parameter (torch_view (w2_bias ),
301+ requires_grad = False )
302+
263303 def apply (
264304 self ,
265305 layer : torch .nn .Module ,
@@ -290,6 +330,9 @@ def apply(
290330 if scoring_func != "softmax" :
291331 raise NotImplementedError (
292332 "Only softmax is supported for scoring_func" )
333+ # TODO(kyuyeunk): Remove this check once MoE bias support has landed.
334+ if self .moe .has_bias :
335+ raise NotImplementedError ("Bias is not currently supported." )
293336
294337 if self .use_kernel and layer .use_ep :
295338 output = fused_ep_moe (
@@ -305,7 +348,7 @@ def apply(
305348 else :
306349 # Use the original implementation
307350 _fused_moe_func = functools .partial (
308- jax .jit (jax_fused_moe_func_padded ,
351+ jax .jit (fused_moe_func_padded ,
309352 static_argnames = [
310353 "topk" , "global_num_experts" , "renormalize" ,
311354 "reduce_results" , "mesh" , "use_ep"
0 commit comments