Skip to content

Commit 256e010

Browse files
authored
Improve docstrings and type hints in scheduling_deis_multistep.py (#12796)
* feat: Add `flow_prediction` to `prediction_type`, introduce `use_flow_sigmas`, `flow_shift`, `use_dynamic_shifting`, and `time_shift_type` parameters, and refine type hints for various arguments. * style: reformat argument wrapping in `_convert_to_beta` and `index_for_timestep` method signatures.
1 parent 8430ac2 commit 256e010

File tree

1 file changed

+78
-39
lines changed

1 file changed

+78
-39
lines changed

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 78 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -84,33 +84,35 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
8484
methods the library implements for all schedulers such as loading and saving.
8585
8686
Args:
87-
num_train_timesteps (`int`, defaults to 1000):
87+
num_train_timesteps (`int`, defaults to `1000`):
8888
The number of diffusion steps to train the model.
89-
beta_start (`float`, defaults to 0.0001):
89+
beta_start (`float`, defaults to `0.0001`):
9090
The starting `beta` value of inference.
91-
beta_end (`float`, defaults to 0.02):
91+
beta_end (`float`, defaults to `0.02`):
9292
The final `beta` value.
93-
beta_schedule (`str`, defaults to `"linear"`):
93+
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
9494
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
9595
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
96-
trained_betas (`np.ndarray`, *optional*):
96+
trained_betas (`np.ndarray` or `List[float]`, *optional*):
9797
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
98-
solver_order (`int`, defaults to 2):
98+
solver_order (`int`, defaults to `2`):
9999
The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
100100
sampling, and `solver_order=3` for unconditional sampling.
101-
prediction_type (`str`, defaults to `epsilon`):
101+
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
102102
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
103-
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
104-
Video](https://huggingface.co/papers/2210.02303) paper).
103+
`sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
104+
Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
105105
thresholding (`bool`, defaults to `False`):
106106
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
107107
as Stable Diffusion.
108-
dynamic_thresholding_ratio (`float`, defaults to 0.995):
108+
dynamic_thresholding_ratio (`float`, defaults to `0.995`):
109109
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
110-
sample_max_value (`float`, defaults to 1.0):
110+
sample_max_value (`float`, defaults to `1.0`):
111111
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
112-
algorithm_type (`str`, defaults to `deis`):
112+
algorithm_type (`"deis"`, defaults to `"deis"`):
113113
The algorithm type for the solver.
114+
solver_type (`"logrho"`, defaults to `"logrho"`):
115+
Solver type for DEIS.
114116
lower_order_final (`bool`, defaults to `True`):
115117
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps.
116118
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -121,11 +123,19 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
121123
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
122124
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
123125
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
124-
timestep_spacing (`str`, defaults to `"linspace"`):
126+
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
127+
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
128+
flow_shift (`float`, *optional*, defaults to `1.0`):
129+
The flow shift parameter for flow-based models.
130+
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
125131
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
126132
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
127-
steps_offset (`int`, defaults to 0):
133+
steps_offset (`int`, defaults to `0`):
128134
An offset added to the inference steps, as required by some model families.
135+
use_dynamic_shifting (`bool`, defaults to `False`):
136+
Whether to use dynamic shifting for the noise schedule.
137+
time_shift_type (`"exponential"`, defaults to `"exponential"`):
138+
The type of time shifting to apply.
129139
"""
130140

131141
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -137,29 +147,38 @@ def __init__(
137147
num_train_timesteps: int = 1000,
138148
beta_start: float = 0.0001,
139149
beta_end: float = 0.02,
140-
beta_schedule: str = "linear",
141-
trained_betas: Optional[np.ndarray] = None,
150+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
151+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
142152
solver_order: int = 2,
143-
prediction_type: str = "epsilon",
153+
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
144154
thresholding: bool = False,
145155
dynamic_thresholding_ratio: float = 0.995,
146156
sample_max_value: float = 1.0,
147-
algorithm_type: str = "deis",
148-
solver_type: str = "logrho",
157+
algorithm_type: Literal["deis"] = "deis",
158+
solver_type: Literal["logrho"] = "logrho",
149159
lower_order_final: bool = True,
150160
use_karras_sigmas: Optional[bool] = False,
151161
use_exponential_sigmas: Optional[bool] = False,
152162
use_beta_sigmas: Optional[bool] = False,
153163
use_flow_sigmas: Optional[bool] = False,
154164
flow_shift: Optional[float] = 1.0,
155-
timestep_spacing: str = "linspace",
165+
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
156166
steps_offset: int = 0,
157167
use_dynamic_shifting: bool = False,
158-
time_shift_type: str = "exponential",
159-
):
168+
time_shift_type: Literal["exponential"] = "exponential",
169+
) -> None:
160170
if self.config.use_beta_sigmas and not is_scipy_available():
161171
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
162-
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
172+
if (
173+
sum(
174+
[
175+
self.config.use_beta_sigmas,
176+
self.config.use_exponential_sigmas,
177+
self.config.use_karras_sigmas,
178+
]
179+
)
180+
> 1
181+
):
163182
raise ValueError(
164183
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
165184
)
@@ -169,7 +188,15 @@ def __init__(
169188
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
170189
elif beta_schedule == "scaled_linear":
171190
# this schedule is very specific to the latent diffusion model.
172-
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
191+
self.betas = (
192+
torch.linspace(
193+
beta_start**0.5,
194+
beta_end**0.5,
195+
num_train_timesteps,
196+
dtype=torch.float32,
197+
)
198+
** 2
199+
)
173200
elif beta_schedule == "squaredcos_cap_v2":
174201
# Glide cosine schedule
175202
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -211,21 +238,21 @@ def __init__(
211238
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
212239

213240
@property
214-
def step_index(self):
241+
def step_index(self) -> Optional[int]:
215242
"""
216243
The index counter for current timestep. It will increase 1 after each scheduler step.
217244
"""
218245
return self._step_index
219246

220247
@property
221-
def begin_index(self):
248+
def begin_index(self) -> Optional[int]:
222249
"""
223250
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
224251
"""
225252
return self._begin_index
226253

227254
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
228-
def set_begin_index(self, begin_index: int = 0):
255+
def set_begin_index(self, begin_index: int = 0) -> None:
229256
"""
230257
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
231258
@@ -236,8 +263,11 @@ def set_begin_index(self, begin_index: int = 0):
236263
self._begin_index = begin_index
237264

238265
def set_timesteps(
239-
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
240-
):
266+
self,
267+
num_inference_steps: int,
268+
device: Union[str, torch.device] = None,
269+
mu: Optional[float] = None,
270+
) -> None:
241271
"""
242272
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
243273
@@ -246,6 +276,9 @@ def set_timesteps(
246276
The number of diffusion steps used when generating samples with a pre-trained model.
247277
device (`str` or `torch.device`, *optional*):
248278
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
279+
mu (`float`, *optional*):
280+
The mu parameter for dynamic shifting. Only used when `use_dynamic_shifting=True` and
281+
`time_shift_type="exponential"`.
249282
"""
250283
if mu is not None:
251284
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
@@ -363,7 +396,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
363396
return sample
364397

365398
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
366-
def _sigma_to_t(self, sigma, log_sigmas):
399+
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
367400
"""
368401
Convert sigma values to corresponding timestep values through interpolation.
369402
@@ -400,7 +433,7 @@ def _sigma_to_t(self, sigma, log_sigmas):
400433
return t
401434

402435
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
403-
def _sigma_to_alpha_sigma_t(self, sigma):
436+
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
404437
"""
405438
Convert sigma values to alpha_t and sigma_t values.
406439
@@ -422,7 +455,7 @@ def _sigma_to_alpha_sigma_t(self, sigma):
422455
return alpha_t, sigma_t
423456

424457
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
425-
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
458+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
426459
"""
427460
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
428461
Models](https://huggingface.co/papers/2206.00364).
@@ -648,7 +681,10 @@ def deis_first_order_update(
648681
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
649682
)
650683

651-
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
684+
sigma_t, sigma_s = (
685+
self.sigmas[self.step_index + 1],
686+
self.sigmas[self.step_index],
687+
)
652688
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
653689
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
654690
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -714,7 +750,11 @@ def multistep_deis_second_order_update(
714750

715751
m0, m1 = model_output_list[-1], model_output_list[-2]
716752

717-
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
753+
rho_t, rho_s0, rho_s1 = (
754+
sigma_t / alpha_t,
755+
sigma_s0 / alpha_s0,
756+
sigma_s1 / alpha_s1,
757+
)
718758

719759
if self.config.algorithm_type == "deis":
720760

@@ -854,7 +894,7 @@ def index_for_timestep(
854894
return step_index
855895

856896
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
857-
def _init_step_index(self, timestep):
897+
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
858898
"""
859899
Initialize the step_index counter for the scheduler.
860900
@@ -884,18 +924,17 @@ def step(
884924
Args:
885925
model_output (`torch.Tensor`):
886926
The direct output from learned diffusion model.
887-
timestep (`int`):
927+
timestep (`int` or `torch.Tensor`):
888928
The current discrete timestep in the diffusion chain.
889929
sample (`torch.Tensor`):
890930
A current instance of a sample created by the diffusion process.
891-
return_dict (`bool`):
931+
return_dict (`bool`, defaults to `True`):
892932
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
893933
894934
Returns:
895935
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
896936
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
897937
tuple is returned where the first element is the sample tensor.
898-
899938
"""
900939
if self.num_inference_steps is None:
901940
raise ValueError(
@@ -1000,5 +1039,5 @@ def add_noise(
10001039
noisy_samples = alpha_t * original_samples + sigma_t * noise
10011040
return noisy_samples
10021041

1003-
def __len__(self):
1042+
def __len__(self) -> int:
10041043
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)