Spaces:
Running
on
L40S
Running
on
L40S
import numpy as np | |
import torch | |
def append_dims(x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") | |
return x[(...,) + (None,) * dims_to_append] | |
# From LCMScheduler.get_scalings_for_boundary_condition_discrete | |
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): | |
scaled_timestep = timestep_scaling * timestep | |
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) | |
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 | |
return c_skip, c_out | |
def extract_into_tensor(a, t, x_shape): | |
b, *_ = t.shape | |
out = a.gather(-1, t) | |
return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
class DDIMSolver: | |
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): | |
# DDIM sampling parameters | |
step_ratio = timesteps // ddim_timesteps | |
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 | |
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] | |
self.ddim_alpha_cumprods_prev = np.asarray( | |
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() | |
) | |
# convert to torch tensors | |
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() | |
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) | |
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) | |
def to(self, device): | |
self.ddim_timesteps = self.ddim_timesteps.to(device) | |
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) | |
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) | |
return self | |
def ddim_step(self, pred_x0, pred_noise, timestep_index): | |
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) | |
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise | |
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt | |
return x_prev |