@@ -84,33 +84,35 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
8484 methods the library implements for all schedulers such as loading and saving.
8585
8686 Args:
87- num_train_timesteps (`int`, defaults to 1000):
87+ num_train_timesteps (`int`, defaults to ` 1000` ):
8888 The number of diffusion steps to train the model.
89- beta_start (`float`, defaults to 0.0001):
89+ beta_start (`float`, defaults to ` 0.0001` ):
9090 The starting `beta` value of inference.
91- beta_end (`float`, defaults to 0.02):
91+ beta_end (`float`, defaults to ` 0.02` ):
9292 The final `beta` value.
93- beta_schedule (`str `, defaults to `"linear"`):
93+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2" `, defaults to `"linear"`):
9494 The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
9595 `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
96- trained_betas (`np.ndarray`, *optional*):
96+ trained_betas (`np.ndarray` or `List[float]` , *optional*):
9797 Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
98- solver_order (`int`, defaults to 2 ):
98+ solver_order (`int`, defaults to `2` ):
9999 The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
100100 sampling, and `solver_order=3` for unconditional sampling.
101- prediction_type (`str `, defaults to `epsilon`):
101+ prediction_type (`"epsilon" `, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `" epsilon" `):
102102 Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
103- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
104- Video](https://huggingface.co/papers/2210.02303) paper).
103+ `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
104+ Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction` .
105105 thresholding (`bool`, defaults to `False`):
106106 Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
107107 as Stable Diffusion.
108- dynamic_thresholding_ratio (`float`, defaults to 0.995):
108+ dynamic_thresholding_ratio (`float`, defaults to ` 0.995` ):
109109 The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
110- sample_max_value (`float`, defaults to 1.0):
110+ sample_max_value (`float`, defaults to ` 1.0` ):
111111 The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
112- algorithm_type (`str `, defaults to `deis`):
112+ algorithm_type (`"deis" `, defaults to `" deis" `):
113113 The algorithm type for the solver.
114+ solver_type (`"logrho"`, defaults to `"logrho"`):
115+ Solver type for DEIS.
114116 lower_order_final (`bool`, defaults to `True`):
115117 Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps.
116118 use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -121,11 +123,19 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
121123 use_beta_sigmas (`bool`, *optional*, defaults to `False`):
122124 Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
123125 Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
124- timestep_spacing (`str`, defaults to `"linspace"`):
126+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
127+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
128+ flow_shift (`float`, *optional*, defaults to `1.0`):
129+ The flow shift parameter for flow-based models.
130+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
125131 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
126132 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
127- steps_offset (`int`, defaults to 0 ):
133+ steps_offset (`int`, defaults to `0` ):
128134 An offset added to the inference steps, as required by some model families.
135+ use_dynamic_shifting (`bool`, defaults to `False`):
136+ Whether to use dynamic shifting for the noise schedule.
137+ time_shift_type (`"exponential"`, defaults to `"exponential"`):
138+ The type of time shifting to apply.
129139 """
130140
131141 _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -137,29 +147,38 @@ def __init__(
137147 num_train_timesteps : int = 1000 ,
138148 beta_start : float = 0.0001 ,
139149 beta_end : float = 0.02 ,
140- beta_schedule : str = "linear" ,
141- trained_betas : Optional [np .ndarray ] = None ,
150+ beta_schedule : Literal [ "linear" , "scaled_linear" , "squaredcos_cap_v2" ] = "linear" ,
151+ trained_betas : Optional [Union [ np .ndarray , List [ float ]] ] = None ,
142152 solver_order : int = 2 ,
143- prediction_type : str = "epsilon" ,
153+ prediction_type : Literal [ "epsilon" , "sample" , "v_prediction" , "flow_prediction" ] = "epsilon" ,
144154 thresholding : bool = False ,
145155 dynamic_thresholding_ratio : float = 0.995 ,
146156 sample_max_value : float = 1.0 ,
147- algorithm_type : str = "deis" ,
148- solver_type : str = "logrho" ,
157+ algorithm_type : Literal [ "deis" ] = "deis" ,
158+ solver_type : Literal [ "logrho" ] = "logrho" ,
149159 lower_order_final : bool = True ,
150160 use_karras_sigmas : Optional [bool ] = False ,
151161 use_exponential_sigmas : Optional [bool ] = False ,
152162 use_beta_sigmas : Optional [bool ] = False ,
153163 use_flow_sigmas : Optional [bool ] = False ,
154164 flow_shift : Optional [float ] = 1.0 ,
155- timestep_spacing : str = "linspace" ,
165+ timestep_spacing : Literal [ "linspace" , "leading" , "trailing" ] = "linspace" ,
156166 steps_offset : int = 0 ,
157167 use_dynamic_shifting : bool = False ,
158- time_shift_type : str = "exponential" ,
159- ):
168+ time_shift_type : Literal [ "exponential" ] = "exponential" ,
169+ ) -> None :
160170 if self .config .use_beta_sigmas and not is_scipy_available ():
161171 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
162- if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
172+ if (
173+ sum (
174+ [
175+ self .config .use_beta_sigmas ,
176+ self .config .use_exponential_sigmas ,
177+ self .config .use_karras_sigmas ,
178+ ]
179+ )
180+ > 1
181+ ):
163182 raise ValueError (
164183 "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
165184 )
@@ -169,7 +188,15 @@ def __init__(
169188 self .betas = torch .linspace (beta_start , beta_end , num_train_timesteps , dtype = torch .float32 )
170189 elif beta_schedule == "scaled_linear" :
171190 # this schedule is very specific to the latent diffusion model.
172- self .betas = torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = torch .float32 ) ** 2
191+ self .betas = (
192+ torch .linspace (
193+ beta_start ** 0.5 ,
194+ beta_end ** 0.5 ,
195+ num_train_timesteps ,
196+ dtype = torch .float32 ,
197+ )
198+ ** 2
199+ )
173200 elif beta_schedule == "squaredcos_cap_v2" :
174201 # Glide cosine schedule
175202 self .betas = betas_for_alpha_bar (num_train_timesteps )
@@ -211,21 +238,21 @@ def __init__(
211238 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
212239
213240 @property
214- def step_index (self ):
241+ def step_index (self ) -> Optional [ int ] :
215242 """
216243 The index counter for current timestep. It will increase 1 after each scheduler step.
217244 """
218245 return self ._step_index
219246
220247 @property
221- def begin_index (self ):
248+ def begin_index (self ) -> Optional [ int ] :
222249 """
223250 The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
224251 """
225252 return self ._begin_index
226253
227254 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
228- def set_begin_index (self , begin_index : int = 0 ):
255+ def set_begin_index (self , begin_index : int = 0 ) -> None :
229256 """
230257 Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
231258
@@ -236,8 +263,11 @@ def set_begin_index(self, begin_index: int = 0):
236263 self ._begin_index = begin_index
237264
238265 def set_timesteps (
239- self , num_inference_steps : int , device : Union [str , torch .device ] = None , mu : Optional [float ] = None
240- ):
266+ self ,
267+ num_inference_steps : int ,
268+ device : Union [str , torch .device ] = None ,
269+ mu : Optional [float ] = None ,
270+ ) -> None :
241271 """
242272 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
243273
@@ -246,6 +276,9 @@ def set_timesteps(
246276 The number of diffusion steps used when generating samples with a pre-trained model.
247277 device (`str` or `torch.device`, *optional*):
248278 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
279+ mu (`float`, *optional*):
280+ The mu parameter for dynamic shifting. Only used when `use_dynamic_shifting=True` and
281+ `time_shift_type="exponential"`.
249282 """
250283 if mu is not None :
251284 assert self .config .use_dynamic_shifting and self .config .time_shift_type == "exponential"
@@ -363,7 +396,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
363396 return sample
364397
365398 # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
366- def _sigma_to_t (self , sigma , log_sigmas ) :
399+ def _sigma_to_t (self , sigma : np . ndarray , log_sigmas : np . ndarray ) -> np . ndarray :
367400 """
368401 Convert sigma values to corresponding timestep values through interpolation.
369402
@@ -400,7 +433,7 @@ def _sigma_to_t(self, sigma, log_sigmas):
400433 return t
401434
402435 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
403- def _sigma_to_alpha_sigma_t (self , sigma ) :
436+ def _sigma_to_alpha_sigma_t (self , sigma : torch . Tensor ) -> Tuple [ torch . Tensor , torch . Tensor ] :
404437 """
405438 Convert sigma values to alpha_t and sigma_t values.
406439
@@ -422,7 +455,7 @@ def _sigma_to_alpha_sigma_t(self, sigma):
422455 return alpha_t , sigma_t
423456
424457 # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
425- def _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps ) -> torch .Tensor :
458+ def _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps : int ) -> torch .Tensor :
426459 """
427460 Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
428461 Models](https://huggingface.co/papers/2206.00364).
@@ -648,7 +681,10 @@ def deis_first_order_update(
648681 "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`" ,
649682 )
650683
651- sigma_t , sigma_s = self .sigmas [self .step_index + 1 ], self .sigmas [self .step_index ]
684+ sigma_t , sigma_s = (
685+ self .sigmas [self .step_index + 1 ],
686+ self .sigmas [self .step_index ],
687+ )
652688 alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma_t )
653689 alpha_s , sigma_s = self ._sigma_to_alpha_sigma_t (sigma_s )
654690 lambda_t = torch .log (alpha_t ) - torch .log (sigma_t )
@@ -714,7 +750,11 @@ def multistep_deis_second_order_update(
714750
715751 m0 , m1 = model_output_list [- 1 ], model_output_list [- 2 ]
716752
717- rho_t , rho_s0 , rho_s1 = sigma_t / alpha_t , sigma_s0 / alpha_s0 , sigma_s1 / alpha_s1
753+ rho_t , rho_s0 , rho_s1 = (
754+ sigma_t / alpha_t ,
755+ sigma_s0 / alpha_s0 ,
756+ sigma_s1 / alpha_s1 ,
757+ )
718758
719759 if self .config .algorithm_type == "deis" :
720760
@@ -854,7 +894,7 @@ def index_for_timestep(
854894 return step_index
855895
856896 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
857- def _init_step_index (self , timestep ) :
897+ def _init_step_index (self , timestep : Union [ int , torch . Tensor ]) -> None :
858898 """
859899 Initialize the step_index counter for the scheduler.
860900
@@ -884,18 +924,17 @@ def step(
884924 Args:
885925 model_output (`torch.Tensor`):
886926 The direct output from learned diffusion model.
887- timestep (`int`):
927+ timestep (`int` or `torch.Tensor` ):
888928 The current discrete timestep in the diffusion chain.
889929 sample (`torch.Tensor`):
890930 A current instance of a sample created by the diffusion process.
891- return_dict (`bool`):
931+ return_dict (`bool`, defaults to `True` ):
892932 Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
893933
894934 Returns:
895935 [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
896936 If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
897937 tuple is returned where the first element is the sample tensor.
898-
899938 """
900939 if self .num_inference_steps is None :
901940 raise ValueError (
@@ -1000,5 +1039,5 @@ def add_noise(
10001039 noisy_samples = alpha_t * original_samples + sigma_t * noise
10011040 return noisy_samples
10021041
1003- def __len__ (self ):
1042+ def __len__ (self ) -> int :
10041043 return self .config .num_train_timesteps
0 commit comments