From d1f764aa5ca95c944024b6a53637cb3e662a644f Mon Sep 17 00:00:00 2001 From: Anand Kumar <63339285+AnandK27@users.noreply.github.com> Date: Fri, 6 Sep 2024 23:59:14 -0700 Subject: [PATCH 1/2] Update scheduling_ddpm.py --- src/diffusers/schedulers/scheduling_ddpm.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 81a770edf635..b332f8b0e80f 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -545,16 +545,12 @@ def __len__(self): return self.config.num_train_timesteps def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t From 0fb5fb8578c5b735caa461904418707046f9e14a Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 3 Dec 2024 08:53:19 +0000 Subject: [PATCH 2/2] fix copies --- src/diffusers/schedulers/scheduling_ddpm_parallel.py | 8 ++------ src/diffusers/schedulers/scheduling_lcm.py | 8 ++------ src/diffusers/schedulers/scheduling_tcd.py | 8 ++------ 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index f377ee6e8c93..20ad7a4c927d 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -639,16 +639,12 @@ def __len__(self): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index f1aa09ab1723..686b686f6870 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -643,16 +643,12 @@ def __len__(self): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 580224404c54..5d60383142a4 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -680,16 +680,12 @@ def __len__(self): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t