Skip to content

Commit 67f931b

Browse files
committed
Improve docstrings and type hints in scheduling_ddim.py
- Add complete type hints for all function parameters - Enhance docstrings to follow project conventions - Add missing parameter descriptions Fixes #9567
1 parent 0fd58c7 commit 67f931b

File tree

1 file changed

+75
-15
lines changed

1 file changed

+75
-15
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,17 @@ class DDIMSchedulerOutput(BaseOutput):
4949

5050
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
5151
def betas_for_alpha_bar(
52-
num_diffusion_timesteps,
53-
max_beta=0.999,
54-
alpha_transform_type="cosine",
55-
):
52+
num_diffusion_timesteps: int,
53+
max_beta: float = 0.999,
54+
alpha_transform_type: str = "cosine",
55+
) -> torch.Tensor:
5656
"""
5757
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
5858
(1-beta) over time from t = [0,1].
5959
6060
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
6161
to that part of the diffusion process.
6262
63-
6463
Args:
6564
num_diffusion_timesteps (`int`): the number of betas to produce.
6665
max_beta (`float`): the maximum beta to use; use values lower than 1 to
@@ -69,16 +68,16 @@ def betas_for_alpha_bar(
6968
Choose from `cosine` or `exp`
7069
7170
Returns:
72-
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
71+
betas (`torch.Tensor`): the betas used by the scheduler to step the model outputs
7372
"""
7473
if alpha_transform_type == "cosine":
7574

76-
def alpha_bar_fn(t):
75+
def alpha_bar_fn(t: float) -> float:
7776
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
7877

7978
elif alpha_transform_type == "exp":
8079

81-
def alpha_bar_fn(t):
80+
def alpha_bar_fn(t: float) -> float:
8281
return math.exp(t * -12.0)
8382

8483
else:
@@ -92,11 +91,10 @@ def alpha_bar_fn(t):
9291
return torch.tensor(betas, dtype=torch.float32)
9392

9493

95-
def rescale_zero_terminal_snr(betas):
94+
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
9695
"""
9796
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
9897
99-
10098
Args:
10199
betas (`torch.Tensor`):
102100
the betas that the scheduler is being initialized with.
@@ -250,7 +248,25 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
250248
"""
251249
return sample
252250

253-
def _get_variance(self, timestep, prev_timestep):
251+
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
252+
"""
253+
Computes the variance of the noise added at a given diffusion step.
254+
255+
For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
256+
literature:
257+
var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
258+
where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
259+
260+
Args:
261+
timestep (`int`):
262+
The current timestep in the diffusion process.
263+
prev_timestep (`int`):
264+
The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
265+
266+
Returns:
267+
`torch.Tensor`:
268+
The variance for the current timestep.
269+
"""
254270
alpha_prod_t = self.alphas_cumprod[timestep]
255271
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
256272
beta_prod_t = 1 - alpha_prod_t
@@ -263,13 +279,21 @@ def _get_variance(self, timestep, prev_timestep):
263279
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
264280
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
265281
"""
266-
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
282+
Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
267283
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
268284
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
269285
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
270-
photorealism as well as better image-text alignment, especially when using very large guidance weights."
286+
photorealism as well as better image-text alignment, especially when using very large guidance weights.
287+
288+
See https://huggingface.co/papers/2205.11487
289+
290+
Args:
291+
sample (`torch.Tensor`):
292+
The sample to threshold.
271293
272-
https://huggingface.co/papers/2205.11487
294+
Returns:
295+
`torch.Tensor`:
296+
The thresholded sample.
273297
"""
274298
dtype = sample.dtype
275299
batch_size, channels, *remaining_dims = sample.shape
@@ -294,13 +318,18 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
294318

295319
return sample
296320

297-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
321+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
298322
"""
299323
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
300324
301325
Args:
302326
num_inference_steps (`int`):
303327
The number of diffusion steps used when generating samples with a pre-trained model.
328+
device (`Union[str, torch.device]`, *optional*):
329+
The device to use for the timesteps.
330+
331+
Raises:
332+
ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
304333
"""
305334

306335
if num_inference_steps > self.config.num_train_timesteps:
@@ -477,6 +506,21 @@ def add_noise(
477506
noise: torch.Tensor,
478507
timesteps: torch.IntTensor,
479508
) -> torch.Tensor:
509+
"""
510+
Adds noise to the original samples.
511+
512+
Args:
513+
original_samples (`torch.Tensor`):
514+
The original samples to add noise to.
515+
noise (`torch.Tensor`):
516+
The noise to add to the original samples.
517+
timesteps (`torch.IntTensor`):
518+
The timesteps to add noise to.
519+
520+
Returns:
521+
`torch.Tensor`:
522+
The noisy samples.
523+
"""
480524
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
481525
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
482526
# for the subsequent add_noise calls
@@ -499,6 +543,22 @@ def add_noise(
499543

500544
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
501545
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
546+
"""
547+
Computes the velocity of the sample. The velocity is defined as the difference between the original sample and
548+
the noisy sample. See https://huggingface.co/papers/2010.02502
549+
550+
Args:
551+
sample (`torch.Tensor`):
552+
The sample to compute the velocity of.
553+
noise (`torch.Tensor`):
554+
The noise to compute the velocity of.
555+
timesteps (`torch.IntTensor`):
556+
The timesteps to compute the velocity of.
557+
558+
Returns:
559+
`torch.Tensor`:
560+
The velocity of the sample.
561+
"""
502562
# Make sure alphas_cumprod and timestep have same device and dtype as sample
503563
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
504564
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)

0 commit comments

Comments
 (0)