Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 5a8f75f

Browse files
Refactoring schedulers (#285)
* Adding components and refactoring of schedulers (DDPM only) Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Fix Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Adding tests I forgot to add Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Updates from comments * Updates to other schedulers Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Update Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Tutorials updates Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Update Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Update Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Fixes Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Updates from comments Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Update generative/networks/schedulers/ddpm.py Co-authored-by: Mark Graham <markgraham539@gmail.com> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> * Autofixin' Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> * Fixes Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> --------- Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Mark Graham <markgraham539@gmail.com>
1 parent 798a2ef commit 5a8f75f

26 files changed

+649
-304
lines changed

generative/networks/schedulers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .ddim import DDIMScheduler
1515
from .ddpm import DDPMScheduler
1616
from .pndm import PNDMScheduler
17+
from .scheduler import NoiseSchedules, Scheduler

generative/networks/schedulers/ddim.py

Lines changed: 42 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -33,66 +33,62 @@
3333

3434
import numpy as np
3535
import torch
36-
import torch.nn as nn
36+
from monai.utils import StrEnum
3737

38+
from .scheduler import Scheduler
3839

39-
class DDIMScheduler(nn.Module):
40+
41+
class DDIMPredictionType(StrEnum):
42+
"""
43+
Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument.
44+
45+
epsilon: predicting the noise of the diffusion process
46+
sample: directly predicting the noisy sample
47+
v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
48+
"""
49+
50+
EPSILON = "epsilon"
51+
SAMPLE = "sample"
52+
V_PREDICTION = "v_prediction"
53+
54+
55+
class DDIMScheduler(Scheduler):
4056
"""
4157
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
4258
diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion
4359
Implicit Models" https://arxiv.org/abs/2010.02502
4460
4561
Args:
4662
num_train_timesteps: number of diffusion steps used to train the model.
47-
beta_start: the starting `beta` value of inference.
48-
beta_end: the final `beta` value.
49-
beta_schedule: {``"linear"``, ``"scaled_linear"``}
50-
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
63+
schedule: member of NoiseSchedules, name of noise schedule function in component store
5164
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
5265
set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.
5366
For the final step there is no previous alpha. When this option is `True` the previous alpha product is
5467
fixed to `1`, otherwise it uses the value of alpha at step 0.
5568
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
5669
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
5770
stable diffusion.
58-
prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``}
59-
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
60-
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
61-
https://imagen.research.google/video/paper.pdf)
71+
prediction_type: member of DDPMPredictionType
72+
schedule_args: arguments to pass to the schedule function
73+
6274
"""
6375

6476
def __init__(
6577
self,
6678
num_train_timesteps: int = 1000,
67-
beta_start: float = 1e-4,
68-
beta_end: float = 2e-2,
69-
beta_schedule: str = "linear",
79+
schedule: str = "linear_beta",
7080
clip_sample: bool = True,
7181
set_alpha_to_one: bool = True,
7282
steps_offset: int = 0,
73-
prediction_type: str = "epsilon",
83+
prediction_type: str = DDIMPredictionType.EPSILON,
84+
**schedule_args,
7485
) -> None:
75-
super().__init__()
76-
self.beta_schedule = beta_schedule
77-
if beta_schedule == "linear":
78-
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
79-
elif beta_schedule == "scaled_linear":
80-
# this schedule is very specific to the latent diffusion model.
81-
self.betas = (
82-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
83-
)
84-
else:
85-
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
86+
super().__init__(num_train_timesteps, schedule, **schedule_args)
8687

87-
if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]:
88-
raise ValueError(
89-
f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`"
90-
)
88+
if prediction_type not in DDIMPredictionType.__members__.values():
89+
raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType")
9190

9291
self.prediction_type = prediction_type
93-
self.num_train_timesteps = num_train_timesteps
94-
self.alphas = 1.0 - self.betas
95-
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
9692

9793
# At every step in ddim, we are looking into the previous alphas_cumprod
9894
# For the final step, there is no previous alphas_cumprod because we are already at 0
@@ -103,13 +99,13 @@ def __init__(
10399
# standard deviation of the initial noise distribution
104100
self.init_noise_sigma = 1.0
105101

106-
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64))
102+
self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64))
107103

108104
self.clip_sample = clip_sample
109105
self.steps_offset = steps_offset
110106

111107
# default the number of inference timesteps to the number of train steps
112-
self.set_timesteps(num_train_timesteps)
108+
self.set_timesteps(self.num_train_timesteps)
113109

114110
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
115111
"""
@@ -190,13 +186,13 @@ def step(
190186

191187
# 3. compute predicted original sample from predicted noise also called
192188
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
193-
if self.prediction_type == "epsilon":
194-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
189+
if self.prediction_type == DDIMPredictionType.EPSILON:
190+
pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5)
195191
pred_epsilon = model_output
196-
elif self.prediction_type == "sample":
192+
elif self.prediction_type == DDIMPredictionType.SAMPLE:
197193
pred_original_sample = model_output
198-
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
199-
elif self.prediction_type == "v_prediction":
194+
pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5)
195+
elif self.prediction_type == DDIMPredictionType.V_PREDICTION:
200196
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
201197
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
202198

@@ -207,19 +203,19 @@ def step(
207203
# 5. compute variance: "sigma_t(η)" -> see formula (16)
208204
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
209205
variance = self._get_variance(timestep, prev_timestep)
210-
std_dev_t = eta * variance ** (0.5)
206+
std_dev_t = eta * variance**0.5
211207

212208
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
213-
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
209+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon
214210

215211
# 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
216-
pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
212+
pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction
217213

218214
if eta > 0:
219215
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
220216
device = model_output.device if torch.is_tensor(model_output) else "cpu"
221217
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
222-
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
218+
variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise
223219

224220
pred_prev_sample = pred_prev_sample + variance
225221

@@ -263,13 +259,13 @@ def reversed_step(
263259
# 3. compute predicted original sample from predicted noise also called
264260
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
265261

266-
if self.prediction_type == "epsilon":
262+
if self.prediction_type == DDIMPredictionType.EPSILON:
267263
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
268264
pred_epsilon = model_output
269-
elif self.prediction_type == "sample":
265+
elif self.prediction_type == DDIMPredictionType.SAMPLE:
270266
pred_original_sample = model_output
271267
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
272-
elif self.prediction_type == "v_prediction":
268+
elif self.prediction_type == DDIMPredictionType.V_PREDICTION:
273269
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
274270
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
275271

@@ -284,50 +280,3 @@ def reversed_step(
284280
pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
285281

286282
return pred_post_sample, pred_original_sample
287-
288-
def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
289-
"""
290-
Add noise to the original samples.
291-
292-
Args:
293-
original_samples: original samples
294-
noise: noise to add to samples
295-
timesteps: timesteps tensor indicating the timestep to be computed for each sample.
296-
297-
Returns:
298-
noisy_samples: sample with added noise
299-
"""
300-
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
301-
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
302-
timesteps = timesteps.to(original_samples.device)
303-
304-
sqrt_alpha_cumprod = self.alphas_cumprod[timesteps] ** 0.5
305-
sqrt_alpha_cumprod = sqrt_alpha_cumprod.flatten()
306-
while len(sqrt_alpha_cumprod.shape) < len(original_samples.shape):
307-
sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1)
308-
309-
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
310-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
311-
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
312-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
313-
314-
noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
315-
return noisy_samples
316-
317-
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
318-
# Make sure alphas_cumprod and timestep have same device and dtype as sample
319-
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
320-
timesteps = timesteps.to(sample.device)
321-
322-
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
323-
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
324-
while len(sqrt_alpha_prod.shape) < len(sample.shape):
325-
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
326-
327-
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
328-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
329-
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
330-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
331-
332-
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
333-
return velocity

0 commit comments

Comments
 (0)