test / modules /dml /hijack /kdiffusion.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
import torch
from tqdm.auto import tqdm
from k_diffusion import sampling
import modules.devices as devices
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = sampling.default_noise_sampler(x) if noise_sampler is None else noise_sampler
if order not in {2, 3}:
raise ValueError('order should be 2 or 3')
forward = t_end > t_start
if not forward and eta:
raise ValueError('eta must be 0 for reverse sampling')
h_init = abs(h_init) * (1 if forward else -1)
atol = torch.tensor(atol, device=devices.device)
rtol = torch.tensor(rtol, device=devices.device)
s = t_start
x_prev = x
accept = True
pid = sampling.PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
eps_cache = {}
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
if eta:
sd, su = sampling.get_ancestral_step(self.sigma(s), self.sigma(t), eta)
t_ = torch.minimum(t_end, self.t(sd))
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
else:
t_, su = t, 0.
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
denoised = x - self.sigma(s) * eps
if order == 2:
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
else:
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
accept = pid.propose_step(error)
if accept:
x_prev = x_low
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
s = t
info['n_accept'] += 1
else:
info['n_reject'] += 1
info['nfe'] += order
info['steps'] += 1
if self.info_callback is not None:
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
return x, info
@devices.inference_context()
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(total=n, disable=disable) as pbar:
dpm_solver = sampling.DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max, device=devices.device)), dpm_solver.t(torch.tensor(sigma_min, device=devices.device)), n, eta, s_noise, noise_sampler)
@devices.inference_context()
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(disable=disable) as pbar:
dpm_solver = sampling.DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max, device=devices.device)), dpm_solver.t(torch.tensor(sigma_min, device=devices.device)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
if return_info:
return x, info
return x
sampling.DPMSolver.dpm_solver_adaptive = dpm_solver_adaptive
sampling.sample_dpm_fast = sample_dpm_fast
sampling.sample_dpm_adaptive = sample_dpm_adaptive