Skip to content

Conversation

@jvlunteren
Copy link
Contributor

@jvlunteren jvlunteren commented Nov 7, 2025

Purpose

This pull request depends on PR #27993.

It adapts the 3D Triton attention kernel, which is used exclusively for decode operations, to support full CUDA Graphs. The key changes include:

  • The allocation of the intermediate data structures used for the tiled softmax implementation has been moved to the attention metadata builder class.

  • The dynamic selection between the 2D and 3D attention kernels during decode is now based on comparing the batch size against a threshold corresponding to one of the CUDA Graph capture sizes. This ensures that, for each batch size, only one valid kernel choice (either 2D or 3D kernel) exists and will be captured correctly.

The updated kernel selection logic is only applied when CUDA Graphs are enabled for decoding. This is now automatically detected in the attention metadata builder, which sets the appropriate threshold values accordingly.

Test Plan

The unit test ./tests/kernels/attention/test_triton_unified_attention.py has been updated to include the allocation of intermediate data structures required for the tiled softmax implementation. It has also been modified to explicitly test the separate use of the 2D and 3D attention kernels during decoding. Other tests, such as lm_eval, remain compatible and can be used without modification.

Test Result

unit test results for updated Triton unified attention kernel (this PR):

python3 -m pytest tests/kernels/attention/test_triton_unified_attention.py

================================================ 512 passed in 71.05s (0:01:11) ================================================


lm_eval results for updated Triton unified attention kernel (this PR):

VLLM_ATTENTION_BACKEND=TRITON_ATTN lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.800|±  |0.0179|
|     |       |strict-match    |     5|exact_match|↑  |0.788|±  |0.0183|

yields similar lm_eval results as FlashAttention:

lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.790|±  |0.0182|
|     |       |strict-match    |     5|exact_match|↑  |0.768|±  |0.0189|

@tdoublep @bringlein

jvlunteren and others added 7 commits November 3, 2025 11:28
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adapts the 3D Triton attention kernel for CUDA graph compatibility, which is a valuable improvement for performance. The core changes involve moving intermediate tensor allocations out of the kernel and aligning the 2D/3D kernel selection threshold with CUDA graph capture sizes. The refactoring of the Triton kernels for decode-only operations is clean and logical.

My review identifies one critical issue where an empty cudagraph_capture_sizes list would cause a ValueError and crash the server. I've provided a suggestion to handle this edge case gracefully. Otherwise, the changes are well-implemented and align with the stated goals.

Comment on lines +113 to +129
if self.decode_cudagraph_enabled:
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
upd_seq_threshold_3D = min(
self.vllm_config.compilation_config.cudagraph_capture_sizes,
key=lambda x: abs(x - self.seq_threshold_3D),
)

# If the updated threshold becomes significantly larger than the
# initial value, it is reset to zero. This enforces the use of the
# 2D kernel only and ensures that the size of the allocated
# intermediate structures remains bounded.
if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D:
self.seq_threshold_3D = upd_seq_threshold_3D
else:
self.seq_threshold_3D = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code does not handle the case where self.vllm_config.compilation_config.cudagraph_capture_sizes is an empty list. If it is empty, min() will raise a ValueError, causing a server crash on startup. This can happen if a user configures CUDA graphs but provides an empty list for cudagraph_capture_sizes.

We should add a check to ensure cudagraph_capture_sizes is not empty before calling min(). If it is empty, we should fall back to a safe default, such as setting self.seq_threshold_3D = 0 to always use the 2D kernel, which is a safe choice for CUDA graph compatibility.

        if self.decode_cudagraph_enabled:
            capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
            if not capture_sizes:
                # If no CUDA graph capture sizes are specified, we cannot
                # guarantee a static kernel choice. Forcing the 2D kernel
                # is the safest option.
                self.seq_threshold_3D = 0
            else:
                # Select the CUDA Graph capture size closest to self.seq_threshold_3D
                # as threshold. This ensures that each captured graph covers the
                # correct execution path.
                upd_seq_threshold_3D = min(
                    capture_sizes,
                    key=lambda x: abs(x - self.seq_threshold_3D),
                )

                # If the updated threshold becomes significantly larger than the
                # initial value, it is reset to zero. This enforces the use of the
                # 2D kernel only and ensures that the size of the allocated
                # intermediate structures remains bounded.
                if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D:
                    self.seq_threshold_3D = upd_seq_threshold_3D
                else:
                    self.seq_threshold_3D = 0

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +131 to +138
self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
self.softmax_segm_output = torch.empty(
(
self.seq_threshold_3D,
self.num_heads_q,
self.num_par_softmax_segments,
self.headdim,
),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preallocate softmax buffers with padded head dimension

The 3D decode kernel writes per-head outputs with a width of triton.next_power_of_2(head_size) (see HEAD_SIZE_PADDED usages in kernel_unified_attention_3d and reduce_segments). However, the metadata builder now preallocates softmax_segm_output with the unpadded self.headdim (lines 131‑137). For models whose head size is not already a power of two (e.g., 80 or 96), the kernel will index past the end of these buffers, corrupting memory or producing incorrect results whenever the 3D path is selected. The buffers should match the padded size used by the kernels.

Useful? React with 👍 / 👎.

@jvlunteren jvlunteren changed the title Jvl triton attn upd2 [Kernel] Support CUDA Graphs in 3D Triton Attention Kernel Nov 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant