Skip to content

Commit 664c492

Browse files
authored
Merge branch 'main' into fix/broken-group-offloading-using-block_level
2 parents e71d91e + d176f61 commit 664c492

File tree

7 files changed

+173
-45
lines changed

7 files changed

+173
-45
lines changed

docs/source/en/optimization/attention_backends.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,14 @@ Refer to the table below for a complete list of available attention backends and
139139
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
140140
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
141141
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
142+
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
142143
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
143144
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
144145
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
145146
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
146147
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
147148
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
149+
| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
148150
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
149151
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
150152
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |

src/diffusers/models/attention_dispatch.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import math
1919
from dataclasses import dataclass
2020
from enum import Enum
21-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
2222

2323
import torch
2424

@@ -160,16 +160,13 @@ def wrap(func):
160160
# - CP with sage attention, flex, xformers, other missing backends
161161
# - Add support for normal and CP training with backends that don't support it yet
162162

163-
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
164-
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
165-
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
166-
167163

168164
class AttentionBackendName(str, Enum):
169165
# EAGER = "eager"
170166

171167
# `flash-attn`
172168
FLASH = "flash"
169+
FLASH_HUB = "flash_hub"
173170
FLASH_VARLEN = "flash_varlen"
174171
_FLASH_3 = "_flash_3"
175172
_FLASH_VARLEN_3 = "_flash_varlen_3"
@@ -191,6 +188,7 @@ class AttentionBackendName(str, Enum):
191188

192189
# `sageattention`
193190
SAGE = "sage"
191+
SAGE_HUB = "sage_hub"
194192
SAGE_VARLEN = "sage_varlen"
195193
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
196194
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
@@ -264,7 +262,13 @@ class _HubKernelConfig:
264262
# TODO: temporary revision for now. Remove when merged upstream into `main`.
265263
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
266264
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
267-
)
265+
),
266+
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
267+
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
268+
),
269+
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
270+
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
271+
),
268272
}
269273

270274

@@ -420,8 +424,8 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
420424
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
421425
)
422426

