Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import math
import torch
def beta_schedule(schedule='cosine',
num_timesteps=1000,
zero_terminal_snr=False,
**kwargs):
# compute betas
betas = {
# 'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
'linear': linear_schedule,
'linear_sd': linear_sd_schedule,
'quadratic': quadratic_schedule,
'cosine': cosine_schedule
}[schedule](num_timesteps, **kwargs)
if zero_terminal_snr and abs(betas.max() - 1.0) > 0.0001:
betas = rescale_zero_terminal_snr(betas)
return betas
def sigma_schedule(schedule='cosine',
num_timesteps=1000,
zero_terminal_snr=False,
**kwargs):
# compute betas
betas = {
'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
'linear': linear_schedule,
'linear_sd': linear_sd_schedule,
'quadratic': quadratic_schedule,
'cosine': cosine_schedule
}[schedule](num_timesteps, **kwargs)
if schedule == 'logsnr_cosine_interp':
sigma = betas
else:
sigma = betas_to_sigmas(betas)
if zero_terminal_snr and abs(sigma.max() - 1.0) > 0.0001:
sigma = rescale_zero_terminal_snr(sigma)
return sigma
def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs):
scale = 1000.0 / num_timesteps
init_beta = init_beta or scale * 0.0001
ast_beta = last_beta or scale * 0.02
return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64)
def logsnr_cosine_interp_schedule(
num_timesteps,
scale_min=2,
scale_max=4,
logsnr_min=-15,
logsnr_max=15,
**kwargs):
return logsnrs_to_sigmas(
_logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max))
def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs):
return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs):
init_beta = init_beta or 0.0015
last_beta = last_beta or 0.0195
return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs):
betas = []
for step in range(num_timesteps):
t1 = step / num_timesteps
t2 = (step + 1) / num_timesteps
fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
return torch.tensor(betas, dtype=torch.float64)
# def cosine_schedule(n, cosine_s=0.008, **kwargs):
# ramp = torch.linspace(0, 1, n + 1)
# square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2
# betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999)
# return betas_to_sigmas(betas)
def betas_to_sigmas(betas):
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
def sigmas_to_betas(sigmas):
square_alphas = 1 - sigmas**2
betas = 1 - torch.cat(
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
return betas
def sigmas_to_logsnrs(sigmas):
square_sigmas = sigmas**2
return torch.log(square_sigmas / (1 - square_sigmas))
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
t_min = math.atan(math.exp(-0.5 * logsnr_min))
t_max = math.atan(math.exp(-0.5 * logsnr_max))
t = torch.linspace(1, 0, n)
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
return logsnrs
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
logsnrs += 2 * math.log(1 / scale)
return logsnrs
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
ramp = torch.linspace(1, 0, n)
min_inv_rho = sigma_min**(1 / rho)
max_inv_rho = sigma_max**(1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
return sigmas
def _logsnr_cosine_interp(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
t = torch.linspace(1, 0, n)
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
return logsnrs
def logsnrs_to_sigmas(logsnrs):
return torch.sqrt(torch.sigmoid(-logsnrs))
def rescale_zero_terminal_snr(betas):
"""
Rescale Schedule to Zero Terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1 - betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()
# Store old values. 8 alphas_bar_sqrt_0 = a
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt ** 2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas