|
84 | 84 | HAS_CUTLASS_BLACKWELL = False |
85 | 85 |
|
86 | 86 |
|
| 87 | +# [Optional] flash_fwd cute-DSL backend |
| 88 | +HAS_FLASH_CUTE = True |
| 89 | +try: |
| 90 | + from flash_attn.cute.interface import ( |
| 91 | + flash_attn_func as flash_attn_cute_func |
| 92 | + ) |
| 93 | +except (ImportError, IOError, AttributeError): |
| 94 | + HAS_FLASH_CUTE = False |
| 95 | + flash_attn_cute_func = None # Define it as None to avoid NameError |
| 96 | + |
| 97 | + |
87 | 98 | def parse_op_args(args: List[str]): |
88 | 99 | parser = argparse.ArgumentParser() |
89 | 100 | parser.add_argument("--batch", type=int, help="Batch size") |
@@ -619,6 +630,26 @@ def cutlass_blackwell_fmha_decode( |
619 | 630 | return lambda: cutlass_blackwell_fmha_func( |
620 | 631 | q, k_cache, v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv |
621 | 632 | ) |
| 633 | + |
| 634 | + @register_benchmark(enabled=HAS_FLASH_CUTE) |
| 635 | + def flash_cute_dsl( |
| 636 | + self, |
| 637 | + q: torch.Tensor, |
| 638 | + k_cache: torch.Tensor, |
| 639 | + v_cache: torch.Tensor, |
| 640 | + cache_seqlens: torch.Tensor, |
| 641 | + ) -> Callable: |
| 642 | + """Flash Attention implementation using cute-DSL backend.""" |
| 643 | + # For GQA, cute-DSL handles the head expansion internally |
| 644 | + # We pass the original KV tensors without manual expansion |
| 645 | + q_heads = q.shape[2] |
| 646 | + kv_heads = k_cache.shape[2] |
| 647 | + return lambda:flash_attn_cute_func( |
| 648 | + q, k_cache, v_cache, |
| 649 | + causal=CAUSAL, |
| 650 | + pack_gqa=(q_heads != kv_heads) |
| 651 | + ) |
| 652 | + |
622 | 653 | @register_benchmark(enabled=HAS_AITER) |
623 | 654 | def aiter_paged_fp8kv( |
624 | 655 | self, |
|
0 commit comments