Skip to content

Commit 28d5884

Browse files
Aya Ibrahimfacebook-github-bot
authored andcommitted
FAv4 CuteDSL Bench for decode
Summary: use of headq = 8 , is doing much better. Maybe because headq= 5 probably doesn't work with TMA_q used here. Differential Revision: D80830933
1 parent cb6e5ff commit 28d5884

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tritonbench/operators/decoding_attention/operator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,17 @@
8484
HAS_CUTLASS_BLACKWELL = False
8585

8686

87+
# [Optional] flash_fwd cute-DSL backend
88+
HAS_FLASH_CUTE = True
89+
try:
90+
from flash_attn.cute.interface import (
91+
flash_attn_func as flash_attn_cute_func
92+
)
93+
except (ImportError, IOError, AttributeError):
94+
HAS_FLASH_CUTE = False
95+
flash_attn_cute_func = None # Define it as None to avoid NameError
96+
97+
8798
def parse_op_args(args: List[str]):
8899
parser = argparse.ArgumentParser()
89100
parser.add_argument("--batch", type=int, help="Batch size")
@@ -619,6 +630,26 @@ def cutlass_blackwell_fmha_decode(
619630
return lambda: cutlass_blackwell_fmha_func(
620631
q, k_cache, v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv
621632
)
633+
634+
@register_benchmark(enabled=HAS_FLASH_CUTE)
635+
def flash_cute_dsl(
636+
self,
637+
q: torch.Tensor,
638+
k_cache: torch.Tensor,
639+
v_cache: torch.Tensor,
640+
cache_seqlens: torch.Tensor,
641+
) -> Callable:
642+
"""Flash Attention implementation using cute-DSL backend."""
643+
# For GQA, cute-DSL handles the head expansion internally
644+
# We pass the original KV tensors without manual expansion
645+
q_heads = q.shape[2]
646+
kv_heads = k_cache.shape[2]
647+
return lambda:flash_attn_cute_func(
648+
q, k_cache, v_cache,
649+
causal=CAUSAL,
650+
pack_gqa=(q_heads != kv_heads)
651+
)
652+
622653
@register_benchmark(enabled=HAS_AITER)
623654
def aiter_paged_fp8kv(
624655
self,

0 commit comments

Comments
 (0)