1515import hashlib
1616import os
1717from contextlib import contextmanager , nullcontext
18- from dataclasses import dataclass
18+ from dataclasses import dataclass , replace
1919from enum import Enum
2020from typing import Dict , List , Optional , Set , Tuple , Union
2121
@@ -59,6 +59,9 @@ class GroupOffloadingConfig:
5959 num_blocks_per_group : Optional [int ] = None
6060 offload_to_disk_path : Optional [str ] = None
6161 stream : Optional [Union [torch .cuda .Stream , torch .Stream ]] = None
62+ block_modules : Optional [List [str ]] = None
63+ exclude_kwargs : Optional [List [str ]] = None
64+ module_prefix : Optional [str ] = ""
6265
6366
6467class ModuleGroup :
@@ -77,7 +80,7 @@ def __init__(
7780 low_cpu_mem_usage : bool = False ,
7881 onload_self : bool = True ,
7982 offload_to_disk_path : Optional [str ] = None ,
80- group_id : Optional [int ] = None ,
83+ group_id : Optional [Union [ int , str ] ] = None ,
8184 ) -> None :
8285 self .modules = modules
8386 self .offload_device = offload_device
@@ -322,7 +325,21 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
322325 self .group .stream .synchronize ()
323326
324327 args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
325- kwargs = send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
328+
329+ # Some Autoencoder models use a feature cache that is passed through submodules
330+ # and modified in place. The `send_to_device` call returns a copy of this feature cache object
331+ # which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
332+ exclude_kwargs = self .config .exclude_kwargs or []
333+ if exclude_kwargs :
334+ moved_kwargs = send_to_device (
335+ {k : v for k , v in kwargs .items () if k not in exclude_kwargs },
336+ self .group .onload_device ,
337+ non_blocking = self .group .non_blocking ,
338+ )
339+ kwargs .update (moved_kwargs )
340+ else :
341+ kwargs = send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
342+
326343 return args , kwargs
327344
328345 def post_forward (self , module : torch .nn .Module , output ):
@@ -455,6 +472,8 @@ def apply_group_offloading(
455472 record_stream : bool = False ,
456473 low_cpu_mem_usage : bool = False ,
457474 offload_to_disk_path : Optional [str ] = None ,
475+ block_modules : Optional [List [str ]] = None ,
476+ exclude_kwargs : Optional [List [str ]] = None ,
458477) -> None :
459478 r"""
460479 Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -512,6 +531,13 @@ def apply_group_offloading(
512531 If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
513532 option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
514533 the CPU memory is a bottleneck but may counteract the benefits of using streams.
534+ block_modules (`List[str]`, *optional*):
535+ List of module names that should be treated as blocks for offloading. If provided, only these modules will
536+ be considered for block-level offloading. If not provided, the default block detection logic will be used.
537+ exclude_kwargs (`List[str]`, *optional*):
538+ List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
539+ caching lists that need to maintain their object identity across forward passes. If not provided, will be
540+ inferred from the module's `_skip_keys` attribute if it exists.
515541
516542 Example:
517543 ```python
@@ -553,6 +579,12 @@ def apply_group_offloading(
553579
554580 _raise_error_if_accelerate_model_or_sequential_hook_present (module )
555581
582+ if block_modules is None :
583+ block_modules = getattr (module , "_group_offload_block_modules" , None )
584+
585+ if exclude_kwargs is None :
586+ exclude_kwargs = getattr (module , "_skip_keys" , None )
587+
556588 config = GroupOffloadingConfig (
557589 onload_device = onload_device ,
558590 offload_device = offload_device ,
@@ -563,6 +595,8 @@ def apply_group_offloading(
563595 record_stream = record_stream ,
564596 low_cpu_mem_usage = low_cpu_mem_usage ,
565597 offload_to_disk_path = offload_to_disk_path ,
598+ block_modules = block_modules ,
599+ exclude_kwargs = exclude_kwargs ,
566600 )
567601 _apply_group_offloading (module , config )
568602
@@ -578,46 +612,66 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf
578612
579613def _apply_group_offloading_block_level (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
580614 r"""
581- This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
582- the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
583- """
615+ This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
616+ defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
617+ done at the top-level blocks and modules specified in block_modules.
584618
619+ When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
620+ module, recursively apply block offloading to it.
621+ """
585622 if config .stream is not None and config .num_blocks_per_group != 1 :
586623 logger .warning (
587624 f"Using streams is only supported for num_blocks_per_group=1. Got { config .num_blocks_per_group = } . Setting it to 1."
588625 )
589626 config .num_blocks_per_group = 1
590627
591- # Create module groups for ModuleList and Sequential blocks
628+ block_modules = set (config .block_modules ) if config .block_modules is not None else set ()
629+
630+ # Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
592631 modules_with_group_offloading = set ()
593632 unmatched_modules = []
594633 matched_module_groups = []
634+
595635 for name , submodule in module .named_children ():
596- if not isinstance (submodule , (torch .nn .ModuleList , torch .nn .Sequential )):
597- unmatched_modules .append ((name , submodule ))
636+ # Check if this is an explicitly defined block module
637+ if name in block_modules :
638+ # Track submodule using a prefix to avoid filename collisions during disk offload.
639+ # Without this, submodules sharing the same model class would be assigned identical
640+ # filenames (derived from the class name).
641+ prefix = f"{ config .module_prefix } { name } ." if config .module_prefix else f"{ name } ."
642+ submodule_config = replace (config , module_prefix = prefix )
643+
644+ _apply_group_offloading_block_level (submodule , submodule_config )
598645 modules_with_group_offloading .add (name )
599- continue
600646
601- for i in range (0 , len (submodule ), config .num_blocks_per_group ):
602- current_modules = submodule [i : i + config .num_blocks_per_group ]
603- group_id = f"{ name } _{ i } _{ i + len (current_modules ) - 1 } "
604- group = ModuleGroup (
605- modules = current_modules ,
606- offload_device = config .offload_device ,
607- onload_device = config .onload_device ,
608- offload_to_disk_path = config .offload_to_disk_path ,
609- offload_leader = current_modules [- 1 ],
610- onload_leader = current_modules [0 ],
611- non_blocking = config .non_blocking ,
612- stream = config .stream ,
613- record_stream = config .record_stream ,
614- low_cpu_mem_usage = config .low_cpu_mem_usage ,
615- onload_self = True ,
616- group_id = group_id ,
617- )
618- matched_module_groups .append (group )
619- for j in range (i , i + len (current_modules )):
620- modules_with_group_offloading .add (f"{ name } .{ j } " )
647+ elif isinstance (submodule , (torch .nn .ModuleList , torch .nn .Sequential )):
648+ # Handle ModuleList and Sequential blocks as before
649+ for i in range (0 , len (submodule ), config .num_blocks_per_group ):
650+ current_modules = list (submodule [i : i + config .num_blocks_per_group ])
651+ if len (current_modules ) == 0 :
652+ continue
653+
654+ group_id = f"{ config .module_prefix } { name } _{ i } _{ i + len (current_modules ) - 1 } "
655+ group = ModuleGroup (
656+ modules = current_modules ,
657+ offload_device = config .offload_device ,
658+ onload_device = config .onload_device ,
659+ offload_to_disk_path = config .offload_to_disk_path ,
660+ offload_leader = current_modules [- 1 ],
661+ onload_leader = current_modules [0 ],
662+ non_blocking = config .non_blocking ,
663+ stream = config .stream ,
664+ record_stream = config .record_stream ,
665+ low_cpu_mem_usage = config .low_cpu_mem_usage ,
666+ onload_self = True ,
667+ group_id = group_id ,
668+ )
669+ matched_module_groups .append (group )
670+ for j in range (i , i + len (current_modules )):
671+ modules_with_group_offloading .add (f"{ name } .{ j } " )
672+ else :
673+ # This is an unmatched module
674+ unmatched_modules .append ((name , submodule ))
621675
622676 # Apply group offloading hooks to the module groups
623677 for i , group in enumerate (matched_module_groups ):
@@ -632,28 +686,29 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
632686 parameters = [param for _ , param in parameters ]
633687 buffers = [buffer for _ , buffer in buffers ]
634688
635- # Create a group for the unmatched submodules of the top-level module so that they are on the correct
636- # device when the forward pass is called.
689+ # Create a group for the remaining unmatched submodules of the top-level
690+ # module so that they are on the correct device when the forward pass is called.
637691 unmatched_modules = [unmatched_module for _ , unmatched_module in unmatched_modules ]
638- unmatched_group = ModuleGroup (
639- modules = unmatched_modules ,
640- offload_device = config .offload_device ,
641- onload_device = config .onload_device ,
642- offload_to_disk_path = config .offload_to_disk_path ,
643- offload_leader = module ,
644- onload_leader = module ,
645- parameters = parameters ,
646- buffers = buffers ,
647- non_blocking = False ,
648- stream = None ,
649- record_stream = False ,
650- onload_self = True ,
651- group_id = f"{ module .__class__ .__name__ } _unmatched_group" ,
652- )
653- if config .stream is None :
654- _apply_group_offloading_hook (module , unmatched_group , config = config )
655- else :
656- _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
692+ if len (unmatched_modules ) > 0 or len (parameters ) > 0 or len (buffers ) > 0 :
693+ unmatched_group = ModuleGroup (
694+ modules = unmatched_modules ,
695+ offload_device = config .offload_device ,
696+ onload_device = config .onload_device ,
697+ offload_to_disk_path = config .offload_to_disk_path ,
698+ offload_leader = module ,
699+ onload_leader = module ,
700+ parameters = parameters ,
701+ buffers = buffers ,
702+ non_blocking = False ,
703+ stream = None ,
704+ record_stream = False ,
705+ onload_self = True ,
706+ group_id = f"{ config .module_prefix } { module .__class__ .__name__ } _unmatched_group" ,
707+ )
708+ if config .stream is None :
709+ _apply_group_offloading_hook (module , unmatched_group , config = config )
710+ else :
711+ _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
657712
658713
659714def _apply_group_offloading_leaf_level (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
0 commit comments