@@ -160,7 +160,6 @@ def __init__(
160160 lower_order_final : bool = True ,
161161 euler_at_final : bool = False ,
162162 use_karras_sigmas : Optional [bool ] = False ,
163- final_sigmas_type : Optional [str ] = "default" , # "denoise_to_zero", "default"
164163 lambda_min_clipped : float = - float ("inf" ),
165164 variance_type : Optional [str ] = None ,
166165 timestep_spacing : str = "linspace" ,
@@ -203,11 +202,6 @@ def __init__(
203202 else :
204203 raise NotImplementedError (f"{ solver_type } does is not implemented for { self .__class__ } " )
205204
206- if algorithm_type not in ["dpmsolver++" , "sde-dpmsolver++" ] and final_sigmas_type == "denoise_to_zero" :
207- raise ValueError (
208- f"`final_sigmas_type` { final_sigmas_type } is not supported for `algorithm_type` { algorithm_type } ."
209- )
210-
211205 # setable values
212206 self .num_inference_steps = None
213207 timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = np .float32 ).copy ()
@@ -270,27 +264,13 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
270264 sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
271265 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
272266 timesteps = timesteps .copy ().astype (np .int64 )
273- if self .config .final_sigmas_type == "default" :
274- sigmas = np .concatenate ([sigmas [0 ], sigmas ]).astype (np .float32 )
275- elif self .config .final_sigmas_type == "denoise_to_zero" :
276- sigmas = np .concatenate ([np .array ([0 ]), sigmas ]).astype (np .float32 )
277- else :
278- raise ValueError (
279- f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got { self .config .final_sigmas_type } "
280- )
267+ sigmas = np .concatenate ([sigmas , sigmas [- 1 :]]).astype (np .float32 )
281268 else :
282269 sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
283- if self .config .final_sigmas_type == "default" :
284- sigma_last = (
285- (1 - self .alphas_cumprod [self .noisiest_timestep ]) / self .alphas_cumprod [self .noisiest_timestep ]
286- ) ** 0.5
287- elif self .config .final_sigmas_type == "denoise_to_zero" :
288- sigma_last = 0
289- else :
290- raise ValueError (
291- f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got { self .config .final_sigmas_type } "
292- )
293- sigmas = np .concatenate ([[sigma_last ], sigmas ]).astype (np .float32 )
270+ sigma_max = (
271+ (1 - self .alphas_cumprod [self .noisiest_timestep ]) / self .alphas_cumprod [self .noisiest_timestep ]
272+ ) ** 0.5
273+ sigmas = np .concatenate ([sigmas , [sigma_max ]]).astype (np .float32 )
294274
295275 self .sigmas = torch .from_numpy (sigmas )
296276 self .timesteps = torch .from_numpy (timesteps ).to (device = device , dtype = torch .int64 )
@@ -797,7 +777,6 @@ def _init_step_index(self, timestep):
797777
798778 self ._step_index = step_index
799779
800- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
801780 def step (
802781 self ,
803782 model_output : torch .FloatTensor ,
@@ -838,9 +817,7 @@ def step(
838817
839818 # Improve numerical stability for small number of steps
840819 lower_order_final = (self .step_index == len (self .timesteps ) - 1 ) and (
841- self .config .euler_at_final
842- or (self .config .lower_order_final and len (self .timesteps ) < 15 )
843- or self .config .final_sigmas_type == "denoise_to_zero"
820+ self .config .euler_at_final or (self .config .lower_order_final and len (self .timesteps ) < 15 )
844821 )
845822 lower_order_second = (
846823 (self .step_index == len (self .timesteps ) - 2 ) and self .config .lower_order_final and len (self .timesteps ) < 15
0 commit comments