Skip to content

Commit

Permalink
[refactor]Scheduler.set_begin_index (#6728)
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyixuxu authored Feb 1, 2024
1 parent ec9840a commit adcbe67
Show file tree
Hide file tree
Showing 28 changed files with 620 additions and 279 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/pia/pipeline_pia.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
59 changes: 41 additions & 18 deletions src/diffusers/schedulers/scheduling_consistency_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,34 @@ def __init__(
self.custom_timesteps = False
self.is_scale_input_called = False
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

indices = (schedule_timesteps == timestep).nonzero()
return indices.item()

@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index

@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index

def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
Expand Down Expand Up @@ -231,6 +243,7 @@ def set_timesteps(
self.timesteps = torch.from_numpy(timesteps).to(device=device)

self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Modified _convert_to_karras implementation that takes in ramp as argument
Expand Down Expand Up @@ -280,23 +293,29 @@ def get_scalings_for_boundary_condition(self, sigma):
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
return c_skip, c_out

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

index_candidates = (self.timesteps == timestep).nonzero()
indices = (schedule_timesteps == timestep).nonzero()

# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
pos = 1 if len(indices) > 1 else 0

return indices[pos].item()

self._step_index = step_index.item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index

def step(
self,
Expand Down Expand Up @@ -412,7 +431,11 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down
59 changes: 44 additions & 15 deletions src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
Expand All @@ -196,6 +197,24 @@ def step_index(self):
"""
return self._step_index

@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Expand Down Expand Up @@ -255,6 +274,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
Expand Down Expand Up @@ -620,11 +640,12 @@ def ind_fn(t, b, c, d):
else:
raise NotImplementedError("only support log-rho multistep deis now")

def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

index_candidates = (self.timesteps == timestep).nonzero()
index_candidates = (schedule_timesteps == timestep).nonzero()

if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
Expand All @@ -637,7 +658,20 @@ def _init_step_index(self, timestep):
else:
step_index = index_candidates[0].item()

self._step_index = step_index
return step_index

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""

if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index

def step(
self,
Expand Down Expand Up @@ -736,16 +770,11 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
# begin_index is None when the scheduler is used for training
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down
Loading

0 comments on commit adcbe67

Please sign in to comment.