diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index b095c79dc954..906c94b86a1b 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -1365,8 +1365,10 @@ def fused_gdn_gating_kernel( blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) # compute beta_output = sigmoid(b) - blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32))) - tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store( + beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask + ) def fused_gdn_gating( @@ -1387,7 +1389,7 @@ def fused_gdn_gating( seq_len = 1 grid = (batch, seq_len, triton.cdiv(num_heads, 8)) g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) - beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) + beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device) fused_gdn_gating_kernel[grid]( g, beta_output,