Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion submodules/aiter
Submodule aiter updated 2771 files
17 changes: 17 additions & 0 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@

from .test_fmha_utils import make_packed_qkv

with try_import("HAS_AITER"):
from aiter.ops.triton.mha import flash_attn_func as aiter_flash_attn_func

HAS_CUDA_124 = (
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.4"
)
Expand Down Expand Up @@ -362,6 +365,20 @@ def sdpa_flash_attention(q, k, v):
v,
)

@register_benchmark(enabled=is_hip() and HAS_AITER)
def aiter(self, q, k, v):
def _inner():
return aiter_flash_attn_func(
q,
k,
v,
softmax_scale=self.sm_scale,
causal=self.causal,
deterministic=self.deterministic,
)

return _inner

if IS_B200:
# Only enable calling this benchmark directly.
@register_benchmark(enabled=False)
Expand Down