Spaces:
Running
on
Zero
Running
on
Zero
import math | |
from collections.abc import Callable | |
import torch | |
from torch import Tensor | |
def get_noise(num_samples: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int): | |
return torch.randn( | |
num_samples, | |
16, | |
# allow for packing | |
2 * math.ceil(height / 16), | |
2 * math.ceil(width / 16), | |
device=device, | |
dtype=dtype, | |
generator=torch.Generator(device=device).manual_seed(seed), | |
) | |
def time_shift(mu: float, sigma: float, t: Tensor): | |
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) | |
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: | |
m = (y2 - y1) / (x2 - x1) | |
b = y1 - m * x1 | |
return lambda x: m * x + b | |
def get_schedule( | |
num_steps: int, | |
image_seq_len: int, | |
base_shift: float = 0.5, | |
max_shift: float = 1.15, | |
shift: bool = True, | |
) -> list[float]: | |
# extra step for zero | |
timesteps = torch.linspace(1, 0, num_steps + 1) | |
# shifting the schedule to favor high timesteps for higher signal images | |
if shift: | |
# estimate mu based on linear estimation between two points | |
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) | |
timesteps = time_shift(mu, 1.0, timesteps) | |
return timesteps.tolist() | |