From e0747719ce1195d744cdf30ca5801eaee6adfa9a Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 31 Oct 2025 20:58:10 -0700 Subject: [PATCH 1/2] update aiter --- submodules/aiter | 2 +- .../operators/flash_attention/operator.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/submodules/aiter b/submodules/aiter index 5000ee01f..afdfdf4b1 160000 --- a/submodules/aiter +++ b/submodules/aiter @@ -1 +1 @@ -Subproject commit 5000ee01f1f6770378752b63fef2ea232025e793 +Subproject commit afdfdf4b1aebf5e68206028f90d9ec83a4a17fb9 diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 34b4428dc..88723c7ae 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -78,6 +78,9 @@ from .test_fmha_utils import make_packed_qkv +with try_import("HAS_AITER"): + from aiter.ops.triton.mha import flash_attn_func as aiter_flash_attn_func + HAS_CUDA_124 = ( torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.4" ) @@ -362,6 +365,21 @@ def sdpa_flash_attention(q, k, v): v, ) + + @register_benchmark(enabled=is_hip() and HAS_AITER) + def aiter(self, q, k, v): + def _inner(): + return aiter_flash_attn_func( + q, + k, + v, + softmax_scale=self.sm_scale, + causal=self.causal, + deterministic=self.deterministic, + ) + return _inner + + if IS_B200: # Only enable calling this benchmark directly. @register_benchmark(enabled=False) From 1324d106fe092144a58ca30f0529f6a6cffe852f Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 31 Oct 2025 21:19:43 -0700 Subject: [PATCH 2/2] fix lint --- tritonbench/operators/flash_attention/operator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 88723c7ae..cc5ea49a1 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -365,7 +365,6 @@ def sdpa_flash_attention(q, k, v): v, ) - @register_benchmark(enabled=is_hip() and HAS_AITER) def aiter(self, q, k, v): def _inner(): @@ -377,8 +376,8 @@ def _inner(): causal=self.causal, deterministic=self.deterministic, ) - return _inner + return _inner if IS_B200: # Only enable calling this benchmark directly.