1414from ..modules .multi_stream_utils import do_multi_stream
1515from ..modules .swiglu import silu_and_mul_kernel
1616from ..utils import (ActivationType , fp4_scale_infer_shape ,
17+ gen_balanced_moe_routing_input ,
1718 get_last_power_of_2_num_tokens_buckets ,
1819 last_positive_power_of_2 )
1920
@@ -24,6 +25,18 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
2425 torch .bmm (a , b , out = out )
2526
2627
28+ def inputs_pre_hook (inputs : List [torch .Tensor ], ep_size : int ,
29+ ** kwargs ) -> List [torch .Tensor ]:
30+ x , token_selected_experts , fc1_expert_weights , fc1_expert_biases , fc2_expert_weights , fc2_expert_biases = inputs
31+ num_tokens = x .shape [0 ]
32+ num_experts = fc2_expert_weights .shape [0 ] * ep_size
33+ top_k = token_selected_experts .shape [1 ]
34+ router = gen_balanced_moe_routing_input (num_tokens , num_experts , top_k )
35+ inputs [1 ] = router .to (dtype = torch .int32 ,
36+ device = token_selected_experts .device )
37+ return inputs
38+
39+
2740class MoERunner (TunableRunner ):
2841 # avoid overhead of creating a new runner in forward pass
2942 runner_dict = dict ()
@@ -32,6 +45,7 @@ class MoERunner(TunableRunner):
3245 0 , 0 , get_last_power_of_2_num_tokens_buckets ,
3346 last_positive_power_of_2 ), ),
3447 tune_max_num_tokens = 8192 ,
48+ inputs_pre_hook = inputs_pre_hook ,
3549 )
3650
3751 def __init__ (
@@ -99,10 +113,13 @@ def forward(
99113 gemm_idx : int = 0 ,
100114 tactic : int = - 1 ,
101115 do_preparation : bool = False ,
116+ ** kwargs ,
102117 ):
103- x , fc1_expert_weights , fc1_expert_biases , fc2_expert_weights , fc2_expert_biases = inputs
118+ x , token_selected_experts , fc1_expert_weights , fc1_expert_biases , fc2_expert_weights , fc2_expert_biases = inputs
119+ use_customized_router = True
104120 self .fused_moe_runner .run_gemm_profile (
105121 x ,
122+ token_selected_experts ,
106123 fc1_expert_weights ,
107124 fc1_expert_biases ,
108125 fc2_expert_weights ,
@@ -121,6 +138,7 @@ def forward(
121138 do_preparation ,
122139 self .activation_type ,
123140 self .unpadded_hidden_size ,
141+ use_customized_router ,
124142 )
125143
126144
@@ -197,27 +215,30 @@ def fused_moe(
197215 )
198216
199217 MoERunner .tuning_config .tune_max_num_tokens = tune_max_num_tokens
200-
218+ input_tensors = [
219+ tuner_input ,
220+ token_selected_experts ,
221+ fc1_expert_weights ,
222+ fc1_expert_biases ,
223+ fc2_expert_weights ,
224+ fc2_expert_biases ,
225+ ]
201226 _ , gemm_tactic_1 = tuner .choose_one (
202227 "trtllm::fused_moe::gemm1" ,
203228 [moe_runner ],
204229 MoERunner .tuning_config ,
205- [
206- tuner_input , fc1_expert_weights , fc1_expert_biases ,
207- fc2_expert_weights , fc2_expert_biases
208- ],
230+ input_tensors ,
209231 gemm_idx = 1 ,
232+ ep_size = ep_size ,
210233 )
211234
212235 _ , gemm_tactic_2 = tuner .choose_one (
213236 "trtllm::fused_moe::gemm2" ,
214237 [moe_runner ],
215238 MoERunner .tuning_config ,
216- [
217- tuner_input , fc1_expert_weights , fc1_expert_biases ,
218- fc2_expert_weights , fc2_expert_biases
219- ],
239+ input_tensors ,
220240 gemm_idx = 2 ,
241+ ep_size = ep_size ,
221242 )
222243
223244 run_moe = moe_runner .fused_moe_runner .run_moe_min_latency if min_latency_mode else moe_runner .fused_moe_runner .run_moe
0 commit comments