Spaces:
Running
Running
File size: 5,476 Bytes
2ba4412 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import math
import torch
def beta_schedule(schedule='cosine',
num_timesteps=1000,
zero_terminal_snr=False,
**kwargs):
# compute betas
betas = {
# 'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
'linear': linear_schedule,
'linear_sd': linear_sd_schedule,
'quadratic': quadratic_schedule,
'cosine': cosine_schedule
}[schedule](num_timesteps, **kwargs)
if zero_terminal_snr and abs(betas.max() - 1.0) > 0.0001:
betas = rescale_zero_terminal_snr(betas)
return betas
def sigma_schedule(schedule='cosine',
num_timesteps=1000,
zero_terminal_snr=False,
**kwargs):
# compute betas
betas = {
'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
'linear': linear_schedule,
'linear_sd': linear_sd_schedule,
'quadratic': quadratic_schedule,
'cosine': cosine_schedule
}[schedule](num_timesteps, **kwargs)
if schedule == 'logsnr_cosine_interp':
sigma = betas
else:
sigma = betas_to_sigmas(betas)
if zero_terminal_snr and abs(sigma.max() - 1.0) > 0.0001:
sigma = rescale_zero_terminal_snr(sigma)
return sigma
def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs):
scale = 1000.0 / num_timesteps
init_beta = init_beta or scale * 0.0001
ast_beta = last_beta or scale * 0.02
return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64)
def logsnr_cosine_interp_schedule(
num_timesteps,
scale_min=2,
scale_max=4,
logsnr_min=-15,
logsnr_max=15,
**kwargs):
return logsnrs_to_sigmas(
_logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max))
def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs):
return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs):
init_beta = init_beta or 0.0015
last_beta = last_beta or 0.0195
return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs):
betas = []
for step in range(num_timesteps):
t1 = step / num_timesteps
t2 = (step + 1) / num_timesteps
fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
return torch.tensor(betas, dtype=torch.float64)
# def cosine_schedule(n, cosine_s=0.008, **kwargs):
# ramp = torch.linspace(0, 1, n + 1)
# square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2
# betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999)
# return betas_to_sigmas(betas)
def betas_to_sigmas(betas):
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
def sigmas_to_betas(sigmas):
square_alphas = 1 - sigmas**2
betas = 1 - torch.cat(
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
return betas
def sigmas_to_logsnrs(sigmas):
square_sigmas = sigmas**2
return torch.log(square_sigmas / (1 - square_sigmas))
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
t_min = math.atan(math.exp(-0.5 * logsnr_min))
t_max = math.atan(math.exp(-0.5 * logsnr_max))
t = torch.linspace(1, 0, n)
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
return logsnrs
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
logsnrs += 2 * math.log(1 / scale)
return logsnrs
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
ramp = torch.linspace(1, 0, n)
min_inv_rho = sigma_min**(1 / rho)
max_inv_rho = sigma_max**(1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
return sigmas
def _logsnr_cosine_interp(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
t = torch.linspace(1, 0, n)
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
return logsnrs
def logsnrs_to_sigmas(logsnrs):
return torch.sqrt(torch.sigmoid(-logsnrs))
def rescale_zero_terminal_snr(betas):
"""
Rescale Schedule to Zero Terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1 - betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()
# Store old values. 8 alphas_bar_sqrt_0 = a
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt ** 2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
|