File size: 4,770 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from tqdm.auto import tqdm
from k_diffusion import sampling
import modules.devices as devices


def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
    noise_sampler = sampling.default_noise_sampler(x) if noise_sampler is None else noise_sampler
    if order not in {2, 3}:
        raise ValueError('order should be 2 or 3')
    forward = t_end > t_start
    if not forward and eta:
        raise ValueError('eta must be 0 for reverse sampling')
    h_init = abs(h_init) * (1 if forward else -1)
    atol = torch.tensor(atol, device=devices.device)
    rtol = torch.tensor(rtol, device=devices.device)
    s = t_start
    x_prev = x
    accept = True
    pid = sampling.PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
    info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}

    while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
        eps_cache = {}
        t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
        if eta:
            sd, su = sampling.get_ancestral_step(self.sigma(s), self.sigma(t), eta)
            t_ = torch.minimum(t_end, self.t(sd))
            su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
        else:
            t_, su = t, 0.

        eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
        denoised = x - self.sigma(s) * eps

        if order == 2:
            x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
            x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
        else:
            x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
            x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
        delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
        error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
        accept = pid.propose_step(error)
        if accept:
            x_prev = x_low
            x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
            s = t
            info['n_accept'] += 1
        else:
            info['n_reject'] += 1
        info['nfe'] += order
        info['steps'] += 1

        if self.info_callback is not None:
            self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})

    return x, info


@devices.inference_context()
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
    """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
    if sigma_min <= 0 or sigma_max <= 0:
        raise ValueError('sigma_min and sigma_max must not be 0')
    with tqdm(total=n, disable=disable) as pbar:
        dpm_solver = sampling.DPMSolver(model, extra_args, eps_callback=pbar.update)
        if callback is not None:
            dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
        return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max, device=devices.device)), dpm_solver.t(torch.tensor(sigma_min, device=devices.device)), n, eta, s_noise, noise_sampler)


@devices.inference_context()
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
    """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
    if sigma_min <= 0 or sigma_max <= 0:
        raise ValueError('sigma_min and sigma_max must not be 0')
    with tqdm(disable=disable) as pbar:
        dpm_solver = sampling.DPMSolver(model, extra_args, eps_callback=pbar.update)
        if callback is not None:
            dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
        x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max, device=devices.device)), dpm_solver.t(torch.tensor(sigma_min, device=devices.device)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
    if return_info:
        return x, info
    return x

sampling.DPMSolver.dpm_solver_adaptive = dpm_solver_adaptive
sampling.sample_dpm_fast = sample_dpm_fast
sampling.sample_dpm_adaptive = sample_dpm_adaptive