From b687ad6d46b0df1e70feb2a7fb80dead4ef1ae1d Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Wed, 26 Nov 2025 12:28:17 -0800 Subject: [PATCH] Use the router gemm op for nemotron MOE Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../_torch/auto_deploy/models/patches/nemotron_h.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index 8248ab209f2..5c43d5924aa 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -93,14 +93,15 @@ def _nemotron_h_topk_router_forward(self, hidden_states): Forward pass for NemotronHTopkRouter using the optimized noaux_tc_op kernel. This replaces the original forward method which used pure PyTorch operations - with a fused CUDA kernel that performs: - 1. Sigmoid activation of logits - 2. Group-based expert selection - 3. Top-k selection within selected groups - 4. Normalized weight computation + with optimized CUDA kernels: """ hidden_states = hidden_states.view(-1, self.config.hidden_size) - router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + if self.weight.dtype == torch.float32: + router_logits = F.linear(hidden_states.type(torch.float32), self.weight) + else: + router_logits = torch.ops.trtllm.dsv3_router_gemm_op( + hidden_states, self.weight.t(), bias=None, out_dtype=torch.float32 + ) # Use the fused noaux_tc_op kernel which applies sigmoid internally # and performs group-based top-k selection with normalization