Skip to content

Commit acd8694

Browse files
committed
[https://nvbugs/5680133][fix] Implement customizable router for cutlass MoE.
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
1 parent 85b4c92 commit acd8694

File tree

9 files changed

+101
-39
lines changed

9 files changed

+101
-39
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,8 @@ struct GemmProfilerBackend
960960
}
961961
}
962962

963-
void prepare(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream);
963+
void prepare(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream,
964+
void const* token_selected_experts_customized = nullptr, bool use_customized_router = false);
964965

965966
std::map<std::string, std::pair<size_t, size_t>> getProfilerWorkspaces(int maxM, bool is_tma_ws);
966967
size_t getWorkspaceSize(int maxM);
@@ -990,7 +991,7 @@ struct GemmProfilerBackend
990991
nvinfer1::DataType mOType{};
991992

992993
// This will be a unique value for every iteration of warmup and actual bench
993-
constexpr static int64_t NUM_ROUTING_SAMPLES = 16;
994+
constexpr static int64_t NUM_ROUTING_SAMPLES = 1;
994995

995996
constexpr static int64_t NUM_FUSION_TYPES = 2;
996997
constexpr static int64_t NUM_SWAP_AB_TYPES = 2;
@@ -1006,8 +1007,9 @@ struct GemmProfilerBackend
10061007
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType mScalingType{};
10071008

10081009
private:
1009-
void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream);
10101010
void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream);
1011+
void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream,
1012+
void const* token_selected_experts_customized = nullptr, bool use_customized_router = false);
10111013
void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights,
10121014
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, bool swap_ab, cudaStream_t stream);
10131015
};

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4288,7 +4288,8 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile
42884288
return out_map;
42894289
}
42904290

4291-
void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char, cudaStream_t stream)
4291+
void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char, cudaStream_t stream,
4292+
void const* token_selected_experts_customized, bool use_customized_router)
42924293
{
42934294
auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90);
42944295
#define GET_WS_PTR_BASE(type, name) \
@@ -4329,10 +4330,19 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha
43294330
int const start_expert_id = mNumExpertsPerNode * mParallelismConfig.ep_rank;
43304331

43314332
uint32_t num_threads = 256;
4332-
dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1};
4333-
prepareFakeRouterBuffers<<<grid_dim, num_threads, 0, stream>>>(
4334-
token_selected_experts_base, num_tokens, mK, mNumExperts);
4335-
sync_check_cuda_error(stream);
4333+
if (use_customized_router)
4334+
{
4335+
// copy token selected experts to token_selected_experts_base
4336+
cudaMemcpyAsync(token_selected_experts_base, token_selected_experts_customized,
4337+
num_tokens * mK * sizeof(int), cudaMemcpyDeviceToDevice, stream);
4338+
}
4339+
else
4340+
{
4341+
dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1};
4342+
prepareFakeRouterBuffers<<<grid_dim, num_threads, 0, stream>>>(
4343+
token_selected_experts_base, num_tokens, mK, mNumExperts);
4344+
sync_check_cuda_error(stream);
4345+
}
43364346

43374347
for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++)
43384348
{
@@ -4539,15 +4549,16 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr
45394549
}
45404550
}
45414551

4542-
void GemmProfilerBackend::prepare(
4543-
int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream)
4552+
void GemmProfilerBackend::prepare(int num_tokens, char* workspace_ptr_char, void const* expert_weights,
4553+
cudaStream_t stream, void const* token_selected_experts_customized, bool use_customized_router)
45444554
{
45454555
mSampleIndex = 0;
45464556

45474557
auto workspace_size = getWorkspaceSize(num_tokens);
45484558
populateRandomBuffer(workspace_ptr_char, workspace_size, stream);
4559+
auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90);
45494560

4550-
prepareRouting(num_tokens, workspace_ptr_char, stream);
4561+
prepareRouting(num_tokens, workspace_ptr_char, stream, token_selected_experts_customized, use_customized_router);
45514562
prepareQuantParams(num_tokens, workspace_ptr_char, stream);
45524563
for (auto fusion : {TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE,
45534564
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE})

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -660,13 +660,13 @@ class FusedMoeRunner : public torch::CustomClassHolder
660660
}
661661

