-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Kernel] Support CUDA Graphs in 3D Triton Attention Kernel #28306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adapts the 3D Triton attention kernel for CUDA graph compatibility, which is a valuable improvement for performance. The core changes involve moving intermediate tensor allocations out of the kernel and aligning the 2D/3D kernel selection threshold with CUDA graph capture sizes. The refactoring of the Triton kernels for decode-only operations is clean and logical.
My review identifies one critical issue where an empty cudagraph_capture_sizes list would cause a ValueError and crash the server. I've provided a suggestion to handle this edge case gracefully. Otherwise, the changes are well-implemented and align with the stated goals.
| if self.decode_cudagraph_enabled: | ||
| # Select the CUDA Graph capture size closest to self.seq_threshold_3D | ||
| # as threshold. This ensures that each captured graph covers the | ||
| # correct execution path. | ||
| upd_seq_threshold_3D = min( | ||
| self.vllm_config.compilation_config.cudagraph_capture_sizes, | ||
| key=lambda x: abs(x - self.seq_threshold_3D), | ||
| ) | ||
|
|
||
| # If the updated threshold becomes significantly larger than the | ||
| # initial value, it is reset to zero. This enforces the use of the | ||
| # 2D kernel only and ensures that the size of the allocated | ||
| # intermediate structures remains bounded. | ||
| if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D: | ||
| self.seq_threshold_3D = upd_seq_threshold_3D | ||
| else: | ||
| self.seq_threshold_3D = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code does not handle the case where self.vllm_config.compilation_config.cudagraph_capture_sizes is an empty list. If it is empty, min() will raise a ValueError, causing a server crash on startup. This can happen if a user configures CUDA graphs but provides an empty list for cudagraph_capture_sizes.
We should add a check to ensure cudagraph_capture_sizes is not empty before calling min(). If it is empty, we should fall back to a safe default, such as setting self.seq_threshold_3D = 0 to always use the 2D kernel, which is a safe choice for CUDA graph compatibility.
if self.decode_cudagraph_enabled:
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
if not capture_sizes:
# If no CUDA graph capture sizes are specified, we cannot
# guarantee a static kernel choice. Forcing the 2D kernel
# is the safest option.
self.seq_threshold_3D = 0
else:
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
upd_seq_threshold_3D = min(
capture_sizes,
key=lambda x: abs(x - self.seq_threshold_3D),
)
# If the updated threshold becomes significantly larger than the
# initial value, it is reset to zero. This enforces the use of the
# 2D kernel only and ensures that the size of the allocated
# intermediate structures remains bounded.
if upd_seq_threshold_3D <= 4 * self.seq_threshold_3D:
self.seq_threshold_3D = upd_seq_threshold_3D
else:
self.seq_threshold_3D = 0There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS | ||
| self.softmax_segm_output = torch.empty( | ||
| ( | ||
| self.seq_threshold_3D, | ||
| self.num_heads_q, | ||
| self.num_par_softmax_segments, | ||
| self.headdim, | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preallocate softmax buffers with padded head dimension
The 3D decode kernel writes per-head outputs with a width of triton.next_power_of_2(head_size) (see HEAD_SIZE_PADDED usages in kernel_unified_attention_3d and reduce_segments). However, the metadata builder now preallocates softmax_segm_output with the unpadded self.headdim (lines 131‑137). For models whose head size is not already a power of two (e.g., 80 or 96), the kernel will index past the end of these buffers, corrupting memory or producing incorrect results whenever the 3D path is selected. The buffers should match the padded size used by the kernels.
Useful? React with 👍 / 👎.
Purpose
This pull request depends on PR #27993.
It adapts the 3D Triton attention kernel, which is used exclusively for decode operations, to support full CUDA Graphs. The key changes include:
The allocation of the intermediate data structures used for the tiled softmax implementation has been moved to the attention metadata builder class.
The dynamic selection between the 2D and 3D attention kernels during decode is now based on comparing the batch size against a threshold corresponding to one of the CUDA Graph capture sizes. This ensures that, for each batch size, only one valid kernel choice (either 2D or 3D kernel) exists and will be captured correctly.
The updated kernel selection logic is only applied when CUDA Graphs are enabled for decoding. This is now automatically detected in the attention metadata builder, which sets the appropriate threshold values accordingly.
Test Plan
The unit test
./tests/kernels/attention/test_triton_unified_attention.pyhas been updated to include the allocation of intermediate data structures required for the tiled softmax implementation. It has also been modified to explicitly test the separate use of the 2D and 3D attention kernels during decoding. Other tests, such aslm_eval, remain compatible and can be used without modification.Test Result
unit test results for updated Triton unified attention kernel (this PR):
lm_evalresults for updated Triton unified attention kernel (this PR):yields similar
lm_evalresults as FlashAttention:@tdoublep @bringlein