Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm_ascend/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@

from vllm.model_executor.models.qwen3_next import ( # isort: skip
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
fused_gdn_gating)
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock)

from vllm_ascend.ops.fla import fused_sigmoid_gating_delta_rule_update
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This import will cause a runtime error because fused_sigmoid_gating_delta_rule_update is not defined in vllm_ascend.ops.fla. Furthermore, this function is not used anywhere in the file. This line should be removed.

from vllm_ascend.ops.fused_gdn_gating import fused_gdn_gating

class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):

Expand Down
67 changes: 67 additions & 0 deletions vllm_ascend/ops/fused_gdn_gating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch

Check failure on line 2 in vllm_ascend/ops/fused_gdn_gating.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 2 in vllm_ascend/ops/fused_gdn_gating.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 2 in vllm_ascend/ops/fused_gdn_gating.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 2 in vllm_ascend/ops/fused_gdn_gating.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 2 in vllm_ascend/ops/fused_gdn_gating.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton": module is installed, but missing library stubs or py.typed marker [import-untyped]
import triton

Check failure on line 3 in vllm_ascend/ops/fused_gdn_gating.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.language": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 3 in vllm_ascend/ops/fused_gdn_gating.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.language": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 3 in vllm_ascend/ops/fused_gdn_gating.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.language": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 3 in vllm_ascend/ops/fused_gdn_gating.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.language": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 3 in vllm_ascend/ops/fused_gdn_gating.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.language": module is installed, but missing library stubs or py.typed marker [import-untyped]
import triton.language as tl

# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
@triton.jit
def fused_gdn_gating_kernel(
g,
A_log,
a,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
NUM_BATCHES: tl.constexpr,
beta: tl.constexpr,
threshold: tl.constexpr,
BLK_HEADS: tl.constexpr,
BLK_BATCHES: tl.constexpr
):
i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
batch_off = i_b * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
off = batch_off[:, None] * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off[None, :]
head_mask = head_off < NUM_HEADS
mask = head_mask[None, :] & (batch_off[:, None] < NUM_BATCHES)
blk_A_log = tl.load(A_log + head_off, mask=head_mask)
blk_a = tl.load(a + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=head_mask)
# If the model is loaded in fp16, without the .float() here, A might be -inf
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)[None, :]
softplus_x = tl.where(beta * x <= threshold,
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)


def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
batch, num_heads = a.shape
seq_len = 1
NUM_BATCH_GROUPS = batch
BLK_BATCHES = 1
if batch > 40:
BLK_BATCHES = triton.next_power_of_2(triton.cdiv(batch, 32))
NUM_BATCH_GROUPS = triton.cdiv(batch, BLK_BATCHES)

grid = (NUM_BATCH_GROUPS, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty_like(a, dtype=torch.float32)
fused_gdn_gating_kernel[grid](g,
A_log,
a,
dt_bias,
seq_len,
num_heads,
batch,
beta,
threshold,
8,
BLK_BATCHES=BLK_BATCHES,
num_warps=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using num_warps=1 is likely to be suboptimal for performance. Triton kernels, especially memory-bound ones, benefit from having multiple warps to hide memory latency (memory-level parallelism). While this kernel has some compute-intensive operations (exp, log), it still involves significant data movement. A single warp might underutilize the hardware's execution units and memory bandwidth.

I recommend increasing num_warps (e.g., to 4 or 8) and benchmarking to find the optimal value for your target hardware.

Suggested change
num_warps=1)
num_warps=4)

return g
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The Triton kernel fused_gdn_gating_kernel implicitly assumes that the input tensor a and output tensor g are contiguous, as it performs pointer arithmetic without considering strides. However, a could be non-contiguous depending on the operations that produced it. Using torch.empty_like(a) for g will preserve the memory layout of a, so g could also be non-contiguous. This can lead to incorrect memory access and wrong results.

To fix this, you should ensure a is contiguous before using it and create g as a new contiguous tensor.

def fused_gdn_gating(
    A_log: torch.Tensor,
    a: torch.Tensor,
    dt_bias: torch.Tensor,
    beta: float = 1.0,
    threshold: float = 20.0,
) -> torch.Tensor:
    a = a.contiguous()
    batch, num_heads = a.shape
    seq_len = 1
    NUM_BATCH_GROUPS = batch
    BLK_BATCHES = 1
    if batch > 40:
        BLK_BATCHES = triton.next_power_of_2(triton.cdiv(batch, 32))
        NUM_BATCH_GROUPS = triton.cdiv(batch, BLK_BATCHES)

    grid = (NUM_BATCH_GROUPS, seq_len, triton.cdiv(num_heads, 8))
    g = torch.empty((batch, num_heads), dtype=torch.float32, device=a.device)
    fused_gdn_gating_kernel[grid](g,
                                  A_log,
                                  a,
                                  dt_bias,
                                  seq_len,
                                  num_heads,
                                  batch,
                                  beta,
                                  threshold,
                                  8,
                                  BLK_BATCHES=BLK_BATCHES,
                                  num_warps=1)
    return g

Loading