Spaces:
Runtime error
Runtime error
import torch | |
from tqdm.auto import tqdm | |
from k_diffusion import sampling | |
import modules.devices as devices | |
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): | |
noise_sampler = sampling.default_noise_sampler(x) if noise_sampler is None else noise_sampler | |
if order not in {2, 3}: | |
raise ValueError('order should be 2 or 3') | |
forward = t_end > t_start | |
if not forward and eta: | |
raise ValueError('eta must be 0 for reverse sampling') | |
h_init = abs(h_init) * (1 if forward else -1) | |
atol = torch.tensor(atol, device=devices.device) | |
rtol = torch.tensor(rtol, device=devices.device) | |
s = t_start | |
x_prev = x | |
accept = True | |
pid = sampling.PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) | |
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} | |
while s < t_end - 1e-5 if forward else s > t_end + 1e-5: | |
eps_cache = {} | |
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) | |
if eta: | |
sd, su = sampling.get_ancestral_step(self.sigma(s), self.sigma(t), eta) | |
t_ = torch.minimum(t_end, self.t(sd)) | |
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 | |
else: | |
t_, su = t, 0. | |
eps, eps_cache = self.eps(eps_cache, 'eps', x, s) | |
denoised = x - self.sigma(s) * eps | |
if order == 2: | |
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) | |
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) | |
else: | |
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) | |
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) | |
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) | |
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 | |
accept = pid.propose_step(error) | |
if accept: | |
x_prev = x_low | |
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) | |
s = t | |
info['n_accept'] += 1 | |
else: | |
info['n_reject'] += 1 | |
info['nfe'] += order | |
info['steps'] += 1 | |
if self.info_callback is not None: | |
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) | |
return x, info | |
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): | |
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" | |
if sigma_min <= 0 or sigma_max <= 0: | |
raise ValueError('sigma_min and sigma_max must not be 0') | |
with tqdm(total=n, disable=disable) as pbar: | |
dpm_solver = sampling.DPMSolver(model, extra_args, eps_callback=pbar.update) | |
if callback is not None: | |
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) | |
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max, device=devices.device)), dpm_solver.t(torch.tensor(sigma_min, device=devices.device)), n, eta, s_noise, noise_sampler) | |
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): | |
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" | |
if sigma_min <= 0 or sigma_max <= 0: | |
raise ValueError('sigma_min and sigma_max must not be 0') | |
with tqdm(disable=disable) as pbar: | |
dpm_solver = sampling.DPMSolver(model, extra_args, eps_callback=pbar.update) | |
if callback is not None: | |
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) | |
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max, device=devices.device)), dpm_solver.t(torch.tensor(sigma_min, device=devices.device)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler) | |
if return_info: | |
return x, info | |
return x | |
sampling.DPMSolver.dpm_solver_adaptive = dpm_solver_adaptive | |
sampling.sample_dpm_fast = sample_dpm_fast | |
sampling.sample_dpm_adaptive = sample_dpm_adaptive | |