Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,8 @@ struct GemmProfilerBackend
}
}

void prepare(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream);
void prepare(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream,
void const* token_selected_experts_customized = nullptr, bool use_customized_router = false);

std::map<std::string, std::pair<size_t, size_t>> getProfilerWorkspaces(int maxM, bool is_tma_ws);
size_t getWorkspaceSize(int maxM);
Expand Down Expand Up @@ -990,7 +991,7 @@ struct GemmProfilerBackend
nvinfer1::DataType mOType{};

// This will be a unique value for every iteration of warmup and actual bench
constexpr static int64_t NUM_ROUTING_SAMPLES = 16;
constexpr static int64_t NUM_ROUTING_SAMPLES = 1;

constexpr static int64_t NUM_FUSION_TYPES = 2;
constexpr static int64_t NUM_SWAP_AB_TYPES = 2;
Expand All @@ -1006,8 +1007,9 @@ struct GemmProfilerBackend
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType mScalingType{};

private:
void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream);
void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream);
void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream,
void const* token_selected_experts_customized = nullptr, bool use_customized_router = false);
void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, bool swap_ab, cudaStream_t stream);
};
Expand Down
27 changes: 19 additions & 8 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4288,7 +4288,8 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile
return out_map;
}

void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char, cudaStream_t stream)
void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char, cudaStream_t stream,
void const* token_selected_experts_customized, bool use_customized_router)
{
auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90);
#define GET_WS_PTR_BASE(type, name) \
Expand Down Expand Up @@ -4329,10 +4330,19 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha
int const start_expert_id = mNumExpertsPerNode * mParallelismConfig.ep_rank;

uint32_t num_threads = 256;
dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1};
prepareFakeRouterBuffers<<<grid_dim, num_threads, 0, stream>>>(
token_selected_experts_base, num_tokens, mK, mNumExperts);
sync_check_cuda_error(stream);
if (use_customized_router)
{
// copy token selected experts to token_selected_experts_base
cudaMemcpyAsync(token_selected_experts_base, token_selected_experts_customized,
num_tokens * mK * sizeof(int), cudaMemcpyDeviceToDevice, stream);
}
else
{
dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1};
prepareFakeRouterBuffers<<<grid_dim, num_threads, 0, stream>>>(
token_selected_experts_base, num_tokens, mK, mNumExperts);
sync_check_cuda_error(stream);
}

for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++)
{
Expand Down Expand Up @@ -4539,15 +4549,16 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr
}
}

void GemmProfilerBackend::prepare(
int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream)
void GemmProfilerBackend::prepare(int num_tokens, char* workspace_ptr_char, void const* expert_weights,
cudaStream_t stream, void const* token_selected_experts_customized, bool use_customized_router)
{
mSampleIndex = 0;

auto workspace_size = getWorkspaceSize(num_tokens);
populateRandomBuffer(workspace_ptr_char, workspace_size, stream);
auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90);

prepareRouting(num_tokens, workspace_ptr_char, stream);
prepareRouting(num_tokens, workspace_ptr_char, stream, token_selected_experts_customized, use_customized_router);
prepareQuantParams(num_tokens, workspace_ptr_char, stream);
for (auto fusion : {TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE})
Expand Down
19 changes: 11 additions & 8 deletions cpp/tensorrt_llm/thop/moeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,13 +662,13 @@ class FusedMoeRunner : public torch::CustomClassHolder
}

// TODO Update this to be able to tell if we are profiling swiglu bias
void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights,
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
int64_t const unpadded_hidden_size)
void runGemmProfile(torch::Tensor const& input, torch::optional<torch::Tensor> const& token_final_scales,
torch::Tensor const& fc1_expert_weights, torch::optional<torch::Tensor> const& fc1_expert_biases,
torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases,
int64_t const top_k, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank,
int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode,
int64_t const gemm_idx, int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
int64_t const unpadded_hidden_size, bool const use_customized_router)
{
std::lock_guard<std::mutex> lock(mMutex);

Expand Down Expand Up @@ -748,7 +748,10 @@ class FusedMoeRunner : public torch::CustomClassHolder
auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size);
TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile.");

mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, stream);
void const* token_selected_experts_customized
= token_final_scales.has_value() ? token_final_scales.value().const_data_ptr() : nullptr;
mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, stream,
token_selected_experts_customized, use_customized_router);
}

// Profile specific tactic. Assuming at least one preparation phase has been executed already.
Expand Down
90 changes: 80 additions & 10 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm import deep_gemm
from tensorrt_llm._torch.modules.fused_moe.routing import (
ROUTING_METHOD_TYPE_TO_CLASS, RoutingMethodType)
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.logger import logger

