Skip to content

Commit e71d91e

Browse files
committed
fix: update group offloading tests to use AutoencoderKL and adjust input dimensions
1 parent fb8a741 commit e71d91e

File tree

1 file changed

+35
-109
lines changed

1 file changed

+35
-109
lines changed

tests/hooks/test_group_offloading.py

Lines changed: 35 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from parameterized import parameterized
2121

22+
from diffusers import AutoencoderKL
2223
from diffusers.hooks import HookRegistry, ModelHook
2324
from diffusers.models import ModelMixin
2425
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
@@ -149,78 +150,6 @@ def post_forward(self, module, output):
149150
return output
150151

151152

152-
# Model simulating VAE structure with standalone computational layers
153-
class DummyVAELikeModel(ModelMixin):
154-
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
155-
super().__init__()
156-
157-
# Encoder container (not ModuleList/Sequential at top level)
158-
self.encoder = torch.nn.Sequential(
159-
torch.nn.Linear(in_features, hidden_features),
160-
torch.nn.ReLU(),
161-
)
162-
163-
# Standalone Conv2d layer (simulates quant_conv)
164-
self.quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1)
165-
166-
# Decoder container with nested ModuleList
167-
self.decoder = DecoderWithNestedBlocks(hidden_features, hidden_features)
168-
169-
# Standalone Conv2d layer (simulates post_quant_conv)
170-
self.post_quant_conv = torch.nn.Conv2d(1, 1, kernel_size=1)
171-
172-
# Output projection
173-
self.linear_out = torch.nn.Linear(hidden_features, out_features)
174-
175-
def forward(self, x: torch.Tensor) -> torch.Tensor:
176-
# Encode
177-
x = self.encoder(x)
178-
179-
# Reshape for conv operations
180-
batch_size = x.shape[0]
181-
x_reshaped = x.view(batch_size, 1, -1, 1)
182-
183-
# Apply standalone conv layers
184-
x_reshaped = self.quant_conv(x_reshaped)
185-
x_reshaped = self.post_quant_conv(x_reshaped)
186-
187-
# Reshape back
188-
x = x_reshaped.view(batch_size, -1)
189-
190-
# Decode
191-
x = self.decoder(x)
192-
193-
# Output
194-
x = self.linear_out(x)
195-
return x
196-
197-
198-
class DecoderWithNestedBlocks(torch.nn.Module):
199-
def __init__(self, in_features: int, out_features: int) -> None:
200-
super().__init__()
201-
202-
# Container modules (not ModuleList/Sequential)
203-
self.conv_in = torch.nn.Linear(in_features, in_features)
204-
205-
# Nested ModuleList (like VAE's decoder.up_blocks)
206-
self.up_blocks = torch.nn.ModuleList(
207-
[torch.nn.Linear(in_features, in_features), torch.nn.Linear(in_features, in_features)]
208-
)
209-
210-
# Non-computational layer
211-
self.norm = torch.nn.LayerNorm(in_features)
212-
213-
self.conv_out = torch.nn.Linear(in_features, out_features)
214-
215-
def forward(self, x: torch.Tensor) -> torch.Tensor:
216-
x = self.conv_in(x)
217-
for block in self.up_blocks:
218-
x = block(x)
219-
x = self.norm(x)
220-
x = self.conv_out(x)
221-
return x
222-
223-
224153
# Model with only standalone computational layers at top level
225154
class DummyModelWithStandaloneLayers(ModelMixin):
226155
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
@@ -503,45 +432,25 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
503432
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
504433
)
505434

