Skip to content

Conversation

@OsirisDuan
Copy link

@OsirisDuan OsirisDuan commented Nov 20, 2025

What this PR does / why we need it?

Does this PR introduce any user-facing change?

How was this patch tested?

Signed-off-by: Ascendyh <hw7osiris@outlook.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused Triton kernel for GDN gating to optimize performance on Ascend hardware. The overall approach is sound, but I've identified a few critical issues that need to be addressed. There is a breaking import in qwen3_next.py that will cause a runtime error. More importantly, the new Triton kernel in fused_gdn_gating.py does not correctly handle non-contiguous tensors, which could lead to incorrect calculations. I've provided detailed comments and code suggestions to resolve these problems. I also included a suggestion to investigate a performance tuning parameter that appears suboptimal.

fused_gdn_gating)
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock)

from vllm_ascend.ops.fla import fused_sigmoid_gating_delta_rule_update
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This import will cause a runtime error because fused_sigmoid_gating_delta_rule_update is not defined in vllm_ascend.ops.fla. Furthermore, this function is not used anywhere in the file. This line should be removed.

Comment on lines 38 to 67
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
batch, num_heads = a.shape
seq_len = 1
NUM_BATCH_GROUPS = batch
BLK_BATCHES = 1
if batch > 40:
BLK_BATCHES = triton.next_power_of_2(triton.cdiv(batch, 32))
NUM_BATCH_GROUPS = triton.cdiv(batch, BLK_BATCHES)

grid = (NUM_BATCH_GROUPS, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty_like(a, dtype=torch.float32)
fused_gdn_gating_kernel[grid](g,
A_log,
a,
dt_bias,
seq_len,
num_heads,
batch,
beta,
threshold,
8,
BLK_BATCHES=BLK_BATCHES,
num_warps=1)
return g No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The Triton kernel fused_gdn_gating_kernel implicitly assumes that the input tensor a and output tensor g are contiguous, as it performs pointer arithmetic without considering strides. However, a could be non-contiguous depending on the operations that produced it. Using torch.empty_like(a) for g will preserve the memory layout of a, so g could also be non-contiguous. This can lead to incorrect memory access and wrong results.

To fix this, you should ensure a is contiguous before using it and create g as a new contiguous tensor.

def fused_gdn_gating(
    A_log: torch.Tensor,
    a: torch.Tensor,
    dt_bias: torch.Tensor,
    beta: float = 1.0,
    threshold: float = 20.0,
) -> torch.Tensor:
    a = a.contiguous()
    batch, num_heads = a.shape
    seq_len = 1
    NUM_BATCH_GROUPS = batch
    BLK_BATCHES = 1
    if batch > 40:
        BLK_BATCHES = triton.next_power_of_2(triton.cdiv(batch, 32))
        NUM_BATCH_GROUPS = triton.cdiv(batch, BLK_BATCHES)

    grid = (NUM_BATCH_GROUPS, seq_len, triton.cdiv(num_heads, 8))
    g = torch.empty((batch, num_heads), dtype=torch.float32, device=a.device)
    fused_gdn_gating_kernel[grid](g,
                                  A_log,
                                  a,
                                  dt_bias,
                                  seq_len,
                                  num_heads,
                                  batch,
                                  beta,
                                  threshold,
                                  8,
                                  BLK_BATCHES=BLK_BATCHES,
                                  num_warps=1)
    return g

threshold,
8,
BLK_BATCHES=BLK_BATCHES,
num_warps=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using num_warps=1 is likely to be suboptimal for performance. Triton kernels, especially memory-bound ones, benefit from having multiple warps to hide memory latency (memory-level parallelism). While this kernel has some compute-intensive operations (exp, log), it still involves significant data movement. A single warp might underutilize the hardware's execution units and memory bandwidth.

I recommend increasing num_warps (e.g., to 4 or 8) and benchmarking to find the optimal value for your target hardware.

Suggested change
num_warps=1)
num_warps=4)

@OsirisDuan OsirisDuan changed the title Add fused gdn gating [task] Add fused gdn gating triton kernel Nov 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant