|
43 | 43 | LoRAAttnProcessor2_0, |
44 | 44 | XFormersAttnProcessor, |
45 | 45 | ) |
46 | | -from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu, slow, torch_device |
| 46 | +from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, slow, torch_device |
47 | 47 |
|
48 | 48 |
|
49 | 49 | def create_unet_lora_layers(unet: nn.Module): |
@@ -1464,3 +1464,41 @@ def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self): |
1464 | 1464 | expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) |
1465 | 1465 |
|
1466 | 1466 | self.assertTrue(np.allclose(images, expected, atol=1e-3)) |
| 1467 | + |
| 1468 | + @nightly |
| 1469 | + def test_sequential_fuse_unfuse(self): |
| 1470 | + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") |
| 1471 | + |
| 1472 | + # 1. round |
| 1473 | + pipe.load_lora_weights("Pclanglais/TintinIA") |
| 1474 | + pipe.fuse_lora() |
| 1475 | + |
| 1476 | + generator = torch.Generator().manual_seed(0) |
| 1477 | + images = pipe( |
| 1478 | + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 |
| 1479 | + ).images |
| 1480 | + image_slice = images[0, -3:, -3:, -1].flatten() |
| 1481 | + |
| 1482 | + pipe.unfuse_lora() |
| 1483 | + |
| 1484 | + # 2. round |
| 1485 | + pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style") |
| 1486 | + pipe.fuse_lora() |
| 1487 | + pipe.unfuse_lora() |
| 1488 | + |
| 1489 | + # 3. round |
| 1490 | + pipe.load_lora_weights("ostris/crayon_style_lora_sdxl") |
| 1491 | + pipe.fuse_lora() |
| 1492 | + pipe.unfuse_lora() |
| 1493 | + |
| 1494 | + # 4. back to 1st round |
| 1495 | + pipe.load_lora_weights("Pclanglais/TintinIA") |
| 1496 | + pipe.fuse_lora() |
| 1497 | + |
| 1498 | + generator = torch.Generator().manual_seed(0) |
| 1499 | + images_2 = pipe( |
| 1500 | + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 |
| 1501 | + ).images |
| 1502 | + image_slice_2 = images_2[0, -3:, -3:, -1].flatten() |
| 1503 | + |
| 1504 | + self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3)) |
0 commit comments