1919import torch
2020from parameterized import parameterized
2121
22+ from diffusers import AutoencoderKL
2223from diffusers .hooks import HookRegistry , ModelHook
2324from diffusers .models import ModelMixin
2425from diffusers .pipelines .pipeline_utils import DiffusionPipeline
@@ -149,78 +150,6 @@ def post_forward(self, module, output):
149150 return output
150151
151152
152- # Model simulating VAE structure with standalone computational layers
153- class DummyVAELikeModel (ModelMixin ):
154- def __init__ (self , in_features : int , hidden_features : int , out_features : int ) -> None :
155- super ().__init__ ()
156-
157- # Encoder container (not ModuleList/Sequential at top level)
158- self .encoder = torch .nn .Sequential (
159- torch .nn .Linear (in_features , hidden_features ),
160- torch .nn .ReLU (),
161- )
162-
163- # Standalone Conv2d layer (simulates quant_conv)
164- self .quant_conv = torch .nn .Conv2d (1 , 1 , kernel_size = 1 )
165-
166- # Decoder container with nested ModuleList
167- self .decoder = DecoderWithNestedBlocks (hidden_features , hidden_features )
168-
169- # Standalone Conv2d layer (simulates post_quant_conv)
170- self .post_quant_conv = torch .nn .Conv2d (1 , 1 , kernel_size = 1 )
171-
172- # Output projection
173- self .linear_out = torch .nn .Linear (hidden_features , out_features )
174-
175- def forward (self , x : torch .Tensor ) -> torch .Tensor :
176- # Encode
177- x = self .encoder (x )
178-
179- # Reshape for conv operations
180- batch_size = x .shape [0 ]
181- x_reshaped = x .view (batch_size , 1 , - 1 , 1 )
182-
183- # Apply standalone conv layers
184- x_reshaped = self .quant_conv (x_reshaped )
185- x_reshaped = self .post_quant_conv (x_reshaped )
186-
187- # Reshape back
188- x = x_reshaped .view (batch_size , - 1 )
189-
190- # Decode
191- x = self .decoder (x )
192-
193- # Output
194- x = self .linear_out (x )
195- return x
196-
197-
198- class DecoderWithNestedBlocks (torch .nn .Module ):
199- def __init__ (self , in_features : int , out_features : int ) -> None :
200- super ().__init__ ()
201-
202- # Container modules (not ModuleList/Sequential)
203- self .conv_in = torch .nn .Linear (in_features , in_features )
204-
205- # Nested ModuleList (like VAE's decoder.up_blocks)
206- self .up_blocks = torch .nn .ModuleList (
207- [torch .nn .Linear (in_features , in_features ), torch .nn .Linear (in_features , in_features )]
208- )
209-
210- # Non-computational layer
211- self .norm = torch .nn .LayerNorm (in_features )
212-
213- self .conv_out = torch .nn .Linear (in_features , out_features )
214-
215- def forward (self , x : torch .Tensor ) -> torch .Tensor :
216- x = self .conv_in (x )
217- for block in self .up_blocks :
218- x = block (x )
219- x = self .norm (x )
220- x = self .conv_out (x )
221- return x
222-
223-
224153# Model with only standalone computational layers at top level
225154class DummyModelWithStandaloneLayers (ModelMixin ):
226155 def __init__ (self , in_features : int , hidden_features : int , out_features : int ) -> None :
@@ -503,45 +432,25 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
503432 cumulated_absmax , 1e-5 , f"Output differences for { name } exceeded threshold: { cumulated_absmax :.5f} "
504433 )
505434
506- def test_vae_like_model_with_standalone_conv_layers (self ):
507- """Test that models with standalone Conv2d layers (like VAE) work with block-level offloading."""
508- if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
509- return
510-
511- model = DummyVAELikeModel (in_features = 64 , hidden_features = 128 , out_features = 64 )
512-
513- model_ref = DummyVAELikeModel (in_features = 64 , hidden_features = 128 , out_features = 64 )
514- model_ref .load_state_dict (model .state_dict (), strict = True )
515- model_ref .to (torch_device )
516-
517- model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 1 , use_stream = True )
518-
519- x = torch .randn (2 , 64 ).to (torch_device )
520-
521- with torch .no_grad ():
522- out_ref = model_ref (x )
523- out = model (x )
524-
525- self .assertTrue (torch .allclose (out_ref , out , atol = 1e-5 ), "Outputs do not match for VAE-like model." )
526-
527435 def test_vae_like_model_without_streams (self ):
528436 """Test VAE-like model with block-level offloading but without streams."""
529437 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
530438 return
531439
532- model = DummyVAELikeModel (in_features = 64 , hidden_features = 128 , out_features = 64 )
440+ config = self .get_autoencoder_kl_config ()
441+ model = AutoencoderKL (** config )
533442
534- model_ref = DummyVAELikeModel ( in_features = 64 , hidden_features = 128 , out_features = 64 )
443+ model_ref = AutoencoderKL ( ** config )
535444 model_ref .load_state_dict (model .state_dict (), strict = True )
536445 model_ref .to (torch_device )
537446
538447 model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 1 , use_stream = False )
539448
540- x = torch .randn (2 , 64 ).to (torch_device )
449+ x = torch .randn (2 , 3 , 32 , 32 ).to (torch_device )
541450
542451 with torch .no_grad ():
543- out_ref = model_ref (x )
544- out = model (x )
452+ out_ref = model_ref (x ). sample
453+ out = model (x ). sample
545454
546455 self .assertTrue (
547456 torch .allclose (out_ref , out , atol = 1e-5 ), "Outputs do not match for VAE-like model without streams."
@@ -597,19 +506,20 @@ def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str)
597506 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
598507 return
599508
600- model = DummyVAELikeModel (in_features = 64 , hidden_features = 128 , out_features = 64 )
509+ config = self .get_autoencoder_kl_config ()
510+ model = AutoencoderKL (** config )
601511
602- model_ref = DummyVAELikeModel ( in_features = 64 , hidden_features = 128 , out_features = 64 )
512+ model_ref = AutoencoderKL ( ** config )
603513 model_ref .load_state_dict (model .state_dict (), strict = True )
604514 model_ref .to (torch_device )
605515
606516 model .enable_group_offload (torch_device , offload_type = offload_type , num_blocks_per_group = 1 , use_stream = True )
607517
608- x = torch .randn (2 , 64 ).to (torch_device )
518+ x = torch .randn (2 , 3 , 32 , 32 ).to (torch_device )
609519
610520 with torch .no_grad ():
611- out_ref = model_ref (x )
612- out = model (x )
521+ out_ref = model_ref (x ). sample
522+ out = model (x ). sample
613523
614524 self .assertTrue (
615525 torch .allclose (out_ref , out , atol = 1e-5 ),
@@ -621,20 +531,21 @@ def test_multiple_invocations_with_vae_like_model(self):
621531 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
622532 return
623533
624- model = DummyVAELikeModel (in_features = 64 , hidden_features = 128 , out_features = 64 )
534+ config = self .get_autoencoder_kl_config ()
535+ model = AutoencoderKL (** config )
625536
626- model_ref = DummyVAELikeModel ( in_features = 64 , hidden_features = 128 , out_features = 64 )
537+ model_ref = AutoencoderKL ( ** config )
627538 model_ref .load_state_dict (model .state_dict (), strict = True )
628539 model_ref .to (torch_device )
629540
630541 model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 1 , use_stream = True )
631542
632- x = torch .randn (2 , 64 ).to (torch_device )
543+ x = torch .randn (2 , 3 , 32 , 32 ).to (torch_device )
633544
634545 with torch .no_grad ():
635- for i in range (5 ):
636- out_ref = model_ref (x )
637- out = model (x )
546+ for i in range (2 ):
547+ out_ref = model_ref (x ). sample
548+ out = model (x ). sample
638549 self .assertTrue (torch .allclose (out_ref , out , atol = 1e-5 ), f"Outputs do not match at iteration { i } ." )
639550
640551 def test_nested_container_parameters_offloading (self ):
@@ -660,3 +571,18 @@ def test_nested_container_parameters_offloading(self):
660571 torch .allclose (out_ref , out , atol = 1e-5 ),
661572 f"Outputs do not match at iteration { i } for nested parameters." ,
662573 )
574+
575+ def get_autoencoder_kl_config (self , block_out_channels = None , norm_num_groups = None ):
576+ block_out_channels = block_out_channels or [2 , 4 ]
577+ norm_num_groups = norm_num_groups or 2
578+ init_dict = {
579+ "block_out_channels" : block_out_channels ,
580+ "in_channels" : 3 ,
581+ "out_channels" : 3 ,
582+ "down_block_types" : ["DownEncoderBlock2D" ] * len (block_out_channels ),
583+ "up_block_types" : ["UpDecoderBlock2D" ] * len (block_out_channels ),
584+ "latent_channels" : 4 ,
585+ "norm_num_groups" : norm_num_groups ,
586+ "layers_per_block" : 1 ,
587+ }
588+ return init_dict
0 commit comments