Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit ef6b7e6

Browse files
virginiafdezvirginiafdez
andauthored
Allowed for the quantized flag to be passed to the LatentDiffusionInferer methods, which is then passed to VQVAE encode_stage_2_inputs if autoencoder_model is a VQVAE. (#481)
Set this flag randomly during testing (when the autoencoder is a VAE, it shouldn't matter), ran the tests, and ran reformatting. + controlnet.py has been changed for reformatting purposes only. Co-authored-by: virginiafdez <virginia.fernandez@kcl.ac.uk>
1 parent 866e282 commit ef6b7e6

File tree

3 files changed

+53
-15
lines changed

3 files changed

+53
-15
lines changed

generative/inferers/inferer.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from monai.transforms import CenterSpatialCrop, SpatialPad
2424
from monai.utils import optional_import
2525

26-
from generative.networks.nets import SPADEAutoencoderKL, SPADEDiffusionModelUNet
26+
from generative.networks.nets import VQVAE, SPADEAutoencoderKL, SPADEDiffusionModelUNet
2727

2828
tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
2929

@@ -362,6 +362,7 @@ def __call__(
362362
condition: torch.Tensor | None = None,
363363
mode: str = "crossattn",
364364
seg: torch.Tensor | None = None,
365+
quantized: bool = True,
365366
) -> torch.Tensor:
366367
"""
367368
Implements the forward pass for a supervised training iteration.
@@ -375,9 +376,14 @@ def __call__(
375376
condition: conditioning for network input.
376377
mode: Conditioning mode for the network.
377378
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
379+
quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
380+
are quantized or not.
378381
"""
379382
with torch.no_grad():
380-
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
383+
autoencode = autoencoder_model.encode_stage_2_inputs
384+
if isinstance(autoencoder_model, VQVAE):
385+
autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized)
386+
latent = autoencode(inputs) * self.scale_factor
381387

382388
if self.ldm_latent_shape is not None:
383389
latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
@@ -496,6 +502,7 @@ def get_likelihood(
496502
resample_latent_likelihoods: bool = False,
497503
resample_interpolation_mode: str = "nearest",
498504
seg: torch.Tensor | None = None,
505+
quantized: bool = True,
499506
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
500507
"""
501508
Computes the log-likelihoods of the latent representations of the input.
@@ -517,12 +524,18 @@ def get_likelihood(
517524
or 'trilinear;
518525
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
519526
is instance of SPADEAutoencoderKL, segmentation must be provided.
527+
quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
528+
are quantized or not.
520529
"""
521530
if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
522531
raise ValueError(
523532
f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
524533
)
525-
latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
534+
535+
autoencode = autoencoder_model.encode_stage_2_inputs
536+
if isinstance(autoencoder_model, VQVAE):
537+
autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized)
538+
latents = autoencode(inputs) * self.scale_factor
526539

527540
if self.ldm_latent_shape is not None:
528541
latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)
@@ -882,6 +895,7 @@ def __call__(
882895
condition: torch.Tensor | None = None,
883896
mode: str = "crossattn",
884897
seg: torch.Tensor | None = None,
898+
quantized: bool = True,
885899
) -> torch.Tensor:
886900
"""
887901
Implements the forward pass for a supervised training iteration.
@@ -897,9 +911,14 @@ def __call__(
897911
condition: conditioning for network input.
898912
mode: Conditioning mode for the network.
899913
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
914+
quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
915+
are quantized or not.
900916
"""
901917
with torch.no_grad():
902-
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
918+
autoencode = autoencoder_model.encode_stage_2_inputs
919+
if isinstance(autoencoder_model, VQVAE):
920+
autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized)
921+
latent = autoencode(inputs) * self.scale_factor
903922

904923
if self.ldm_latent_shape is not None:
905924
latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
@@ -1036,6 +1055,7 @@ def get_likelihood(
10361055
resample_latent_likelihoods: bool = False,
10371056
resample_interpolation_mode: str = "nearest",
10381057
seg: torch.Tensor | None = None,
1058+
quantized: bool = True,
10391059
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
10401060
"""
10411061
Computes the log-likelihoods of the latent representations of the input.
@@ -1059,13 +1079,19 @@ def get_likelihood(
10591079
or 'trilinear;
10601080
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
10611081
is instance of SPADEAutoencoderKL, segmentation must be provided.
1082+
quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
1083+
are quantized or not.
10621084
"""
10631085
if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
10641086
raise ValueError(
10651087
f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
10661088
)
10671089

1068-
latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
1090+
with torch.no_grad():
1091+
autoencode = autoencoder_model.encode_stage_2_inputs
1092+
if isinstance(autoencoder_model, VQVAE):
1093+
autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized)
1094+
latents = autoencode(inputs) * self.scale_factor
10691095

