2323from monai .transforms import CenterSpatialCrop , SpatialPad
2424from monai .utils import optional_import
2525
26- from generative .networks .nets import SPADEAutoencoderKL , SPADEDiffusionModelUNet
26+ from generative .networks .nets import VQVAE , SPADEAutoencoderKL , SPADEDiffusionModelUNet
2727
2828tqdm , has_tqdm = optional_import ("tqdm" , name = "tqdm" )
2929
@@ -362,6 +362,7 @@ def __call__(
362362 condition : torch .Tensor | None = None ,
363363 mode : str = "crossattn" ,
364364 seg : torch .Tensor | None = None ,
365+ quantized : bool = True ,
365366 ) -> torch .Tensor :
366367 """
367368 Implements the forward pass for a supervised training iteration.
@@ -375,9 +376,14 @@ def __call__(
375376 condition: conditioning for network input.
376377 mode: Conditioning mode for the network.
377378 seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
379+ quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
380+ are quantized or not.
378381 """
379382 with torch .no_grad ():
380- latent = autoencoder_model .encode_stage_2_inputs (inputs ) * self .scale_factor
383+ autoencode = autoencoder_model .encode_stage_2_inputs
384+ if isinstance (autoencoder_model , VQVAE ):
385+ autoencode = partial (autoencoder_model .encode_stage_2_inputs , quantized = quantized )
386+ latent = autoencode (inputs ) * self .scale_factor
381387
382388 if self .ldm_latent_shape is not None :
383389 latent = torch .stack ([self .ldm_resizer (i ) for i in decollate_batch (latent )], 0 )
@@ -496,6 +502,7 @@ def get_likelihood(
496502 resample_latent_likelihoods : bool = False ,
497503 resample_interpolation_mode : str = "nearest" ,
498504 seg : torch .Tensor | None = None ,
505+ quantized : bool = True ,
499506 ) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
500507 """
501508 Computes the log-likelihoods of the latent representations of the input.
@@ -517,12 +524,18 @@ def get_likelihood(
517524 or 'trilinear;
518525 seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
519526 is instance of SPADEAutoencoderKL, segmentation must be provided.
527+ quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
528+ are quantized or not.
520529 """
521530 if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest" , "bilinear" , "trilinear" ):
522531 raise ValueError (
523532 f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got { resample_interpolation_mode } "
524533 )
525- latents = autoencoder_model .encode_stage_2_inputs (inputs ) * self .scale_factor
534+
535+ autoencode = autoencoder_model .encode_stage_2_inputs
536+ if isinstance (autoencoder_model , VQVAE ):
537+ autoencode = partial (autoencoder_model .encode_stage_2_inputs , quantized = quantized )
538+ latents = autoencode (inputs ) * self .scale_factor
526539
527540 if self .ldm_latent_shape is not None :
528541 latents = torch .stack ([self .ldm_resizer (i ) for i in decollate_batch (latents )], 0 )
@@ -882,6 +895,7 @@ def __call__(
882895 condition : torch .Tensor | None = None ,
883896 mode : str = "crossattn" ,
884897 seg : torch .Tensor | None = None ,
898+ quantized : bool = True ,
885899 ) -> torch .Tensor :
886900 """
887901 Implements the forward pass for a supervised training iteration.
@@ -897,9 +911,14 @@ def __call__(
897911 condition: conditioning for network input.
898912 mode: Conditioning mode for the network.
899913 seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
914+ quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
915+ are quantized or not.
900916 """
901917 with torch .no_grad ():
902- latent = autoencoder_model .encode_stage_2_inputs (inputs ) * self .scale_factor
918+ autoencode = autoencoder_model .encode_stage_2_inputs
919+ if isinstance (autoencoder_model , VQVAE ):
920+ autoencode = partial (autoencoder_model .encode_stage_2_inputs , quantized = quantized )
921+ latent = autoencode (inputs ) * self .scale_factor
903922
904923 if self .ldm_latent_shape is not None :
905924 latent = torch .stack ([self .ldm_resizer (i ) for i in decollate_batch (latent )], 0 )
@@ -1036,6 +1055,7 @@ def get_likelihood(
10361055 resample_latent_likelihoods : bool = False ,
10371056 resample_interpolation_mode : str = "nearest" ,
10381057 seg : torch .Tensor | None = None ,
1058+ quantized : bool = True ,
10391059 ) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
10401060 """
10411061 Computes the log-likelihoods of the latent representations of the input.
@@ -1059,13 +1079,19 @@ def get_likelihood(
10591079 or 'trilinear;
10601080 seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
10611081 is instance of SPADEAutoencoderKL, segmentation must be provided.
1082+ quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
1083+ are quantized or not.
10621084 """
10631085 if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest" , "bilinear" , "trilinear" ):
10641086 raise ValueError (
10651087 f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got { resample_interpolation_mode } "
10661088 )
10671089
1068- latents = autoencoder_model .encode_stage_2_inputs (inputs ) * self .scale_factor
1090+ with torch .no_grad ():
1091+ autoencode = autoencoder_model .encode_stage_2_inputs
1092+ if isinstance (autoencoder_model , VQVAE ):
1093+ autoencode = partial (autoencoder_model .encode_stage_2_inputs , quantized = quantized )
1094+ latents = autoencode (inputs ) * self .scale_factor
10691095
10701096 if cn_cond .shape [2 :] != latents .shape [2 :]:
10711097 cn_cond = F .interpolate (cn_cond , latents .shape [2 :])
0 commit comments