3333
3434import numpy as np
3535import torch
36- import torch . nn as nn
36+ from monai . utils import StrEnum
3737
38+ from .scheduler import Scheduler
3839
39- class DDIMScheduler (nn .Module ):
40+
41+ class DDIMPredictionType (StrEnum ):
42+ """
43+ Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument.
44+
45+ epsilon: predicting the noise of the diffusion process
46+ sample: directly predicting the noisy sample
47+ v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
48+ """
49+
50+ EPSILON = "epsilon"
51+ SAMPLE = "sample"
52+ V_PREDICTION = "v_prediction"
53+
54+
55+ class DDIMScheduler (Scheduler ):
4056 """
4157 Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
4258 diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion
4359 Implicit Models" https://arxiv.org/abs/2010.02502
4460
4561 Args:
4662 num_train_timesteps: number of diffusion steps used to train the model.
47- beta_start: the starting `beta` value of inference.
48- beta_end: the final `beta` value.
49- beta_schedule: {``"linear"``, ``"scaled_linear"``}
50- the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
63+ schedule: member of NoiseSchedules, name of noise schedule function in component store
5164 clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
5265 set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.
5366 For the final step there is no previous alpha. When this option is `True` the previous alpha product is
5467 fixed to `1`, otherwise it uses the value of alpha at step 0.
5568 steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
5669 `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
5770 stable diffusion.
58- prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``}
59- prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
60- process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
61- https://imagen.research.google/video/paper.pdf)
71+ prediction_type: member of DDPMPredictionType
72+ schedule_args: arguments to pass to the schedule function
73+
6274 """
6375
6476 def __init__ (
6577 self ,
6678 num_train_timesteps : int = 1000 ,
67- beta_start : float = 1e-4 ,
68- beta_end : float = 2e-2 ,
69- beta_schedule : str = "linear" ,
79+ schedule : str = "linear_beta" ,
7080 clip_sample : bool = True ,
7181 set_alpha_to_one : bool = True ,
7282 steps_offset : int = 0 ,
73- prediction_type : str = "epsilon" ,
83+ prediction_type : str = DDIMPredictionType .EPSILON ,
84+ ** schedule_args ,
7485 ) -> None :
75- super ().__init__ ()
76- self .beta_schedule = beta_schedule
77- if beta_schedule == "linear" :
78- self .betas = torch .linspace (beta_start , beta_end , num_train_timesteps , dtype = torch .float32 )
79- elif beta_schedule == "scaled_linear" :
80- # this schedule is very specific to the latent diffusion model.
81- self .betas = (
82- torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = torch .float32 ) ** 2
83- )
84- else :
85- raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
86+ super ().__init__ (num_train_timesteps , schedule , ** schedule_args )
8687
87- if prediction_type .lower () not in ["epsilon" , "sample" , "v_prediction" ]:
88- raise ValueError (
89- f"prediction_type given as { prediction_type } must be one of `epsilon`, `sample`, or" " `v_prediction`"
90- )
88+ if prediction_type not in DDIMPredictionType .__members__ .values ():
89+ raise ValueError ("Argument `prediction_type` must be a member of DDIMPredictionType" )
9190
9291 self .prediction_type = prediction_type
93- self .num_train_timesteps = num_train_timesteps
94- self .alphas = 1.0 - self .betas
95- self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
9692
9793 # At every step in ddim, we are looking into the previous alphas_cumprod
9894 # For the final step, there is no previous alphas_cumprod because we are already at 0
@@ -103,13 +99,13 @@ def __init__(
10399 # standard deviation of the initial noise distribution
104100 self .init_noise_sigma = 1.0
105101
106- self .timesteps = torch .from_numpy (np .arange (0 , num_train_timesteps )[::- 1 ].astype (np .int64 ))
102+ self .timesteps = torch .from_numpy (np .arange (0 , self . num_train_timesteps )[::- 1 ].astype (np .int64 ))
107103
108104 self .clip_sample = clip_sample
109105 self .steps_offset = steps_offset
110106
111107 # default the number of inference timesteps to the number of train steps
112- self .set_timesteps (num_train_timesteps )
108+ self .set_timesteps (self . num_train_timesteps )
113109
114110 def set_timesteps (self , num_inference_steps : int , device : str | torch .device | None = None ) -> None :
115111 """
@@ -190,13 +186,13 @@ def step(
190186
191187 # 3. compute predicted original sample from predicted noise also called
192188 # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
193- if self .prediction_type == "epsilon" :
194- pred_original_sample = (sample - beta_prod_t ** ( 0.5 ) * model_output ) / alpha_prod_t ** ( 0.5 )
189+ if self .prediction_type == DDIMPredictionType . EPSILON :
190+ pred_original_sample = (sample - ( beta_prod_t ** 0.5 ) * model_output ) / ( alpha_prod_t ** 0.5 )
195191 pred_epsilon = model_output
196- elif self .prediction_type == "sample" :
192+ elif self .prediction_type == DDIMPredictionType . SAMPLE :
197193 pred_original_sample = model_output
198- pred_epsilon = (sample - alpha_prod_t ** ( 0.5 ) * pred_original_sample ) / beta_prod_t ** ( 0.5 )
199- elif self .prediction_type == "v_prediction" :
194+ pred_epsilon = (sample - ( alpha_prod_t ** 0.5 ) * pred_original_sample ) / ( beta_prod_t ** 0.5 )
195+ elif self .prediction_type == DDIMPredictionType . V_PREDICTION :
200196 pred_original_sample = (alpha_prod_t ** 0.5 ) * sample - (beta_prod_t ** 0.5 ) * model_output
201197 pred_epsilon = (alpha_prod_t ** 0.5 ) * model_output + (beta_prod_t ** 0.5 ) * sample
202198
@@ -207,19 +203,19 @@ def step(
207203 # 5. compute variance: "sigma_t(η)" -> see formula (16)
208204 # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
209205 variance = self ._get_variance (timestep , prev_timestep )
210- std_dev_t = eta * variance ** ( 0.5 )
206+ std_dev_t = eta * variance ** 0.5
211207
212208 # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
213- pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2 ) ** ( 0.5 ) * pred_epsilon
209+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2 ) ** 0.5 * pred_epsilon
214210
215211 # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
216- pred_prev_sample = alpha_prod_t_prev ** ( 0.5 ) * pred_original_sample + pred_sample_direction
212+ pred_prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
217213
218214 if eta > 0 :
219215 # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
220216 device = model_output .device if torch .is_tensor (model_output ) else "cpu"
221217 noise = torch .randn (model_output .shape , dtype = model_output .dtype , generator = generator ).to (device )
222- variance = self ._get_variance (timestep , prev_timestep ) ** ( 0.5 ) * eta * noise
218+ variance = self ._get_variance (timestep , prev_timestep ) ** 0.5 * eta * noise
223219
224220 pred_prev_sample = pred_prev_sample + variance
225221
@@ -263,13 +259,13 @@ def reversed_step(
263259 # 3. compute predicted original sample from predicted noise also called
264260 # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
265261
266- if self .prediction_type == "epsilon" :
262+ if self .prediction_type == DDIMPredictionType . EPSILON :
267263 pred_original_sample = (sample - beta_prod_t ** (0.5 ) * model_output ) / alpha_prod_t ** (0.5 )
268264 pred_epsilon = model_output
269- elif self .prediction_type == "sample" :
265+ elif self .prediction_type == DDIMPredictionType . SAMPLE :
270266 pred_original_sample = model_output
271267 pred_epsilon = (sample - alpha_prod_t ** (0.5 ) * pred_original_sample ) / beta_prod_t ** (0.5 )
272- elif self .prediction_type == "v_prediction" :
268+ elif self .prediction_type == DDIMPredictionType . V_PREDICTION :
273269 pred_original_sample = (alpha_prod_t ** 0.5 ) * sample - (beta_prod_t ** 0.5 ) * model_output
274270 pred_epsilon = (alpha_prod_t ** 0.5 ) * model_output + (beta_prod_t ** 0.5 ) * sample
275271
@@ -284,50 +280,3 @@ def reversed_step(
284280 pred_post_sample = alpha_prod_t_prev ** (0.5 ) * pred_original_sample + pred_sample_direction
285281
286282 return pred_post_sample , pred_original_sample
287-
288- def add_noise (self , original_samples : torch .Tensor , noise : torch .Tensor , timesteps : torch .Tensor ) -> torch .Tensor :
289- """
290- Add noise to the original samples.
291-
292- Args:
293- original_samples: original samples
294- noise: noise to add to samples
295- timesteps: timesteps tensor indicating the timestep to be computed for each sample.
296-
297- Returns:
298- noisy_samples: sample with added noise
299- """
300- # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
301- self .alphas_cumprod = self .alphas_cumprod .to (device = original_samples .device , dtype = original_samples .dtype )
302- timesteps = timesteps .to (original_samples .device )
303-
304- sqrt_alpha_cumprod = self .alphas_cumprod [timesteps ] ** 0.5
305- sqrt_alpha_cumprod = sqrt_alpha_cumprod .flatten ()
306- while len (sqrt_alpha_cumprod .shape ) < len (original_samples .shape ):
307- sqrt_alpha_cumprod = sqrt_alpha_cumprod .unsqueeze (- 1 )
308-
309- sqrt_one_minus_alpha_prod = (1 - self .alphas_cumprod [timesteps ]) ** 0.5
310- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .flatten ()
311- while len (sqrt_one_minus_alpha_prod .shape ) < len (original_samples .shape ):
312- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .unsqueeze (- 1 )
313-
314- noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
315- return noisy_samples
316-
317- def get_velocity (self , sample : torch .Tensor , noise : torch .Tensor , timesteps : torch .Tensor ) -> torch .Tensor :
318- # Make sure alphas_cumprod and timestep have same device and dtype as sample
319- self .alphas_cumprod = self .alphas_cumprod .to (device = sample .device , dtype = sample .dtype )
320- timesteps = timesteps .to (sample .device )
321-
322- sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
323- sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
324- while len (sqrt_alpha_prod .shape ) < len (sample .shape ):
325- sqrt_alpha_prod = sqrt_alpha_prod .unsqueeze (- 1 )
326-
327- sqrt_one_minus_alpha_prod = (1 - self .alphas_cumprod [timesteps ]) ** 0.5
328- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .flatten ()
329- while len (sqrt_one_minus_alpha_prod .shape ) < len (sample .shape ):
330- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .unsqueeze (- 1 )
331-
332- velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
333- return velocity
0 commit comments