Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 106 additions & 51 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import hashlib
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from dataclasses import dataclass, replace
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -59,6 +59,9 @@ class GroupOffloadingConfig:
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
block_modules: Optional[List[str]] = None
exclude_kwargs: Optional[List[str]] = None
module_prefix: Optional[str] = ""


class ModuleGroup:
Expand All @@ -77,7 +80,7 @@ def __init__(
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
group_id: Optional[int] = None,
group_id: Optional[Union[int, str]] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
Expand Down Expand Up @@ -320,7 +323,21 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
self.group.stream.synchronize()

args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)

# Some Autoencoder models use a feature cache that is passed through submodules
# and modified in place. The `send_to_device` call returns a copy of this feature cache object
# which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
exclude_kwargs = self.config.exclude_kwargs or []
if exclude_kwargs:
moved_kwargs = send_to_device(
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
self.group.onload_device,
non_blocking=self.group.non_blocking,
)
kwargs.update(moved_kwargs)
else:
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)

return args, kwargs

def post_forward(self, module: torch.nn.Module, output):
Expand Down Expand Up @@ -453,6 +470,8 @@ def apply_group_offloading(
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
block_modules: Optional[List[str]] = None,
exclude_kwargs: Optional[List[str]] = None,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
Expand Down Expand Up @@ -510,6 +529,13 @@ def apply_group_offloading(
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
block_modules (`List[str]`, *optional*):
List of module names that should be treated as blocks for offloading. If provided, only these modules will
be considered for block-level offloading. If not provided, the default block detection logic will be used.
exclude_kwargs (`List[str]`, *optional*):
List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
caching lists that need to maintain their object identity across forward passes. If not provided, will be
inferred from the module's `_skip_keys` attribute if it exists.

Example:
```python
Expand Down Expand Up @@ -551,6 +577,12 @@ def apply_group_offloading(

_raise_error_if_accelerate_model_or_sequential_hook_present(module)

if block_modules is None:
block_modules = getattr(module, "_group_offload_block_modules", None)

if exclude_kwargs is None:
exclude_kwargs = getattr(module, "_skip_keys", None)

config = GroupOffloadingConfig(
onload_device=onload_device,
offload_device=offload_device,
Expand All @@ -561,6 +593,8 @@ def apply_group_offloading(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
exclude_kwargs=exclude_kwargs,
)
_apply_group_offloading(module, config)

Expand All @@ -576,46 +610,66 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf

def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
done at the top-level blocks and modules specified in block_modules.

When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
module, recursively apply block offloading to it.
"""
if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
)
config.num_blocks_per_group = 1

# Create module groups for ModuleList and Sequential blocks
block_modules = set(config.block_modules) if config.block_modules is not None else set()

# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
modules_with_group_offloading = set()
unmatched_modules = []
matched_module_groups = []

for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
unmatched_modules.append((name, submodule))
# Check if this is an explicitly defined block module
if name in block_modules:
# Track submodule using a prefix to avoid filename collisions during disk offload.
# Without this, submodules sharing the same model class would be assigned identical
# filenames (derived from the class name).
prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
submodule_config = replace(config, module_prefix=prefix)

_apply_group_offloading_block_level(submodule, submodule_config)
modules_with_group_offloading.add(name)
continue

for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Handle ModuleList and Sequential blocks as before
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = list(submodule[i : i + config.num_blocks_per_group])
if len(current_modules) == 0:
continue

group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
# This is an unmatched module
unmatched_modules.append((name, submodule))

# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
Expand All @@ -630,28 +684,29 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers]

# Create a group for the unmatched submodules of the top-level module so that they are on the correct
# device when the forward pass is called.
# Create a group for the remaining unmatched submodules of the top-level
# module so that they are on the correct device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)


def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]

@register_to_config
def __init__(
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
feat_idx[0] += 1
else:
x = self.conv_out(x)

return x


Expand Down Expand Up @@ -961,6 +962,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
"""

_supports_gradient_checkpointing = False
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
Expand Down Expand Up @@ -1408,6 +1410,7 @@ def forward(
"""
x = sample
posterior = self.encode(x).latent_dist

if sample_posterior:
z = posterior.sample(generator=generator)
else:
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@ def enable_group_offload(
record_stream: bool = False,
low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
block_modules: Optional[str] = None,
exclude_kwargs: Optional[str] = None,
) -> None:
r"""
Activates group offloading for the current model.
Expand Down Expand Up @@ -570,6 +572,7 @@ def enable_group_offload(
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
f"open an issue at https://github.com/huggingface/diffusers/issues."
)

apply_group_offloading(
module=self,
onload_device=onload_device,
Expand All @@ -581,6 +584,8 @@ def enable_group_offload(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
exclude_kwargs=exclude_kwargs,
)

def set_attention_backend(self, backend: str) -> None:
Expand Down
Loading
Loading