423-
# TODO: add support Hub variant of FA3 varlen later
424-
elif backend in [AttentionBackendName._FLASH_3_HUB]:
427+
# TODO: add support Hub variant of varlen later
428+
elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.FLASH_HUB, AttentionBackendName.SAGE_HUB]:
425429
if not is_kernels_available():
426430
raise RuntimeError(
427431
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
@@ -1350,6 +1354,38 @@ def _flash_attention(
13501354
return (out, lse) if return_lse else out
13511355

13521356

1357+
@_AttentionBackendRegistry.register(
1358+
AttentionBackendName.FLASH_HUB,
1359+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1360+
supports_context_parallel=False,
1361+
)
1362+
def _flash_attention_hub(
1363+
query: torch.Tensor,
1364+
key: torch.Tensor,
1365+
value: torch.Tensor,
1366+
dropout_p: float = 0.0,
1367+
is_causal: bool = False,
1368+
scale: Optional[float] = None,
1369+
return_lse: bool = False,
1370+
_parallel_config: Optional["ParallelConfig"] = None,
1371+
) -> torch.Tensor:
1372+
lse = None
1373+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
1374+
out = func(
1375+
q=query,
1376+
k=key,
1377+
v=value,
1378+
dropout_p=dropout_p,
1379+
softmax_scale=scale,
1380+
causal=is_causal,
1381+
return_attn_probs=return_lse,
1382+
)
1383+
if return_lse:
1384+
out, lse, *_ = out
1385+
1386+
return (out, lse) if return_lse else out
1387+
1388+
13531389
@_AttentionBackendRegistry.register(
13541390
AttentionBackendName.FLASH_VARLEN,
13551391
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
@@ -1431,6 +1467,7 @@ def _flash_attention_3(
14311467
@_AttentionBackendRegistry.register(
14321468
AttentionBackendName._FLASH_3_HUB,
14331469
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1470+
supports_context_parallel=False,
14341471
)
14351472
def _flash_attention_3_hub(
14361473
query: torch.Tensor,
@@ -1444,6 +1481,9 @@ def _flash_attention_3_hub(
14441481
return_attn_probs: bool = False,
14451482
_parallel_config: Optional["ParallelConfig"] = None,
14461483
) -> torch.Tensor:
1484+
if _parallel_config:
1485+
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
1486+
14471487
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
14481488
out = func(
14491489
q=query,
@@ -1938,6 +1978,38 @@ def _sage_attention(
19381978
return (out, lse) if return_lse else out
19391979

19401980

1981+
@_AttentionBackendRegistry.register(
1982+
AttentionBackendName.SAGE_HUB,
1983+
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1984+
supports_context_parallel=False,
1985+
)
1986+
def _sage_attention_hub(
1987+
query: torch.Tensor,
1988+
key: torch.Tensor,
1989+
value: torch.Tensor,
1990+
is_causal: bool = False,
1991+
scale: Optional[float] = None,
1992+
return_lse: bool = False,
1993+
_parallel_config: Optional["ParallelConfig"] = None,
1994+
) -> torch.Tensor:
1995+
lse = None
1996+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
1997+
if _parallel_config is None:
1998+
out = func(
1999+
q=query,
2000+
k=key,
2001+
v=value,
2002+
tensor_layout="NHD",
2003+
is_causal=is_causal,
2004+
sm_scale=scale,
2005+
return_lse=return_lse,
2006+
)
2007+
if return_lse:
2008+
out, lse, *_ = out
2009+
2010+
return (out, lse) if return_lse else out
2011+
2012+
19412013
@_AttentionBackendRegistry.register(
19422014
AttentionBackendName.SAGE_VARLEN,
19432015
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],

src/diffusers/models/transformers/transformer_chronoedit.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
6767
return key_img, value_img
6868

6969

70-
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
70+
# modified from diffusers.models.transformers.transformer_wan.WanAttnProcessor
7171
class WanAttnProcessor:
7272
_attention_backend = None
7373
_parallel_config = None
@@ -137,7 +137,8 @@ def apply_rotary_emb(
137137
dropout_p=0.0,
138138
is_causal=False,
139139
backend=self._attention_backend,
140-
parallel_config=self._parallel_config,
140+
# Reference: https://github.com/huggingface/diffusers/pull/12660
141+
parallel_config=None,
141142
)
142143
hidden_states_img = hidden_states_img.flatten(2, 3)
143144
hidden_states_img = hidden_states_img.type_as(query)
@@ -150,7 +151,8 @@ def apply_rotary_emb(
150151
dropout_p=0.0,
151152
is_causal=False,
152153
backend=self._attention_backend,
153-
parallel_config=self._parallel_config,
154+
# Reference: https://github.com/huggingface/diffusers/pull/12660
155+
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
154156
)
155157
hidden_states = hidden_states.flatten(2, 3)
156158
hidden_states = hidden_states.type_as(query)
@@ -568,9 +570,11 @@ class ChronoEditTransformer3DModel(
568570
"blocks.0": {
569571
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
570572
},
571-
"blocks.*": {
572-
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
573-
},
573+
# Reference: https://github.com/huggingface/diffusers/pull/12660
574+
# We need to disable the splitting of encoder_hidden_states because
575+
# the image_encoder consistently generates 257 tokens for image_embed. This causes
576+
# the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
577+
# after concatenation—to be indivisible by the number of devices in the CP.
574578
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
575579
}
576580

src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
"""
7474

7575

76-
class BriaFiboPipeline(DiffusionPipeline):
76+
class BriaFiboPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
7777
r"""
7878
Args:
7979
transformer (`BriaFiboTransformer2DModel`):

src/diffusers/schedulers/scheduling_dpmsolver_sde.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
488488
t = t.reshape(sigma.shape)
489489
return t
490490

491-
# copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
491+
# Copied from diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler._convert_to_karras
492492
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
493-
"""Constructs the noise schedule of Karras et al. (2022)."""
493+
"""
494+
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
495+
Models](https://huggingface.co/papers/2206.00364).
496+
497+
Args:
498+
in_sigmas (`torch.Tensor`):
499+
The input sigma values to be converted.
500+
501+
Returns:
502+
`torch.Tensor`:
503+
The converted sigma values following the Karras noise schedule.
504+
"""
494505

495506
sigma_min: float = in_sigmas[-1].item()
496507
sigma_max: float = in_sigmas[0].item()

0 commit comments

Comments
 (0)