|
|
|
import logging |
|
from typing import Callable, Dict, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
import torchdiffeq |
|
import random |
|
|
|
from sde import multi_scale_targets |
|
from diffusion.base_solver import Solver |
|
import numpy as np |
|
from torchvision import transforms |
|
|
|
|
|
def check_zip(*args): |
|
args = [list(arg) for arg in args] |
|
length = len(args[0]) |
|
for arg in args: |
|
assert len(arg) == length |
|
return zip(*args) |
|
|
|
|
|
def kl_divergence(source, target): |
|
q_raw = source.view(-1) |
|
p_raw = target.view(-1) |
|
|
|
p = F.softmax(p_raw, dim=0) |
|
q = F.softmax(q_raw, dim=0) |
|
|
|
|
|
q_log = torch.log(q) |
|
kl_div_1 = F.kl_div(q_log, p, reduction='sum') |
|
|
|
return kl_div_1 |
|
|
|
|
|
|
|
class TimeStepSampler: |
|
""" |
|
Abstract class to sample timesteps for flow matching. |
|
""" |
|
|
|
def sample_time(self, x_start): |
|
|
|
|
|
raise NotImplementedError |
|
|
|
class ClipLoss(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
local_loss=False, |
|
gather_with_grad=False, |
|
cache_labels=False, |
|
rank=0, |
|
world_size=1, |
|
use_horovod=False, |
|
): |
|
super().__init__() |
|
self.local_loss = local_loss |
|
self.gather_with_grad = gather_with_grad |
|
self.cache_labels = cache_labels |
|
self.rank = rank |
|
self.world_size = world_size |
|
self.use_horovod = use_horovod |
|
|
|
|
|
self.prev_num_logits = 0 |
|
self.labels = {} |
|
|
|
def get_ground_truth(self, device, num_logits) -> torch.Tensor: |
|
|
|
if self.prev_num_logits != num_logits or device not in self.labels: |
|
labels = torch.arange(num_logits, device=device, dtype=torch.long) |
|
if self.world_size > 1 and self.local_loss: |
|
labels = labels + num_logits * self.rank |
|
if self.cache_labels: |
|
self.labels[device] = labels |
|
self.prev_num_logits = num_logits |
|
else: |
|
labels = self.labels[device] |
|
return labels |
|
|
|
def get_logits(self, image_features, text_features, logit_scale): |
|
if self.world_size > 1: |
|
all_image_features, all_text_features = gather_features( |
|
image_features, text_features, |
|
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) |
|
|
|
if self.local_loss: |
|
logits_per_image = logit_scale * image_features @ all_text_features.T |
|
logits_per_text = logit_scale * text_features @ all_image_features.T |
|
else: |
|
logits_per_image = logit_scale * all_image_features @ all_text_features.T |
|
logits_per_text = logits_per_image.T |
|
else: |
|
logits_per_image = logit_scale * image_features @ text_features.T |
|
logits_per_text = logit_scale * text_features @ image_features.T |
|
|
|
return logits_per_image, logits_per_text |
|
|
|
def forward(self, image_features, text_features, logit_scale, output_dict=False): |
|
device = image_features.device |
|
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) |
|
|
|
labels = self.get_ground_truth(device, logits_per_image.shape[0]) |
|
|
|
total_loss = ( |
|
F.cross_entropy(logits_per_image, labels) + |
|
F.cross_entropy(logits_per_text, labels) |
|
) / 2 |
|
|
|
return {"contrastive_loss": total_loss} if output_dict else total_loss |
|
|
|
|
|
class SigLipLoss(nn.Module): |
|
""" Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 |
|
|
|
@article{zhai2023sigmoid, |
|
title={Sigmoid loss for language image pre-training}, |
|
author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, |
|
journal={arXiv preprint arXiv:2303.15343}, |
|
year={2023} |
|
} |
|
""" |
|
def __init__( |
|
self, |
|
cache_labels=False, |
|
rank=0, |
|
world_size=1, |
|
bidir=True, |
|
use_horovod=False, |
|
): |
|
super().__init__() |
|
self.cache_labels = cache_labels |
|
self.rank = rank |
|
self.world_size = world_size |
|
assert not use_horovod |
|
self.use_horovod = use_horovod |
|
self.bidir = bidir |
|
|
|
|
|
self.prev_num_logits = 0 |
|
self.labels = {} |
|
|
|
def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor: |
|
labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) |
|
if not negative_only: |
|
labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels |
|
return labels |
|
|
|
def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): |
|
logits = logit_scale * image_features @ text_features.T |
|
if logit_bias is not None: |
|
logits += logit_bias |
|
return logits |
|
|
|
def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False): |
|
logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) |
|
labels = self.get_ground_truth( |
|
image_features.device, |
|
image_features.dtype, |
|
image_features.shape[0], |
|
negative_only=negative_only, |
|
) |
|
loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] |
|
return loss |
|
|
|
def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False): |
|
loss = self._loss(image_features, text_features, logit_scale, logit_bias) |
|
|
|
if self.world_size > 1: |
|
|
|
right_rank = (self.rank + 1) % self.world_size |
|
left_rank = (self.rank - 1 + self.world_size) % self.world_size |
|
if self.bidir: |
|
text_features_to_right = text_features_to_left = text_features |
|
num_bidir, remainder = divmod(self.world_size - 1, 2) |
|
for i in range(num_bidir): |
|
text_features_recv = neighbour_exchange_bidir_with_grad( |
|
left_rank, |
|
right_rank, |
|
text_features_to_left, |
|
text_features_to_right, |
|
) |
|
|
|
for f in text_features_recv: |
|
loss += self._loss( |
|
image_features, |
|
f, |
|
logit_scale, |
|
logit_bias, |
|
negative_only=True, |
|
) |
|
text_features_to_left, text_features_to_right = text_features_recv |
|
|
|
if remainder: |
|
text_features_recv = neighbour_exchange_with_grad( |
|
left_rank, right_rank, text_features_to_right) |
|
|
|
loss += self._loss( |
|
image_features, |
|
text_features_recv, |
|
logit_scale, |
|
logit_bias, |
|
negative_only=True, |
|
) |
|
else: |
|
text_features_to_right = text_features |
|
for i in range(self.world_size - 1): |
|
text_features_from_left = neighbour_exchange_with_grad( |
|
left_rank, right_rank, text_features_to_right) |
|
|
|
loss += self._loss( |
|
image_features, |
|
text_features_from_left, |
|
logit_scale, |
|
logit_bias, |
|
negative_only=True, |
|
) |
|
text_features_to_right = text_features_from_left |
|
|
|
return {"contrastive_loss": loss} if output_dict else loss |
|
|
|
|
|
class ResolutionScaledTimeStepSampler(TimeStepSampler): |
|
def __init__(self, scale: float, base_time_step_sampler: TimeStepSampler): |
|
self.scale = scale |
|
self.base_time_step_sampler = base_time_step_sampler |
|
|
|
@torch.no_grad() |
|
def sample_time(self, x_start): |
|
base_time = self.base_time_step_sampler.sample_time(x_start) |
|
|
|
scaled_time = (base_time * self.scale) / (1 + (self.scale - 1) * base_time) |
|
return scaled_time |
|
|
|
|
|
class LogitNormalSampler(TimeStepSampler): |
|
def __init__(self, normal_mean: float = 0, normal_std: float = 1): |
|
|
|
|
|
|
|
self.normal_mean = float(normal_mean) |
|
self.normal_std = float(normal_std) |
|
|
|
@torch.no_grad() |
|
def sample_time(self, x_start): |
|
x_normal = torch.normal( |
|
mean=self.normal_mean, |
|
std=self.normal_std, |
|
size=(x_start.shape[0],), |
|
device=x_start.device, |
|
) |
|
x_logistic = torch.nn.functional.sigmoid(x_normal) |
|
return x_logistic |
|
|
|
|
|
class UniformTimeSampler(TimeStepSampler): |
|
@torch.no_grad() |
|
def sample_time(self, x_start): |
|
|
|
return torch.rand(x_start.shape[0], device=x_start.device) |
|
|
|
|
|
class FlowMatching(nn.Module): |
|
def __init__( |
|
self, |
|
sigma_min: float = 1e-5, |
|
sigma_max: float = 1.0, |
|
timescale: float = 1.0, |
|
**kwargs, |
|
): |
|
|
|
super().__init__(**kwargs) |
|
self.time_step_sampler = LogitNormalSampler() |
|
self.sigma_min = sigma_min |
|
self.sigma_max = sigma_max |
|
self.timescale = timescale |
|
|
|
self.clip_loss = ClipLoss() |
|
|
|
|
|
self.resizer = transforms.Resize(256) |
|
|
|
def sample_noise(self, x_start): |
|
|
|
return torch.randn_like(x_start, device=x_start.device) * self.sigma_max |
|
|
|
def mos(self, err, start_dim=1, con_mask=None): |
|
if con_mask is not None: |
|
return (err.pow(2).mean(dim=-1) * con_mask).sum(dim=-1) / con_mask.sum(dim=-1) |
|
else: |
|
return err.pow(2).flatten(start_dim=start_dim).mean(dim=-1) |
|
|
|
|
|
def Xentropy(self, pred, tar, con_mask=None): |
|
if con_mask is not None: |
|
return (nn.functional.cross_entropy(pred, tar, reduction='none') * con_mask).sum(dim=-1) / con_mask.sum(dim=-1) |
|
else: |
|
return nn.functional.cross_entropy(pred, tar, reduction='none').mean(dim=-1) |
|
|
|
def l2_reg(self, pred, lam = 0.0001): |
|
return lam * torch.norm(pred, p=2, dim=(1, 2, 3)) ** 2 |
|
|
|
|
|
def forward( |
|
self, |
|
x, |
|
nnet, |
|
loss_coeffs, |
|
cond, |
|
con_mask, |
|
nnet_style, |
|
training_step, |
|
cond_ori=None, |
|
con_mask_ori=None, |
|
batch_img_clip=None, |
|
model_config=None, |
|
all_config=None, |
|
text_token=None, |
|
return_raw_loss=False, |
|
additional_embeddings=None, |
|
timesteps: Optional[Tuple[int, int]] = None, |
|
*args, |
|
**kwargs, |
|
): |
|
assert timesteps is None, "timesteps must be None" |
|
|
|
timesteps = self.time_step_sampler.sample_time(x) |
|
|
|
if nnet_style == 'dimr': |
|
if hasattr(model_config, "standard_diffusion") and model_config.standard_diffusion: |
|
standard_diffusion=True |
|
else: |
|
standard_diffusion=False |
|
return self.p_losses_textVAE( |
|
x, cond, con_mask, timesteps, nnet, batch_img_clip=batch_img_clip, cond_ori=cond_ori, con_mask_ori=con_mask_ori, text_token=text_token, loss_coeffs=loss_coeffs, return_raw_loss=return_raw_loss, nnet_style=nnet_style, standard_diffusion=standard_diffusion, all_config=all_config, training_step=training_step, *args, **kwargs |
|
) |
|
elif nnet_style == 'dit': |
|
if hasattr(model_config, "standard_diffusion") and model_config.standard_diffusion: |
|
standard_diffusion=True |
|
raise NotImplementedError("need update") |
|
else: |
|
standard_diffusion=False |
|
return self.p_losses_textVAE_dit( |
|
x, cond, con_mask, timesteps, nnet, batch_img_clip=batch_img_clip, cond_ori=cond_ori, con_mask_ori=con_mask_ori, text_token=text_token, loss_coeffs=loss_coeffs, return_raw_loss=return_raw_loss, nnet_style=nnet_style, standard_diffusion=standard_diffusion, all_config=all_config, training_step=training_step, *args, **kwargs |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
def p_losses_textVAE( |
|
self, |
|
x_start, |
|
cond, |
|
con_mask, |
|
t, |
|
nnet, |
|
loss_coeffs, |
|
training_step, |
|
text_token=None, |
|
nnet_style=None, |
|
all_config=None, |
|
batch_img_clip=None, |
|
cond_ori=None, |
|
con_mask_ori=None, |
|
return_raw_loss=False, |
|
additional_embeddings=None, |
|
standard_diffusion=False, |
|
noise=None, |
|
): |
|
""" |
|
CrossFlow training for DiMR |
|
""" |
|
|
|
assert noise is None |
|
|
|
x0, mu, log_var = nnet(cond, text_encoder = True, shape = x_start.shape, mask = con_mask) |
|
|
|
|
|
if batch_img_clip.shape[-1] == 512: |
|
recon_gt = self.resizer(batch_img_clip) |
|
else: |
|
recon_gt = batch_img_clip |
|
recon_gt_clip, logit_scale = nnet(recon_gt, image_clip = True) |
|
image_features = recon_gt_clip / recon_gt_clip.norm(dim=-1, keepdim=True) |
|
text_features = x0 / x0.norm(dim=-1, keepdim=True) |
|
recons_loss = self.clip_loss(image_features, text_features, logit_scale) |
|
|
|
|
|
kld_loss = -0.5 * torch.sum(1 + log_var - (0.3 * mu) ** 6 - log_var.exp(), dim = 1) |
|
kld_loss_weight = 1e-2 |
|
|
|
loss_mlp = recons_loss + kld_loss * kld_loss_weight |
|
|
|
|
|
|
|
noise = x0.reshape(x_start.shape) |
|
|
|
if hasattr(all_config.nnet.model_args, "cfg_indicator"): |
|
null_indicator = torch.from_numpy(np.array([random.random() < all_config.nnet.model_args.cfg_indicator for _ in range(x_start.shape[0])])).to(x_start.device) |
|
if null_indicator.sum()<=1: |
|
null_indicator[null_indicator==True] = False |
|
assert null_indicator.sum() == 0 |
|
pass |
|
else: |
|
target_null = x_start[null_indicator] |
|
target_null = torch.cat((target_null[1:], target_null[:1])) |
|
x_start[null_indicator] = target_null |
|
else: |
|
null_indicator = None |
|
|
|
|
|
x_noisy = self.psi(t, x=noise, x1=x_start) |
|
target_velocity = self.Dt_psi(t, x=noise, x1=x_start) |
|
log_snr = 4 - t * 8 |
|
|
|
prediction = nnet(x_noisy, log_snr = log_snr, null_indicator=null_indicator) |
|
|
|
target = multi_scale_targets(target_velocity, levels = len(prediction), scale_correction = True) |
|
|
|
loss_diff = 0 |
|
for pred, coeff in check_zip(prediction, loss_coeffs): |
|
loss_diff = loss_diff + coeff * self.mos(pred - target[pred.shape[-1]]) |
|
|
|
|
|
|
|
loss = loss_diff + loss_mlp |
|
|
|
return loss, {'loss_diff': loss_diff, 'clip_loss': recons_loss, 'kld_loss': kld_loss, 'kld_loss_weight': torch.tensor(kld_loss_weight, device=kld_loss.device), 'clip_logit_scale': logit_scale} |
|
|
|
|
|
def p_losses_textVAE_dit( |
|
self, |
|
x_start, |
|
cond, |
|
con_mask, |
|
t, |
|
nnet, |
|
loss_coeffs, |
|
training_step, |
|
text_token=None, |
|
nnet_style=None, |
|
all_config=None, |
|
batch_img_clip=None, |
|
cond_ori=None, |
|
con_mask_ori=None, |
|
return_raw_loss=False, |
|
additional_embeddings=None, |
|
standard_diffusion=False, |
|
noise=None, |
|
): |
|
""" |
|
CrossFLow training for DiT |
|
""" |
|
|
|
assert noise is None |
|
|
|
x0, mu, log_var = nnet(cond, text_encoder = True, shape = x_start.shape, mask = con_mask) |
|
|
|
|
|
if batch_img_clip.shape[-1] == 512: |
|
recon_gt = self.resizer(batch_img_clip) |
|
else: |
|
recon_gt = batch_img_clip |
|
recon_gt_clip, logit_scale = nnet(recon_gt, image_clip = True) |
|
image_features = recon_gt_clip / recon_gt_clip.norm(dim=-1, keepdim=True) |
|
text_features = x0 / x0.norm(dim=-1, keepdim=True) |
|
recons_loss = self.clip_loss(image_features, text_features, logit_scale) |
|
|
|
|
|
kld_loss = -0.5 * torch.sum(1 + log_var - (0.3 * mu) ** 6 - log_var.exp(), dim = 1) |
|
kld_loss_weight = 1e-2 |
|
|
|
loss_mlp = recons_loss + kld_loss * kld_loss_weight |
|
|
|
|
|
noise = x0.reshape(x_start.shape) |
|
|
|
if hasattr(all_config.nnet.model_args, "cfg_indicator"): |
|
null_indicator = torch.from_numpy(np.array([random.random() < all_config.nnet.model_args.cfg_indicator for _ in range(x_start.shape[0])])).to(x_start.device) |
|
if null_indicator.sum()<=1: |
|
null_indicator[null_indicator==True] = False |
|
assert null_indicator.sum() == 0 |
|
pass |
|
else: |
|
target_null = x_start[null_indicator] |
|
target_null = torch.cat((target_null[1:], target_null[:1])) |
|
x_start[null_indicator] = target_null |
|
else: |
|
null_indicator = None |
|
|
|
x_noisy = self.psi(t, x=noise, x1=x_start) |
|
target_velocity = self.Dt_psi(t, x=noise, x1=x_start) |
|
|
|
prediction = nnet(x_noisy, t = t, null_indicator = null_indicator)[0] |
|
|
|
loss_diff = self.mos(prediction - target_velocity) |
|
|
|
|
|
|
|
loss = loss_diff + loss_mlp |
|
|
|
return loss, {'loss_diff': loss_diff, 'clip_loss': recons_loss, 'kld_loss': kld_loss, 'kld_loss_weight': torch.tensor(kld_loss_weight, device=kld_loss.device), 'clip_logit_scale': logit_scale} |
|
|
|
|
|
|
|
def psi(self, t, x, x1): |
|
assert ( |
|
t.shape[0] == x.shape[0] |
|
), f"Batch size of t and x does not agree {t.shape[0]} vs. {x.shape[0]}" |
|
assert ( |
|
t.shape[0] == x1.shape[0] |
|
), f"Batch size of t and x1 does not agree {t.shape[0]} vs. {x1.shape[0]}" |
|
assert t.ndim == 1 |
|
t = self.expand_t(t, x) |
|
return (t * (self.sigma_min / self.sigma_max - 1) + 1) * x + t * x1 |
|
|
|
def Dt_psi(self, t: torch.Tensor, x: torch.Tensor, x1: torch.Tensor): |
|
assert x.shape[0] == x1.shape[0] |
|
return (self.sigma_min / self.sigma_max - 1) * x + x1 |
|
|
|
def expand_t(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: |
|
t_expanded = t |
|
while t_expanded.ndim < x.ndim: |
|
t_expanded = t_expanded.unsqueeze(-1) |
|
return t_expanded.expand_as(x) |
|
|
|
|
|
|
|
|
|
class ODEEulerFlowMatchingSolver(Solver): |
|
""" |
|
ODE Solver for Flow matching that uses an Euler discretization |
|
Supports number of time steps at inference |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.step_size_type = kwargs.get("step_size_type", "step_in_dsigma") |
|
assert self.step_size_type in ["step_in_dsigma", "step_in_dt"] |
|
self.sample_timescale = 1.0 - 1e-5 |
|
|
|
@torch.no_grad() |
|
def sample_euler( |
|
self, |
|
x_T, |
|
unconditional_guidance_scale, |
|
has_null_indicator, |
|
t=[0, 1.0], |
|
**kwargs, |
|
): |
|
""" |
|
Euler solver for flow matching. |
|
Based on https://github.com/VinAIResearch/LFM/blob/main/sampler/karras_sample.py |
|
""" |
|
t = torch.tensor(t) |
|
t = t * self.sample_timescale |
|
sigma_min = 1e-5 |
|
sigma_max = 1.0 |
|
sigma_steps = torch.linspace( |
|
sigma_min, sigma_max, self.num_time_steps + 1, device=x_T.device |
|
) |
|
discrete_time_steps_for_step = torch.linspace( |
|
t[0], t[1], self.num_time_steps + 1, device=x_T.device |
|
) |
|
discrete_time_steps_to_eval_model_at = torch.linspace( |
|
t[0], t[1], self.num_time_steps, device=x_T.device |
|
) |
|
|
|
print("num_time_steps : " + str(self.num_time_steps)) |
|
|
|
for i in range(self.num_time_steps): |
|
t_i = discrete_time_steps_to_eval_model_at[i] |
|
velocity = self.get_model_output_dimr( |
|
x_T, |
|
has_null_indicator = has_null_indicator, |
|
t_continuous = t_i.repeat(x_T.shape[0]), |
|
unconditional_guidance_scale = unconditional_guidance_scale, |
|
) |
|
if self.step_size_type == "step_in_dsigma": |
|
step_size = sigma_steps[i + 1] - sigma_steps[i] |
|
elif self.step_size_type == "step_in_dt": |
|
step_size = ( |
|
discrete_time_steps_for_step[i + 1] |
|
- discrete_time_steps_for_step[i] |
|
) |
|
x_T = x_T + velocity * step_size |
|
|
|
intermediates = None |
|
return x_T, intermediates |
|
|
|
@torch.no_grad() |
|
def sample( |
|
self, |
|
*args, |
|
**kwargs, |
|
): |
|
assert kwargs.get("ucg_schedule", None) is None |
|
assert kwargs.get("skip_type", None) is None |
|
assert kwargs.get("dynamic_threshold", None) is None |
|
assert kwargs.get("x0", None) is None |
|
assert kwargs.get("x_T") is not None |
|
assert kwargs.get("score_corrector", None) is None |
|
assert kwargs.get("normals_sequence", None) is None |
|
assert kwargs.get("callback", None) is None |
|
assert kwargs.get("quantize_x0", False) is False |
|
assert kwargs.get("eta", 0.0) == 0.0 |
|
assert kwargs.get("mask", None) is None |
|
assert kwargs.get("noise_dropout", 0.0) == 0.0 |
|
|
|
self.num_time_steps = kwargs.get("sample_steps") |
|
self.x_T_uncon = kwargs.get("x_T_uncon") |
|
|
|
samples, intermediates = super().sample( |
|
*args, |
|
sampling_method=self.sample_euler, |
|
do_make_schedule=False, |
|
**kwargs, |
|
) |
|
return samples, intermediates |
|
|
|
|
|
class ODEFlowMatchingSolver(Solver): |
|
""" |
|
ODE Solver for Flow matching that uses `dopri5` |
|
Does not support number of time steps based control |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.sample_timescale = 1.0 - 1e-5 |
|
|
|
|
|
@torch.no_grad() |
|
def sample_transport( |
|
self, |
|
x_T, |
|
unconditional_guidance_scale, |
|
has_null_indicator, |
|
t=[0, 1.0], |
|
ode_opts={}, |
|
**kwargs, |
|
): |
|
num_evals = 0 |
|
t = torch.tensor(t, device=x_T.device) |
|
if "options" not in ode_opts: |
|
ode_opts["options"] = {} |
|
ode_opts["options"]["step_t"] = [self.sample_timescale + 1e-6] |
|
|
|
def ode_func(t, x_T): |
|
nonlocal num_evals |
|
num_evals += 1 |
|
model_output = self.get_model_output_dimr( |
|
x_T, |
|
has_null_indicator = has_null_indicator, |
|
t_continuous = t.repeat(x_T.shape[0]), |
|
unconditional_guidance_scale = unconditional_guidance_scale, |
|
) |
|
return model_output |
|
|
|
z = torchdiffeq.odeint( |
|
ode_func, |
|
x_T, |
|
t * self.sample_timescale, |
|
**{"atol": 1e-5, "rtol": 1e-5, "method": "dopri5", **ode_opts}, |
|
) |
|
|
|
|
|
z = z[-1] |
|
intermediates = None |
|
return z, intermediates |
|
|
|
@torch.no_grad() |
|
def sample( |
|
self, |
|
*args, |
|
**kwargs, |
|
): |
|
assert kwargs.get("ucg_schedule", None) is None |
|
assert kwargs.get("skip_type", None) is None |
|
assert kwargs.get("dynamic_threshold", None) is None |
|
assert kwargs.get("x0", None) is None |
|
assert kwargs.get("x_T") is not None |
|
assert kwargs.get("score_corrector", None) is None |
|
assert kwargs.get("normals_sequence", None) is None |
|
assert kwargs.get("callback", None) is None |
|
assert kwargs.get("quantize_x0", False) is False |
|
assert kwargs.get("eta", 0.0) == 0.0 |
|
assert kwargs.get("mask", None) is None |
|
assert kwargs.get("noise_dropout", 0.0) == 0.0 |
|
samples, intermediates = super().sample( |
|
*args, |
|
sampling_method=self.sample_transport, |
|
do_make_schedule=False, |
|
**kwargs, |
|
) |
|
return samples, intermediates |