10701096
if cn_cond.shape[2:] != latents.shape[2:]:
10711097
cn_cond = F.interpolate(cn_cond, latents.shape[2:])

generative/networks/nets/controlnet.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
from generative.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding
4343

44+
4445
class ControlNetConditioningEmbedding(nn.Module):
4546
"""
4647
Network to encode the conditioning into a latent space.
@@ -120,26 +121,28 @@ def zero_module(module):
120121
nn.init.zeros_(p)
121122
return module
122123

123-
def copy_weights_to_controlnet(controlnet : nn.Module,
124-
diffusion_model: nn.Module,
125-
verbose: bool = True) -> None:
126-
'''
124+
125+
def copy_weights_to_controlnet(controlnet: nn.Module, diffusion_model: nn.Module, verbose: bool = True) -> None:
126+
"""
127127
Copy the state dict from the input diffusion model to the ControlNet, printing, if user requires it, the output
128128
keys that have matched and those that haven't.
129129
130130
Args:
131131
controlnet: instance of ControlNet
132132
diffusion_model: instance of DiffusionModelUnet or SPADEDiffusionModelUnet
133133
verbose: if True, the matched and unmatched keys will be printed.
134-
'''
134+
"""
135135

136-
output = controlnet.load_state_dict(diffusion_model.state_dict(), strict = False)
136+
output = controlnet.load_state_dict(diffusion_model.state_dict(), strict=False)
137137
if verbose:
138138
dm_keys = [p[0] for p in list(diffusion_model.named_parameters()) if p[0] not in output.unexpected_keys]
139-
print(f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:"
140-
f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:"
141-
f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:"
142-
f"\n{'; '.join(output.unexpected_keys)}")
139+
print(
140+
f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:"
141+
f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:"
142+
f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:"
143+
f"\n{'; '.join(output.unexpected_keys)}"
144+
)
145+
143146

144147
class ControlNet(nn.Module):
145148
"""

tests/test_latent_diffusion_inferer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import unittest
1515

16+
import numpy as np
1617
import torch
1718
from parameterized import parameterized
1819

@@ -329,6 +330,7 @@ def test_prediction_shape(
329330
seg=input_seg,
330331
noise=noise,
331332
timesteps=timesteps,
333+
quantized=np.random.choice([True, False]),
332334
)
333335
else:
334336
prediction = inferer(
@@ -472,6 +474,7 @@ def test_get_likelihoods(
472474
scheduler=scheduler,
473475
save_intermediates=True,
474476
seg=input_seg,
477+
quantized=np.random.choice([True, False]),
475478
)
476479
else:
477480
sample, intermediates = inferer.get_likelihood(
@@ -480,6 +483,7 @@ def test_get_likelihoods(
480483
diffusion_model=stage_2,
481484
scheduler=scheduler,
482485
save_intermediates=True,
486+
quantized=np.random.choice([True, False]),
483487
)
484488
self.assertEqual(len(intermediates), 10)
485489
self.assertEqual(intermediates[0].shape, latent_shape)
@@ -525,6 +529,7 @@ def test_resample_likelihoods(
525529
save_intermediates=True,
526530
resample_latent_likelihoods=True,
527531
seg=input_seg,
532+
quantized=np.random.choice([True, False]),
528533
)
529534
else:
530535
sample, intermediates = inferer.get_likelihood(
@@ -534,6 +539,7 @@ def test_resample_likelihoods(
534539
scheduler=scheduler,
535540
save_intermediates=True,
536541
resample_latent_likelihoods=True,
542+
quantized=np.random.choice([True, False]),
537543
)
538544
self.assertEqual(len(intermediates), 10)
539545
self.assertEqual(intermediates[0].shape[2:], input_shape[2:])
@@ -590,6 +596,7 @@ def test_prediction_shape_conditioned_concat(
590596
condition=conditioning,
591597
mode="concat",
592598
seg=input_seg,
599+
quantized=np.random.choice([True, False]),
593600
)
594601
else:
595602
prediction = inferer(
@@ -600,6 +607,7 @@ def test_prediction_shape_conditioned_concat(
600607
timesteps=timesteps,
601608
condition=conditioning,
602609
mode="concat",
610+
quantized=np.random.choice([True, False]),
603611
)
604612
self.assertEqual(prediction.shape, latent_shape)
605613

@@ -713,6 +721,7 @@ def test_sample_shape_different_latents(
713721
noise=noise,
714722
timesteps=timesteps,
715723
seg=input_seg,
724+
quantized=np.random.choice([True, False]),
716725
)
717726
else:
718727
prediction = inferer(

0 commit comments

Comments
 (0)