662662
// TODO Update this to be able to tell if we are profiling swiglu bias
663-
void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights,
664-
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
665-
torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
666-
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
667-
int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
668-
int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
669-
int64_t const unpadded_hidden_size)
663+
void runGemmProfile(torch::Tensor const& input, torch::optional<torch::Tensor> const& token_final_scales,
664+
torch::Tensor const& fc1_expert_weights, torch::optional<torch::Tensor> const& fc1_expert_biases,
665+
torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases,
666+
int64_t const top_k, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank,
667+
int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode,
668+
int64_t const gemm_idx, int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
669+
int64_t const unpadded_hidden_size, bool const use_customized_router)
670670
{
671671
std::lock_guard<std::mutex> lock(mMutex);
672672

@@ -746,7 +746,10 @@ class FusedMoeRunner : public torch::CustomClassHolder
746746
auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size);
747747
TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile.");
748748

749-
mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, stream);
749+
void const* token_selected_experts_customized
750+
= token_final_scales.has_value() ? token_final_scales.value().const_data_ptr() : nullptr;
751+
mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, stream,
752+
token_selected_experts_customized, use_customized_router);
750753
}
751754

752755
// Profile specific tactic. Assuming at least one preparation phase has been executed already.

tensorrt_llm/_torch/autotuner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,8 @@ def _profile_runners(
798798
best_runner_id, best_tactic = None, None
799799
# If the inputs_pre_hook is provided, it will be called before profiling.
800800
if tuning_config.inputs_pre_hook is not None:
801-
input_tensors = tuning_config.inputs_pre_hook(input_tensors)
801+
input_tensors = tuning_config.inputs_pre_hook(
802+
input_tensors, **kwargs)
802803
for runner_id, runner in enumerate(runners):
803804
# TODO: use FakeTensor here.
804805
runner_arg_names = {

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,8 @@ def generate_permuted_idx_to_expanded_idx(
440440
permuted_idx_to_expanded_idx.append(self.pad_val)
441441
return permuted_idx_to_expanded_idx
442442

443-
def inputs_pre_hook(self,
444-
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
443+
def inputs_pre_hook(self, inputs: List[torch.Tensor],
444+
**kwargs) -> List[torch.Tensor]:
445445
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others = inputs
446446
num_tokens = self.infer_num_tokens(a.size(0))
447447
num_tokens_per_expert = self.generate_num_tokens_per_expert(
@@ -465,8 +465,8 @@ def inputs_pre_hook(self,
465465
device=num_non_exiting_tiles.device)
466466
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others
467467

468-
def inputs_pre_hook_finalize_fusion(
469-
self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
468+
def inputs_pre_hook_finalize_fusion(self, inputs: List[torch.Tensor],
469+
**kwargs) -> List[torch.Tensor]:
470470
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
471471
num_tokens = self.infer_num_tokens(a.size(0))
472472
num_tokens_per_expert = self.generate_num_tokens_per_expert(
@@ -1414,8 +1414,8 @@ def __init__(self, num_experts: int, top_k: int, num_local_experts: int,
14141414
def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int:
14151415
return input_shapes[0][0]
14161416

1417-
def inputs_pre_hook(self,
1418-
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
1417+
def inputs_pre_hook(self, inputs: List[torch.Tensor],
1418+
**kwargs) -> List[torch.Tensor]:
14191419
x, x_sf, token_selected_experts, token_final_scales, *others = inputs
14201420
num_tokens = token_selected_experts.size(0)
14211421
new_token_final_scales, new_token_selected_experts = torch.randn(

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ..modules.multi_stream_utils import do_multi_stream
1515
from ..modules.swiglu import silu_and_mul_kernel
1616
from ..utils import (ActivationType, fp4_scale_infer_shape,
17+
gen_balanced_moe_routing_input,
1718
get_last_power_of_2_num_tokens_buckets,
1819
last_positive_power_of_2)
1920

@@ -24,6 +25,18 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
2425
torch.bmm(a, b, out=out)
2526

2627

28+
def inputs_pre_hook(inputs: List[torch.Tensor], ep_size: int,
29+
**kwargs) -> List[torch.Tensor]:
30+
x, token_selected_experts, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
31+
num_tokens = x.shape[0]
32+
num_experts = fc2_expert_weights.shape[0] * ep_size
33+
top_k = token_selected_experts.shape[1]
34+
router = gen_balanced_moe_routing_input(num_tokens, num_experts, top_k)
35+
inputs[1] = router.to(dtype=torch.int32,
36+
device=token_selected_experts.device)
37+
return inputs
38+
39+
2740
class MoERunner(TunableRunner):
2841
# avoid overhead of creating a new runner in forward pass
2942
runner_dict = dict()
@@ -32,6 +45,7 @@ class MoERunner(TunableRunner):
3245
0, 0, get_last_power_of_2_num_tokens_buckets,
3346
last_positive_power_of_2), ),
3447
tune_max_num_tokens=8192,
48+
inputs_pre_hook=inputs_pre_hook,
3549
)
3650

3751
def __init__(
@@ -99,10 +113,13 @@ def forward(
99113
gemm_idx: int = 0,
100114
tactic: int = -1,
101115
do_preparation: bool = False,
116+
**kwargs,
102117
):
103-
x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
118+
x, token_selected_experts, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
119+
use_customized_router = True
104120
self.fused_moe_runner.run_gemm_profile(
105121
x,
122+
token_selected_experts,
106123
fc1_expert_weights,
107124
fc1_expert_biases,
108125
fc2_expert_weights,
@@ -121,6 +138,7 @@ def forward(
121138
do_preparation,
122139
self.activation_type,
123140
self.unpadded_hidden_size,
141+
use_customized_router,
124142
)
125143

126144

@@ -197,27 +215,30 @@ def fused_moe(
197215
)
198216

199217
MoERunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
200-
218+
input_tensors = [
219+
tuner_input,
220+
token_selected_experts,
221+
fc1_expert_weights,
222+
fc1_expert_biases,
223+
fc2_expert_weights,
224+
fc2_expert_biases,
225+
]
201226
_, gemm_tactic_1 = tuner.choose_one(
202227
"trtllm::fused_moe::gemm1",
203228
[moe_runner],
204229
MoERunner.tuning_config,
205-
[
206-
tuner_input, fc1_expert_weights, fc1_expert_biases,
207-
fc2_expert_weights, fc2_expert_biases
208-
],
230+
input_tensors,
209231
gemm_idx=1,
232+
ep_size=ep_size,
210233
)
211234

212235
_, gemm_tactic_2 = tuner.choose_one(
213236
"trtllm::fused_moe::gemm2",
214237
[moe_runner],
215238
MoERunner.tuning_config,
216-
[
217-
tuner_input, fc1_expert_weights, fc1_expert_biases,
218-
fc2_expert_weights, fc2_expert_biases
219-
],
239+
input_tensors,
220240
gemm_idx=2,
241+
ep_size=ep_size,
221242
)
222243

223244
run_moe = moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode else moe_runner.fused_moe_runner.run_moe

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def prepare_dummy_topk_and_hook(
7979
routing_logits_for_tuner = routing_logits
8080

8181
# Define hook to recreate dummy tensors when shape changes during profiling
82-
def recreate_dummy_topk_if_needed(
83-
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
82+
def recreate_dummy_topk_if_needed(inputs: List[torch.Tensor],
83+
**kwargs) -> List[torch.Tensor]:
8484
"""Recreate dummy topk tensors if token count changed during profiling."""
8585
current_num_tokens = inputs[hidden_states_index].shape[0]
8686

tensorrt_llm/_torch/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,27 @@ def wrapper(*args, **kwargs):
370370
return wrapper
371371

372372
return decorator(func) if func else decorator
373+
374+
375+
def gen_balanced_moe_routing_input(num_tokens: int, num_experts: int,
376+
top_k: int) -> torch.Tensor:
377+
"""
378+
Generate imbalanced routing input for MoE routing.
379+
"""
380+
token_selected_experts_gen = torch.zeros(num_tokens, top_k)
381+
# select k unique experts from num_experts for each token
382+
for i in range(num_tokens):
383+
token_selected_experts_gen[i] = torch.randperm(num_experts)[:top_k]
384+
return token_selected_experts_gen
385+
386+
387+
def gen_imbalanced_moe_routing_input(num_tokens: int, num_experts: int,
388+
top_k: int) -> torch.Tensor:
389+
"""
390+
Generate imbalanced routing input for MoE routing.
391+
"""
392+
token_selected_experts_gen = torch.zeros(num_tokens, top_k)
393+
# select k unique experts from num_experts for each token
394+
for i in range(num_tokens):
395+
token_selected_experts_gen[i] = torch.arange(0, top_k)
396+
return token_selected_experts_gen

tests/unittest/_torch/misc/test_autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def forward(
363363
return [gemm_0, gemm_1, gemm_fallback][tactic_id](*inputs)
364364

365365
@staticmethod
366-
def inputs_pre_hook(inputs: List[torch.Tensor]):
366+
def inputs_pre_hook(inputs: List[torch.Tensor], **kwargs):
367367
# always set the first element to bo iota in x
368368
x, w = inputs
369369
x_hooked = torch.zeros_like(x)

0 commit comments

Comments
 (0)