Spaces:
Running
Running
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import math | |
import torch | |
def beta_schedule(schedule='cosine', | |
num_timesteps=1000, | |
zero_terminal_snr=False, | |
**kwargs): | |
# compute betas | |
betas = { | |
# 'logsnr_cosine_interp': logsnr_cosine_interp_schedule, | |
'linear': linear_schedule, | |
'linear_sd': linear_sd_schedule, | |
'quadratic': quadratic_schedule, | |
'cosine': cosine_schedule | |
}[schedule](num_timesteps, **kwargs) | |
if zero_terminal_snr and abs(betas.max() - 1.0) > 0.0001: | |
betas = rescale_zero_terminal_snr(betas) | |
return betas | |
def sigma_schedule(schedule='cosine', | |
num_timesteps=1000, | |
zero_terminal_snr=False, | |
**kwargs): | |
# compute betas | |
betas = { | |
'logsnr_cosine_interp': logsnr_cosine_interp_schedule, | |
'linear': linear_schedule, | |
'linear_sd': linear_sd_schedule, | |
'quadratic': quadratic_schedule, | |
'cosine': cosine_schedule | |
}[schedule](num_timesteps, **kwargs) | |
if schedule == 'logsnr_cosine_interp': | |
sigma = betas | |
else: | |
sigma = betas_to_sigmas(betas) | |
if zero_terminal_snr and abs(sigma.max() - 1.0) > 0.0001: | |
sigma = rescale_zero_terminal_snr(sigma) | |
return sigma | |
def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs): | |
scale = 1000.0 / num_timesteps | |
init_beta = init_beta or scale * 0.0001 | |
ast_beta = last_beta or scale * 0.02 | |
return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64) | |
def logsnr_cosine_interp_schedule( | |
num_timesteps, | |
scale_min=2, | |
scale_max=4, | |
logsnr_min=-15, | |
logsnr_max=15, | |
**kwargs): | |
return logsnrs_to_sigmas( | |
_logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max)) | |
def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs): | |
return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2 | |
def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs): | |
init_beta = init_beta or 0.0015 | |
last_beta = last_beta or 0.0195 | |
return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2 | |
def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs): | |
betas = [] | |
for step in range(num_timesteps): | |
t1 = step / num_timesteps | |
t2 = (step + 1) / num_timesteps | |
fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2 | |
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) | |
return torch.tensor(betas, dtype=torch.float64) | |
# def cosine_schedule(n, cosine_s=0.008, **kwargs): | |
# ramp = torch.linspace(0, 1, n + 1) | |
# square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2 | |
# betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999) | |
# return betas_to_sigmas(betas) | |
def betas_to_sigmas(betas): | |
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) | |
def sigmas_to_betas(sigmas): | |
square_alphas = 1 - sigmas**2 | |
betas = 1 - torch.cat( | |
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]]) | |
return betas | |
def sigmas_to_logsnrs(sigmas): | |
square_sigmas = sigmas**2 | |
return torch.log(square_sigmas / (1 - square_sigmas)) | |
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15): | |
t_min = math.atan(math.exp(-0.5 * logsnr_min)) | |
t_max = math.atan(math.exp(-0.5 * logsnr_max)) | |
t = torch.linspace(1, 0, n) | |
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) | |
return logsnrs | |
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2): | |
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max) | |
logsnrs += 2 * math.log(1 / scale) | |
return logsnrs | |
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0): | |
ramp = torch.linspace(1, 0, n) | |
min_inv_rho = sigma_min**(1 / rho) | |
max_inv_rho = sigma_max**(1 / rho) | |
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho | |
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2)) | |
return sigmas | |
def _logsnr_cosine_interp(n, | |
logsnr_min=-15, | |
logsnr_max=15, | |
scale_min=2, | |
scale_max=4): | |
t = torch.linspace(1, 0, n) | |
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min) | |
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max) | |
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max | |
return logsnrs | |
def logsnrs_to_sigmas(logsnrs): | |
return torch.sqrt(torch.sigmoid(-logsnrs)) | |
def rescale_zero_terminal_snr(betas): | |
""" | |
Rescale Schedule to Zero Terminal SNR | |
""" | |
# Convert betas to alphas_bar_sqrt | |
alphas = 1 - betas | |
alphas_bar = alphas.cumprod(0) | |
alphas_bar_sqrt = alphas_bar.sqrt() | |
# Store old values. 8 alphas_bar_sqrt_0 = a | |
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | |
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | |
# Shift so last timestep is zero. | |
alphas_bar_sqrt -= alphas_bar_sqrt_T | |
# Scale so first timestep is back to old value. | |
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | |
# Convert alphas_bar_sqrt to betas | |
alphas_bar = alphas_bar_sqrt ** 2 | |
alphas = alphas_bar[1:] / alphas_bar[:-1] | |
alphas = torch.cat([alphas_bar[0:1], alphas]) | |
betas = 1 - alphas | |
return betas | |