diff --git a/submodules/aiter b/submodules/aiter index 5000ee01f..afdfdf4b1 160000 --- a/submodules/aiter +++ b/submodules/aiter @@ -1 +1 @@ -Subproject commit 5000ee01f1f6770378752b63fef2ea232025e793 +Subproject commit afdfdf4b1aebf5e68206028f90d9ec83a4a17fb9 diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 34b4428dc..cc5ea49a1 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -78,6 +78,9 @@ from .test_fmha_utils import make_packed_qkv +with try_import("HAS_AITER"): + from aiter.ops.triton.mha import flash_attn_func as aiter_flash_attn_func + HAS_CUDA_124 = ( torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.4" ) @@ -362,6 +365,20 @@ def sdpa_flash_attention(q, k, v): v, ) + @register_benchmark(enabled=is_hip() and HAS_AITER) + def aiter(self, q, k, v): + def _inner(): + return aiter_flash_attn_func( + q, + k, + v, + softmax_scale=self.sm_scale, + causal=self.causal, + deterministic=self.deterministic, + ) + + return _inner + if IS_B200: # Only enable calling this benchmark directly. @register_benchmark(enabled=False)