@@ -49,18 +49,17 @@ class DDIMSchedulerOutput(BaseOutput):
4949
5050# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
5151def betas_for_alpha_bar (
52- num_diffusion_timesteps ,
53- max_beta = 0.999 ,
54- alpha_transform_type = "cosine" ,
55- ):
52+ num_diffusion_timesteps : int ,
53+ max_beta : float = 0.999 ,
54+ alpha_transform_type : str = "cosine" ,
55+ ) -> torch . Tensor :
5656 """
5757 Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
5858 (1-beta) over time from t = [0,1].
5959
6060 Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
6161 to that part of the diffusion process.
6262
63-
6463 Args:
6564 num_diffusion_timesteps (`int`): the number of betas to produce.
6665 max_beta (`float`): the maximum beta to use; use values lower than 1 to
@@ -69,16 +68,16 @@ def betas_for_alpha_bar(
6968 Choose from `cosine` or `exp`
7069
7170 Returns:
72- betas (`np.ndarray `): the betas used by the scheduler to step the model outputs
71+ betas (`torch.Tensor `): the betas used by the scheduler to step the model outputs
7372 """
7473 if alpha_transform_type == "cosine" :
7574
76- def alpha_bar_fn (t ) :
75+ def alpha_bar_fn (t : float ) -> float :
7776 return math .cos ((t + 0.008 ) / 1.008 * math .pi / 2 ) ** 2
7877
7978 elif alpha_transform_type == "exp" :
8079
81- def alpha_bar_fn (t ) :
80+ def alpha_bar_fn (t : float ) -> float :
8281 return math .exp (t * - 12.0 )
8382
8483 else :
@@ -92,11 +91,10 @@ def alpha_bar_fn(t):
9291 return torch .tensor (betas , dtype = torch .float32 )
9392
9493
95- def rescale_zero_terminal_snr (betas ) :
94+ def rescale_zero_terminal_snr (betas : torch . Tensor ) -> torch . Tensor :
9695 """
9796 Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
9897
99-
10098 Args:
10199 betas (`torch.Tensor`):
102100 the betas that the scheduler is being initialized with.
@@ -250,7 +248,25 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
250248 """
251249 return sample
252250
253- def _get_variance (self , timestep , prev_timestep ):
251+ def _get_variance (self , timestep : int , prev_timestep : int ) -> torch .Tensor :
252+ """
253+ Computes the variance of the noise added at a given diffusion step.
254+
255+ For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
256+ literature:
257+ var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
258+ where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
259+
260+ Args:
261+ timestep (`int`):
262+ The current timestep in the diffusion process.
263+ prev_timestep (`int`):
264+ The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
265+
266+ Returns:
267+ `torch.Tensor`:
268+ The variance for the current timestep.
269+ """
254270 alpha_prod_t = self .alphas_cumprod [timestep ]
255271 alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ] if prev_timestep >= 0 else self .final_alpha_cumprod
256272 beta_prod_t = 1 - alpha_prod_t
@@ -263,13 +279,21 @@ def _get_variance(self, timestep, prev_timestep):
263279 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
264280 def _threshold_sample (self , sample : torch .Tensor ) -> torch .Tensor :
265281 """
266- " Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
282+ Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
267283 prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
268284 s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
269285 pixels from saturation at each step. We find that dynamic thresholding results in significantly better
270- photorealism as well as better image-text alignment, especially when using very large guidance weights."
286+ photorealism as well as better image-text alignment, especially when using very large guidance weights.
287+
288+ See https://huggingface.co/papers/2205.11487
289+
290+ Args:
291+ sample (`torch.Tensor`):
292+ The sample to threshold.
271293
272- https://huggingface.co/papers/2205.11487
294+ Returns:
295+ `torch.Tensor`:
296+ The thresholded sample.
273297 """
274298 dtype = sample .dtype
275299 batch_size , channels , * remaining_dims = sample .shape
@@ -294,13 +318,18 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
294318
295319 return sample
296320
297- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
321+ def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ) -> None :
298322 """
299323 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
300324
301325 Args:
302326 num_inference_steps (`int`):
303327 The number of diffusion steps used when generating samples with a pre-trained model.
328+ device (`Union[str, torch.device]`, *optional*):
329+ The device to use for the timesteps.
330+
331+ Raises:
332+ ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
304333 """
305334
306335 if num_inference_steps > self .config .num_train_timesteps :
@@ -477,6 +506,21 @@ def add_noise(
477506 noise : torch .Tensor ,
478507 timesteps : torch .IntTensor ,
479508 ) -> torch .Tensor :
509+ """
510+ Adds noise to the original samples.
511+
512+ Args:
513+ original_samples (`torch.Tensor`):
514+ The original samples to add noise to.
515+ noise (`torch.Tensor`):
516+ The noise to add to the original samples.
517+ timesteps (`torch.IntTensor`):
518+ The timesteps to add noise to.
519+
520+ Returns:
521+ `torch.Tensor`:
522+ The noisy samples.
523+ """
480524 # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
481525 # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
482526 # for the subsequent add_noise calls
@@ -499,6 +543,22 @@ def add_noise(
499543
500544 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
501545 def get_velocity (self , sample : torch .Tensor , noise : torch .Tensor , timesteps : torch .IntTensor ) -> torch .Tensor :
546+ """
547+ Computes the velocity of the sample. The velocity is defined as the difference between the original sample and
548+ the noisy sample. See https://huggingface.co/papers/2010.02502
549+
550+ Args:
551+ sample (`torch.Tensor`):
552+ The sample to compute the velocity of.
553+ noise (`torch.Tensor`):
554+ The noise to compute the velocity of.
555+ timesteps (`torch.IntTensor`):
556+ The timesteps to compute the velocity of.
557+
558+ Returns:
559+ `torch.Tensor`:
560+ The velocity of the sample.
561+ """
502562 # Make sure alphas_cumprod and timestep have same device and dtype as sample
503563 self .alphas_cumprod = self .alphas_cumprod .to (device = sample .device )
504564 alphas_cumprod = self .alphas_cumprod .to (dtype = sample .dtype )
0 commit comments