Expand All @@ -27,14 +29,65 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
torch.bmm(a, b, out=out)


def prepare_dummy_token_selected_experts_hook(
input: torch.Tensor,
top_k: int,
num_experts: int,
routing_method_type: int = int(RoutingMethodType.Default),
):
"""
Creates a hook function that generates dummy token_selected_experts for tuning.

Args:
input: Input tensor to determine shape and device
top_k: Number of experts per token
num_experts: Total number of experts
routing_method_type: Type of routing method to use

Returns:
A hook function that can be used with the tuner
"""
tuner = AutoTuner.get()
if not tuner.is_tuning_mode:
return lambda inputs: inputs

# Get routing method
routing_cls_kwargs = {}
routing_method = ROUTING_METHOD_TYPE_TO_CLASS[routing_method_type](
top_k=top_k, **routing_cls_kwargs)

def create_dummy_token_selected_experts(
inputs: List[torch.Tensor], ) -> List[torch.Tensor]:
input_tensor = inputs[0] # First tensor is the input
# Generate dummy routing logits with correct shape
routing_logits_for_tuner = torch.randn(input_tensor.shape[0],
num_experts,
dtype=torch.bfloat16,
device=input_tensor.device)

# Apply routing to get properly shaped token_selected_experts
topk_ids_for_tuner, topk_weights_for_tuner = routing_method.apply(
routing_logits_for_tuner)

# Replace the token_selected_experts tensor (inputs[1]) with our generated one
if len(inputs) > 1:
inputs[1] = topk_ids_for_tuner

return inputs

return create_dummy_token_selected_experts


class MoERunner(TunableRunner):
# avoid overhead of creating a new runner in forward pass
runner_dict = dict()
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ),
tune_max_num_tokens=8192,
inputs_pre_hook=None, # Will be set dynamically in fused_moe function
)

def __init__(
Expand Down Expand Up @@ -125,10 +178,13 @@ def forward(
gemm_idx: int = 0,
tactic: int = -1,
do_preparation: bool = False,
**kwargs,
):
x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
x, token_selected_experts, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
use_customized_router = True
self.fused_moe_runner.run_gemm_profile(
x,
token_selected_experts,
fc1_expert_weights,
fc1_expert_biases,
fc2_expert_weights,
Expand All @@ -147,6 +203,7 @@ def forward(
do_preparation,
self.activation_type,
self.unpadded_hidden_size,
use_customized_router,
)


Expand Down Expand Up @@ -183,6 +240,7 @@ def fused_moe(
tuner_num_tokens: Optional[int] = None,
tuner_top_k: Optional[int] = None,
activation_type: int = int(ActivationType.Swiglu),
routing_method_type: int = int(RoutingMethodType.Default),
unpadded_hidden_size: Optional[int] = None,
out_tensor: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
Expand All @@ -200,6 +258,15 @@ def fused_moe(
tuner_input = input
tuner_top_k = token_selected_experts.size(1)

tuning_config = MoERunner.tuning_config
tuning_config.inputs_pre_hook = prepare_dummy_token_selected_experts_hook(
tuner_input,
tuner_top_k,
fc1_expert_weights.shape[0] *
ep_size, # num_experts from weight tensor shape
routing_method_type,
)

# allocate workspace for profiling
moe_runner = MoERunner(
x_dtype=input.dtype,
Expand All @@ -223,27 +290,30 @@ def fused_moe(
)

MoERunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens

input_tensors = [
tuner_input,
token_selected_experts,
fc1_expert_weights,
fc1_expert_biases,
fc2_expert_weights,
fc2_expert_biases,
]
_, gemm_tactic_1 = tuner.choose_one(
"trtllm::fused_moe::gemm1",
[moe_runner],
MoERunner.tuning_config,
[
tuner_input, fc1_expert_weights, fc1_expert_biases,
fc2_expert_weights, fc2_expert_biases
],
input_tensors,
gemm_idx=1,
ep_size=ep_size,
)

_, gemm_tactic_2 = tuner.choose_one(
"trtllm::fused_moe::gemm2",
[moe_runner],
MoERunner.tuning_config,
[
tuner_input, fc1_expert_weights, fc1_expert_biases,
fc2_expert_weights, fc2_expert_biases
],
input_tensors,
gemm_idx=2,
ep_size=ep_size,
)

run_moe = moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode else moe_runner.fused_moe_runner.run_moe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ def forward_chunk(
tuner_num_tokens=tuner_num_tokens,
tuner_top_k=tuner_top_k,
activation_type=self.activation_type,
routing_method_type=self.routing_method.routing_method_type,
unpadded_hidden_size=self.unpadded_hidden_size,
out_tensor=moe_output,
)
Expand Down