506-
def test_vae_like_model_with_standalone_conv_layers(self):
507-
"""Test that models with standalone Conv2d layers (like VAE) work with block-level offloading."""
508-
if torch.device(torch_device).type not in ["cuda", "xpu"]:
509-
return
510-
511-
model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
512-
513-
model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
514-
model_ref.load_state_dict(model.state_dict(), strict=True)
515-
model_ref.to(torch_device)
516-
517-
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
518-
519-
x = torch.randn(2, 64).to(torch_device)
520-
521-
with torch.no_grad():
522-
out_ref = model_ref(x)
523-
out = model(x)
524-
525-
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model.")
526-
527435
def test_vae_like_model_without_streams(self):
528436
"""Test VAE-like model with block-level offloading but without streams."""
529437
if torch.device(torch_device).type not in ["cuda", "xpu"]:
530438
return
531439

532-
model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
440+
config = self.get_autoencoder_kl_config()
441+
model = AutoencoderKL(**config)
533442

534-
model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
443+
model_ref = AutoencoderKL(**config)
535444
model_ref.load_state_dict(model.state_dict(), strict=True)
536445
model_ref.to(torch_device)
537446

538447
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False)
539448

540-
x = torch.randn(2, 64).to(torch_device)
449+
x = torch.randn(2, 3, 32, 32).to(torch_device)
541450

542451
with torch.no_grad():
543-
out_ref = model_ref(x)
544-
out = model(x)
452+
out_ref = model_ref(x).sample
453+
out = model(x).sample
545454

546455
self.assertTrue(
547456
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
@@ -597,19 +506,20 @@ def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str)
597506
if torch.device(torch_device).type not in ["cuda", "xpu"]:
598507
return
599508

600-
model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
509+
config = self.get_autoencoder_kl_config()
510+
model = AutoencoderKL(**config)
601511

602-
model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
512+
model_ref = AutoencoderKL(**config)
603513
model_ref.load_state_dict(model.state_dict(), strict=True)
604514
model_ref.to(torch_device)
605515

606516
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
607517

608-
x = torch.randn(2, 64).to(torch_device)
518+
x = torch.randn(2, 3, 32, 32).to(torch_device)
609519

610520
with torch.no_grad():
611-
out_ref = model_ref(x)
612-
out = model(x)
521+
out_ref = model_ref(x).sample
522+
out = model(x).sample
613523

614524
self.assertTrue(
615525
torch.allclose(out_ref, out, atol=1e-5),
@@ -621,20 +531,21 @@ def test_multiple_invocations_with_vae_like_model(self):
621531
if torch.device(torch_device).type not in ["cuda", "xpu"]:
622532
return
623533

624-
model = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
534+
config = self.get_autoencoder_kl_config()
535+
model = AutoencoderKL(**config)
625536

626-
model_ref = DummyVAELikeModel(in_features=64, hidden_features=128, out_features=64)
537+
model_ref = AutoencoderKL(**config)
627538
model_ref.load_state_dict(model.state_dict(), strict=True)
628539
model_ref.to(torch_device)
629540

630541
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
631542

632-
x = torch.randn(2, 64).to(torch_device)
543+
x = torch.randn(2, 3, 32, 32).to(torch_device)
633544

634545
with torch.no_grad():
635-
for i in range(5):
636-
out_ref = model_ref(x)
637-
out = model(x)
546+
for i in range(2):
547+
out_ref = model_ref(x).sample
548+
out = model(x).sample
638549
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
639550

640551
def test_nested_container_parameters_offloading(self):
@@ -660,3 +571,18 @@ def test_nested_container_parameters_offloading(self):
660571
torch.allclose(out_ref, out, atol=1e-5),
661572
f"Outputs do not match at iteration {i} for nested parameters.",
662573
)
574+
575+
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
576+
block_out_channels = block_out_channels or [2, 4]
577+
norm_num_groups = norm_num_groups or 2
578+
init_dict = {
579+
"block_out_channels": block_out_channels,
580+
"in_channels": 3,
581+
"out_channels": 3,
582+
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
583+
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
584+
"latent_channels": 4,
585+
"norm_num_groups": norm_num_groups,
586+
"layers_per_block": 1,
587+
}
588+
return init_dict

0 commit comments

Comments
 (0)