Skip to content

Commit

Permalink
Change DPM-Solver to use ancestral sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed Oct 3, 2022
1 parent ec78888 commit 567e11f
Showing 1 changed file with 33 additions and 26 deletions.
59 changes: 33 additions & 26 deletions k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def to_d(x, sigma, denoised):
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
if not eta:
return sigma_to, 0.
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
Expand Down Expand Up @@ -317,6 +319,9 @@ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
return x_3, eps_cache

def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.):
if not t_end > t_start and eta:
raise ValueError('eta must be 0 for reverse sampling')

m = math.floor(nfe / 3) + 1
ts = torch.linspace(t_start, t_end, m + 1, device=x.device)

Expand All @@ -328,76 +333,78 @@ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.):
for i in range(len(orders)):
eps_cache = {}
t, t_next = ts[i], ts[i + 1]
gamma = eta * torch.sqrt(2 * (t_next - t).abs())
t = torch.maximum(torch.minimum(t_start, t_end), t - gamma.log1p())
noise = torch.randn_like(x) * s_noise
if t < ts[i]:
x = x + noise * (self.sigma(t) ** 2 - self.sigma(ts[i]) ** 2).sqrt()
if eta:
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
t_next_ = torch.minimum(t_end, self.t(sd))
su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
else:
t_next_, su = t_next, 0.

eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
denoised = x - self.sigma(t) * eps
if self.info_callback is not None:
self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})

if orders[i] == 1:
x, eps_cache = self.dpm_solver_1_step(x, t, t_next, eps_cache=eps_cache)
x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
elif orders[i] == 2:
x, eps_cache = self.dpm_solver_2_step(x, t, t_next, eps_cache=eps_cache)
x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
else:
x, eps_cache = self.dpm_solver_3_step(x, t, t_next, eps_cache=eps_cache)
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)

x = x + su * s_noise * torch.randn_like(x)

return x

def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1.):
if order not in {2, 3}:
raise ValueError('order should be 2 or 3')
forward = t_end > t_start
if not forward and eta:
raise ValueError('eta must be 0 for reverse sampling')
h_init = abs(h_init) * (1 if forward else -1)
atol = torch.tensor(atol)
rtol = torch.tensor(rtol)
s = t_start
x_prev = x
accept = True
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, order, accept_safety)
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}

while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
eps_cache = {}
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
if eta:
sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
t_ = torch.minimum(t_end, self.t(sd))
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
else:
t_, su = t, 0.

gamma = eta * torch.sqrt(2 * (t - s).abs())
s_ = torch.maximum(torch.minimum(t_start, t_end), s - gamma.log1p())
x_pre_noise = x
if accept:
noise = torch.randn_like(x) * s_noise
if s_ < s:
x = x + noise * (self.sigma(s_) ** 2 - self.sigma(s) ** 2).sqrt()

eps, eps_cache = self.eps(eps_cache, 'eps', x, s_)
denoised = x - self.sigma(s_) * eps
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
denoised = x - self.sigma(s) * eps

if order == 2:
x_low, eps_cache = self.dpm_solver_1_step(x, s_, t, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_2_step(x, s_, t, eps_cache=eps_cache)
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
else:
x_low, eps_cache = self.dpm_solver_2_step(x, s_, t, r1=1 / 3, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_3_step(x, s_, t, eps_cache=eps_cache)
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
accept = pid.propose_step(error)
if accept:
x_prev = x_low
x = x_high
x = x_high + su * s_noise * torch.randn_like(x_high)
s = t
info['n_accept'] += 1
else:
x = x_pre_noise
info['n_reject'] += 1
info['nfe'] += order
info['steps'] += 1

if self.info_callback is not None:
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s_, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})

return x, info

Expand Down

0 comments on commit 567e11f

Please sign in to comment.