LoRa_Streamlit / ai-toolkit /toolkit /samplers /custom_flowmatch_sampler.py
ramimu's picture
Upload 586 files
1c72248 verified
raw
history blame
7.99 kB
import math
from typing import Union
from torch.distributions import LogNormal
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
import numpy as np
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_noise_sigma = 1.0
self.timestep_type = "linear"
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# Bell-Shaped Mean-Normalized Timestep Weighting
# bsmntw? need a better name
x = torch.arange(num_timesteps, dtype=torch.float32)
y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2)
# Shift minimum to 0
y_shifted = y - y.min()
# Scale to make mean 1
bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# only do half bell
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# flatten second half to max
hbsmntw_weighing[num_timesteps //
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
self.linear_timesteps = timesteps
self.linear_timesteps_weights = bsmntw_weighing
self.linear_timesteps_weights2 = hbsmntw_weighing
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item()
for t in timesteps]
# Get the weights for the timesteps
if v2:
weights = self.linear_timesteps_weights2[step_indices].flatten()
else:
weights = self.linear_timesteps_weights[step_indices].flatten()
return weights
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item()
for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
t_01 = (timesteps / 1000).to(original_samples.device)
# forward ODE
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
# reverse ODE
# noisy_model_input = (1 - t_01) * noise + t_01 * original_samples
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(
self,
num_timesteps,
device,
timestep_type='linear',
latents=None,
patch_size=1
):
self.timestep_type = timestep_type
if timestep_type == 'linear':
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
elif timestep_type == 'sigmoid':
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
# Scale and reverse the values to go from 1000 to 0
timesteps = ((1 - t) * 1000)
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']:
# matches inference dynamic shifting
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(
self.sigma_min), num_timesteps
)
sigmas = timesteps / self.config.num_train_timesteps
if self.config.use_dynamic_shifting:
if latents is None:
raise ValueError('latents is None')
# for flux we double up the patch size before sending her to simulate the latent reduction
h = latents.shape[2]
w = latents.shape[3]
image_seq_len = h * w // (patch_size**2)
mu = calculate_shift(
image_seq_len,
self.config.get("base_image_seq_len", 256),
self.config.get("max_image_seq_len", 4096),
self.config.get("base_shift", 0.5),
self.config.get("max_shift", 1.16),
)
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
sigmas = torch.from_numpy(sigmas).to(
dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat(
[sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat(
[sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'lognorm_blend':
# disgtribute timestepd to the center/early and blend in linear
alpha = 0.75
lognormal = LogNormal(loc=0, scale=0.333)
# Sample from the distribution
t1 = lognormal.sample((int(num_timesteps * alpha),)).to(device)
# Scale and reverse the values to go from 1000 to 0
t1 = ((1 - t1/t1.max()) * 1000)
# add half of linear
t2 = torch.linspace(1000, 0, int(
num_timesteps * (1 - alpha)), device=device)
timesteps = torch.cat((t1, t2))
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
timesteps = timesteps.to(torch.int)
self.timesteps = timesteps.to(device=device)
return timesteps
else:
raise ValueError(f"Invalid timestep type: {timestep_type}")