Skip to content

Commit f12d161

Browse files
rycerzesgithub-actions[bot]DN6
authored
Fix broken group offloading with block_level for models with standalone layers (#12692)
* fix: group offloading to support standalone computational layers in block-level offloading * test: for models with standalone and deeply nested layers in block-level offloading * feat: support for block-level offloading in group offloading config * fix: group offload block modules to AutoencoderKL and AutoencoderKLWan * fix: update group offloading tests to use AutoencoderKL and adjust input dimensions * refactor: streamline block offloading logic * Apply style fixes * update tests * update * fix for failing tests * clean up * revert to use skip_keys * clean up --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 8d415a6 commit f12d161

File tree

7 files changed

+353
-67
lines changed

7 files changed

+353
-67
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 106 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import hashlib
1616
import os
1717
from contextlib import contextmanager, nullcontext
18-
from dataclasses import dataclass
18+
from dataclasses import dataclass, replace
1919
from enum import Enum
2020
from typing import Dict, List, Optional, Set, Tuple, Union
2121

@@ -59,6 +59,9 @@ class GroupOffloadingConfig:
5959
num_blocks_per_group: Optional[int] = None
6060
offload_to_disk_path: Optional[str] = None
6161
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
62+
block_modules: Optional[List[str]] = None
63+
exclude_kwargs: Optional[List[str]] = None
64+
module_prefix: Optional[str] = ""
6265

6366

6467
class ModuleGroup:
@@ -77,7 +80,7 @@ def __init__(
7780
low_cpu_mem_usage: bool = False,
7881
onload_self: bool = True,
7982
offload_to_disk_path: Optional[str] = None,
80-
group_id: Optional[int] = None,
83+
group_id: Optional[Union[int, str]] = None,
8184
) -> None:
8285
self.modules = modules
8386
self.offload_device = offload_device
@@ -322,7 +325,21 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
322325
self.group.stream.synchronize()
323326

324327
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
325-
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
328+
329+
# Some Autoencoder models use a feature cache that is passed through submodules
330+
# and modified in place. The `send_to_device` call returns a copy of this feature cache object
331+
# which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
332+
exclude_kwargs = self.config.exclude_kwargs or []
333+
if exclude_kwargs:
334+
moved_kwargs = send_to_device(
335+
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
336+
self.group.onload_device,
337+
non_blocking=self.group.non_blocking,
338+
)
339+
kwargs.update(moved_kwargs)
340+
else:
341+
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
342+
326343
return args, kwargs
327344

328345
def post_forward(self, module: torch.nn.Module, output):
@@ -455,6 +472,8 @@ def apply_group_offloading(
455472
record_stream: bool = False,
456473
low_cpu_mem_usage: bool = False,
457474
offload_to_disk_path: Optional[str] = None,
475+
block_modules: Optional[List[str]] = None,
476+
exclude_kwargs: Optional[List[str]] = None,
458477
) -> None:
459478
r"""
460479
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -512,6 +531,13 @@ def apply_group_offloading(
512531
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
513532
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
514533
the CPU memory is a bottleneck but may counteract the benefits of using streams.
534+
block_modules (`List[str]`, *optional*):
535+
List of module names that should be treated as blocks for offloading. If provided, only these modules will
536+
be considered for block-level offloading. If not provided, the default block detection logic will be used.
537+
exclude_kwargs (`List[str]`, *optional*):
538+
List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
539+
caching lists that need to maintain their object identity across forward passes. If not provided, will be
540+
inferred from the module's `_skip_keys` attribute if it exists.
515541
516542
Example:
517543
```python
@@ -553,6 +579,12 @@ def apply_group_offloading(
553579

554580
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
555581

582+
if block_modules is None:
583+
block_modules = getattr(module, "_group_offload_block_modules", None)
584+
585+
if exclude_kwargs is None:
586+
exclude_kwargs = getattr(module, "_skip_keys", None)
587+
556588
config = GroupOffloadingConfig(
557589
onload_device=onload_device,
558590
offload_device=offload_device,
@@ -563,6 +595,8 @@ def apply_group_offloading(
563595
record_stream=record_stream,
564596
low_cpu_mem_usage=low_cpu_mem_usage,
565597
offload_to_disk_path=offload_to_disk_path,
598+
block_modules=block_modules,
599+
exclude_kwargs=exclude_kwargs,
566600
)
567601
_apply_group_offloading(module, config)
568602

@@ -578,46 +612,66 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf
578612

579613
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
580614
r"""
581-
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
582-
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
583-
"""
615+
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
616+
defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
617+
done at the top-level blocks and modules specified in block_modules.
584618
619+
When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
620+
module, recursively apply block offloading to it.
621+
"""
585622
if config.stream is not None and config.num_blocks_per_group != 1:
586623
logger.warning(
587624
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
588625
)
589626
config.num_blocks_per_group = 1
590627

591-
# Create module groups for ModuleList and Sequential blocks
628+
block_modules = set(config.block_modules) if config.block_modules is not None else set()
629+
630+
# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
592631
modules_with_group_offloading = set()
593632
unmatched_modules = []
594633
matched_module_groups = []
634+
595635
for name, submodule in module.named_children():
596-
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
597-
unmatched_modules.append((name, submodule))
636+
# Check if this is an explicitly defined block module
637+
if name in block_modules:
638+
# Track submodule using a prefix to avoid filename collisions during disk offload.
639+
# Without this, submodules sharing the same model class would be assigned identical
640+
# filenames (derived from the class name).
641+
prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
642+
submodule_config = replace(config, module_prefix=prefix)
643+
644+
_apply_group_offloading_block_level(submodule, submodule_config)
598645
modules_with_group_offloading.add(name)
599-
continue
600646

601-
for i in range(0, len(submodule), config.num_blocks_per_group):
602-
current_modules = submodule[i : i + config.num_blocks_per_group]
603-
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
604-
group = ModuleGroup(
605-
modules=current_modules,
606-
offload_device=config.offload_device,
607-
onload_device=config.onload_device,
608-
offload_to_disk_path=config.offload_to_disk_path,
609-
offload_leader=current_modules[-1],
610-
onload_leader=current_modules[0],
611-
non_blocking=config.non_blocking,
612-
stream=config.stream,
613-
record_stream=config.record_stream,
614-
low_cpu_mem_usage=config.low_cpu_mem_usage,
615-
onload_self=True,
616-
group_id=group_id,
617-
)
618-
matched_module_groups.append(group)
619-
for j in range(i, i + len(current_modules)):
620-
modules_with_group_offloading.add(f"{name}.{j}")
647+
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
648+
# Handle ModuleList and Sequential blocks as before
649+
for i in range(0, len(submodule), config.num_blocks_per_group):
650+
current_modules = list(submodule[i : i + config.num_blocks_per_group])
651+
if len(current_modules) == 0:
652+
continue
653+
654+
group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
655+
group = ModuleGroup(
656+
modules=current_modules,
657+
offload_device=config.offload_device,
658+
onload_device=config.onload_device,
659+
offload_to_disk_path=config.offload_to_disk_path,
660+
offload_leader=current_modules[-1],
661+
onload_leader=current_modules[0],
662+
non_blocking=config.non_blocking,
663+
stream=config.stream,
664+
record_stream=config.record_stream,
665+
low_cpu_mem_usage=config.low_cpu_mem_usage,
666+
onload_self=True,
667+
group_id=group_id,
668+
)
669+
matched_module_groups.append(group)
670+
for j in range(i, i + len(current_modules)):
671+
modules_with_group_offloading.add(f"{name}.{j}")
672+
else:
673+
# This is an unmatched module
674+
unmatched_modules.append((name, submodule))
621675

622676
# Apply group offloading hooks to the module groups
623677
for i, group in enumerate(matched_module_groups):
@@ -632,28 +686,29 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
632686
parameters = [param for _, param in parameters]
633687
buffers = [buffer for _, buffer in buffers]
634688

635-
# Create a group for the unmatched submodules of the top-level module so that they are on the correct
636-
# device when the forward pass is called.
689+
# Create a group for the remaining unmatched submodules of the top-level
690+
# module so that they are on the correct device when the forward pass is called.
637691
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
638-
unmatched_group = ModuleGroup(
639-
modules=unmatched_modules,
640-
offload_device=config.offload_device,
641-
onload_device=config.onload_device,
642-
offload_to_disk_path=config.offload_to_disk_path,
643-
offload_leader=module,
644-
onload_leader=module,
645-
parameters=parameters,
646-
buffers=buffers,
647-
non_blocking=False,
648-
stream=None,
649-
record_stream=False,
650-
onload_self=True,
651-
group_id=f"{module.__class__.__name__}_unmatched_group",
652-
)
653-
if config.stream is None:
654-
_apply_group_offloading_hook(module, unmatched_group, config=config)
655-
else:
656-
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
692+
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
693+
unmatched_group = ModuleGroup(
694+
modules=unmatched_modules,
695+
offload_device=config.offload_device,
696+
onload_device=config.onload_device,
697+
offload_to_disk_path=config.offload_to_disk_path,
698+
offload_leader=module,
699+
onload_leader=module,
700+
parameters=parameters,
701+
buffers=buffers,
702+
non_blocking=False,
703+
stream=None,
704+
record_stream=False,
705+
onload_self=True,
706+
group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
707+
)
708+
if config.stream is None:
709+
_apply_group_offloading_hook(module, unmatched_group, config=config)
710+
else:
711+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
657712

658713

659714
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class AutoencoderKL(
7474

7575
_supports_gradient_checkpointing = True
7676
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
77+
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
7778

7879
@register_to_config
7980
def __init__(

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
619619
feat_idx[0] += 1
620620
else:
621621
x = self.conv_out(x)
622+
622623
return x
623624

624625

@@ -961,6 +962,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
961962
"""
962963

963964
_supports_gradient_checkpointing = False
965+
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
964966
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
965967
# these are shared mutable state modified in-place
966968
_skip_keys = ["feat_cache", "feat_idx"]
@@ -1414,6 +1416,7 @@ def forward(
14141416
"""
14151417
x = sample
14161418
posterior = self.encode(x).latent_dist
1419+
14171420
if sample_posterior:
14181421
z = posterior.sample(generator=generator)
14191422
else:

src/diffusers/models/modeling_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,8 @@ def enable_group_offload(
531531
record_stream: bool = False,
532532
low_cpu_mem_usage=False,
533533
offload_to_disk_path: Optional[str] = None,
534+
block_modules: Optional[str] = None,
535+
exclude_kwargs: Optional[str] = None,
534536
) -> None:
535537
r"""
536538
Activates group offloading for the current model.
@@ -570,6 +572,7 @@ def enable_group_offload(
570572
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
571573
f"open an issue at https://github.com/huggingface/diffusers/issues."
572574
)
575+
573576
apply_group_offloading(
574577
module=self,
575578
onload_device=onload_device,
@@ -581,6 +584,8 @@ def enable_group_offload(
581584
record_stream=record_stream,
582585
low_cpu_mem_usage=low_cpu_mem_usage,
583586
offload_to_disk_path=offload_to_disk_path,
587+
block_modules=block_modules,
588+
exclude_kwargs=exclude_kwargs,
584589
)
585590

586591
def set_attention_backend(self, backend: str) -> None:

0 commit comments

Comments
 (0)