import logging from typing import Callable, Dict, Optional, Tuple import torch import torch.nn as nn from torch.nn import functional as F import torchdiffeq import random from sde import multi_scale_targets from diffusion.base_solver import Solver import numpy as np from torchvision import transforms def check_zip(*args): args = [list(arg) for arg in args] length = len(args[0]) for arg in args: assert len(arg) == length return zip(*args) def kl_divergence(source, target): q_raw = source.view(-1) p_raw = target.view(-1) p = F.softmax(p_raw, dim=0) q = F.softmax(q_raw, dim=0) q_log = torch.log(q) kl_div_1 = F.kl_div(q_log, p, reduction='sum') return kl_div_1 class TimeStepSampler: """ Abstract class to sample timesteps for flow matching. """ def sample_time(self, x_start): # In flow matching, time is in range [0, 1] and 1 indicates the original image; 0 is pure noise # this convention is *REVERSE* of diffusion raise NotImplementedError class ClipLoss(nn.Module): def __init__( self, local_loss=False, gather_with_grad=False, cache_labels=False, rank=0, world_size=1, use_horovod=False, ): super().__init__() self.local_loss = local_loss self.gather_with_grad = gather_with_grad self.cache_labels = cache_labels self.rank = rank self.world_size = world_size self.use_horovod = use_horovod # cache state self.prev_num_logits = 0 self.labels = {} def get_ground_truth(self, device, num_logits) -> torch.Tensor: # calculated ground-truth and cache if enabled if self.prev_num_logits != num_logits or device not in self.labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) if self.world_size > 1 and self.local_loss: labels = labels + num_logits * self.rank if self.cache_labels: self.labels[device] = labels self.prev_num_logits = num_logits else: labels = self.labels[device] return labels def get_logits(self, image_features, text_features, logit_scale): if self.world_size > 1: all_image_features, all_text_features = gather_features( image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) if self.local_loss: logits_per_image = logit_scale * image_features @ all_text_features.T logits_per_text = logit_scale * text_features @ all_image_features.T else: logits_per_image = logit_scale * all_image_features @ all_text_features.T logits_per_text = logits_per_image.T else: logits_per_image = logit_scale * image_features @ text_features.T logits_per_text = logit_scale * text_features @ image_features.T return logits_per_image, logits_per_text def forward(self, image_features, text_features, logit_scale, output_dict=False): device = image_features.device logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) labels = self.get_ground_truth(device, logits_per_image.shape[0]) total_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 return {"contrastive_loss": total_loss} if output_dict else total_loss class SigLipLoss(nn.Module): """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 @article{zhai2023sigmoid, title={Sigmoid loss for language image pre-training}, author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, journal={arXiv preprint arXiv:2303.15343}, year={2023} } """ def __init__( self, cache_labels=False, rank=0, world_size=1, bidir=True, use_horovod=False, ): super().__init__() self.cache_labels = cache_labels self.rank = rank self.world_size = world_size assert not use_horovod # FIXME need to look at hvd ops for ring transfers self.use_horovod = use_horovod self.bidir = bidir # cache state FIXME cache not currently used, worthwhile? self.prev_num_logits = 0 self.labels = {} def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor: labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) if not negative_only: labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels return labels def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): logits = logit_scale * image_features @ text_features.T if logit_bias is not None: logits += logit_bias return logits def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False): logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) labels = self.get_ground_truth( image_features.device, image_features.dtype, image_features.shape[0], negative_only=negative_only, ) loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] return loss def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False): loss = self._loss(image_features, text_features, logit_scale, logit_bias) if self.world_size > 1: # exchange text features w/ neighbour world_size - 1 times right_rank = (self.rank + 1) % self.world_size left_rank = (self.rank - 1 + self.world_size) % self.world_size if self.bidir: text_features_to_right = text_features_to_left = text_features num_bidir, remainder = divmod(self.world_size - 1, 2) for i in range(num_bidir): text_features_recv = neighbour_exchange_bidir_with_grad( left_rank, right_rank, text_features_to_left, text_features_to_right, ) for f in text_features_recv: loss += self._loss( image_features, f, logit_scale, logit_bias, negative_only=True, ) text_features_to_left, text_features_to_right = text_features_recv if remainder: text_features_recv = neighbour_exchange_with_grad( left_rank, right_rank, text_features_to_right) loss += self._loss( image_features, text_features_recv, logit_scale, logit_bias, negative_only=True, ) else: text_features_to_right = text_features for i in range(self.world_size - 1): text_features_from_left = neighbour_exchange_with_grad( left_rank, right_rank, text_features_to_right) loss += self._loss( image_features, text_features_from_left, logit_scale, logit_bias, negative_only=True, ) text_features_to_right = text_features_from_left return {"contrastive_loss": loss} if output_dict else loss class ResolutionScaledTimeStepSampler(TimeStepSampler): def __init__(self, scale: float, base_time_step_sampler: TimeStepSampler): self.scale = scale self.base_time_step_sampler = base_time_step_sampler @torch.no_grad() def sample_time(self, x_start): base_time = self.base_time_step_sampler.sample_time(x_start) # based on eq (23) of https://arxiv.org/abs/2403.03206 scaled_time = (base_time * self.scale) / (1 + (self.scale - 1) * base_time) return scaled_time class LogitNormalSampler(TimeStepSampler): def __init__(self, normal_mean: float = 0, normal_std: float = 1): # follows https://arxiv.org/pdf/2403.03206.pdf # sample from a normal distribution # pass the output through standard logistic function, i.e., sigmoid self.normal_mean = float(normal_mean) self.normal_std = float(normal_std) @torch.no_grad() def sample_time(self, x_start): x_normal = torch.normal( mean=self.normal_mean, std=self.normal_std, size=(x_start.shape[0],), device=x_start.device, ) x_logistic = torch.nn.functional.sigmoid(x_normal) return x_logistic class UniformTimeSampler(TimeStepSampler): @torch.no_grad() def sample_time(self, x_start): # [0, 1] and 1 indicates the original image; 0 is pure noise return torch.rand(x_start.shape[0], device=x_start.device) class FlowMatching(nn.Module): def __init__( self, sigma_min: float = 1e-5, sigma_max: float = 1.0, timescale: float = 1.0, **kwargs, ): # LatentDiffusion/DDPM will create too many class variables we do not need super().__init__(**kwargs) self.time_step_sampler = LogitNormalSampler() self.sigma_min = sigma_min self.sigma_max = sigma_max self.timescale = timescale self.clip_loss = ClipLoss() # self.SigLipLoss = SigLipLoss() self.resizer = transforms.Resize(256) # for clip def sample_noise(self, x_start): # simple IID noise return torch.randn_like(x_start, device=x_start.device) * self.sigma_max def mos(self, err, start_dim=1, con_mask=None): # mean of square if con_mask is not None: return (err.pow(2).mean(dim=-1) * con_mask).sum(dim=-1) / con_mask.sum(dim=-1) else: return err.pow(2).flatten(start_dim=start_dim).mean(dim=-1) def Xentropy(self, pred, tar, con_mask=None): if con_mask is not None: return (nn.functional.cross_entropy(pred, tar, reduction='none') * con_mask).sum(dim=-1) / con_mask.sum(dim=-1) else: return nn.functional.cross_entropy(pred, tar, reduction='none').mean(dim=-1) def l2_reg(self, pred, lam = 0.0001): return lam * torch.norm(pred, p=2, dim=(1, 2, 3)) ** 2 # model forward and prediction def forward( self, x, nnet, loss_coeffs, cond, con_mask, nnet_style, training_step, cond_ori=None, # not using con_mask_ori=None, # not using batch_img_clip=None, # not using model_config=None, all_config=None, text_token=None, return_raw_loss=False, additional_embeddings=None, timesteps: Optional[Tuple[int, int]] = None, *args, **kwargs, ): assert timesteps is None, "timesteps must be None" timesteps = self.time_step_sampler.sample_time(x) if nnet_style == 'dimr': if hasattr(model_config, "standard_diffusion") and model_config.standard_diffusion: standard_diffusion=True else: standard_diffusion=False return self.p_losses_textVAE( x, cond, con_mask, timesteps, nnet, batch_img_clip=batch_img_clip, cond_ori=cond_ori, con_mask_ori=con_mask_ori, text_token=text_token, loss_coeffs=loss_coeffs, return_raw_loss=return_raw_loss, nnet_style=nnet_style, standard_diffusion=standard_diffusion, all_config=all_config, training_step=training_step, *args, **kwargs ) elif nnet_style == 'dit': if hasattr(model_config, "standard_diffusion") and model_config.standard_diffusion: standard_diffusion=True raise NotImplementedError("need update") else: standard_diffusion=False return self.p_losses_textVAE_dit( x, cond, con_mask, timesteps, nnet, batch_img_clip=batch_img_clip, cond_ori=cond_ori, con_mask_ori=con_mask_ori, text_token=text_token, loss_coeffs=loss_coeffs, return_raw_loss=return_raw_loss, nnet_style=nnet_style, standard_diffusion=standard_diffusion, all_config=all_config, training_step=training_step, *args, **kwargs ) else: raise NotImplementedError def p_losses_textVAE( self, x_start, cond, con_mask, t, nnet, loss_coeffs, training_step, text_token=None, nnet_style=None, all_config=None, batch_img_clip=None, cond_ori=None, # not using con_mask_ori=None, # not using return_raw_loss=False, additional_embeddings=None, standard_diffusion=False, noise=None, ): """ CrossFlow training for DiMR """ assert noise is None x0, mu, log_var = nnet(cond, text_encoder = True, shape = x_start.shape, mask = con_mask) ############ loss for Text VE if batch_img_clip.shape[-1] == 512: recon_gt = self.resizer(batch_img_clip) else: recon_gt = batch_img_clip recon_gt_clip, logit_scale = nnet(recon_gt, image_clip = True) image_features = recon_gt_clip / recon_gt_clip.norm(dim=-1, keepdim=True) text_features = x0 / x0.norm(dim=-1, keepdim=True) recons_loss = self.clip_loss(image_features, text_features, logit_scale) # kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1) kld_loss = -0.5 * torch.sum(1 + log_var - (0.3 * mu) ** 6 - log_var.exp(), dim = 1) # slightly different KL loss function: mu -> 0 [(0.3*mu) ** 6] and var -> 1 kld_loss_weight = 1e-2 # 0.0005 loss_mlp = recons_loss + kld_loss * kld_loss_weight ############ loss for FM noise = x0.reshape(x_start.shape) if hasattr(all_config.nnet.model_args, "cfg_indicator"): null_indicator = torch.from_numpy(np.array([random.random() < all_config.nnet.model_args.cfg_indicator for _ in range(x_start.shape[0])])).to(x_start.device) if null_indicator.sum()<=1: null_indicator[null_indicator==True] = False assert null_indicator.sum() == 0 pass else: target_null = x_start[null_indicator] target_null = torch.cat((target_null[1:], target_null[:1])) x_start[null_indicator] = target_null else: null_indicator = None x_noisy = self.psi(t, x=noise, x1=x_start) target_velocity = self.Dt_psi(t, x=noise, x1=x_start) log_snr = 4 - t * 8 # compute from timestep : inversed prediction = nnet(x_noisy, log_snr = log_snr, null_indicator=null_indicator) target = multi_scale_targets(target_velocity, levels = len(prediction), scale_correction = True) loss_diff = 0 for pred, coeff in check_zip(prediction, loss_coeffs): loss_diff = loss_diff + coeff * self.mos(pred - target[pred.shape[-1]]) ########### loss = loss_diff + loss_mlp return loss, {'loss_diff': loss_diff, 'clip_loss': recons_loss, 'kld_loss': kld_loss, 'kld_loss_weight': torch.tensor(kld_loss_weight, device=kld_loss.device), 'clip_logit_scale': logit_scale} def p_losses_textVAE_dit( self, x_start, cond, con_mask, t, nnet, loss_coeffs, training_step, text_token=None, nnet_style=None, all_config=None, batch_img_clip=None, cond_ori=None, # not using con_mask_ori=None, # not using return_raw_loss=False, additional_embeddings=None, standard_diffusion=False, noise=None, ): """ CrossFLow training for DiT """ assert noise is None x0, mu, log_var = nnet(cond, text_encoder = True, shape = x_start.shape, mask = con_mask) ############ loss for Text VE if batch_img_clip.shape[-1] == 512: recon_gt = self.resizer(batch_img_clip) else: recon_gt = batch_img_clip recon_gt_clip, logit_scale = nnet(recon_gt, image_clip = True) image_features = recon_gt_clip / recon_gt_clip.norm(dim=-1, keepdim=True) text_features = x0 / x0.norm(dim=-1, keepdim=True) recons_loss = self.clip_loss(image_features, text_features, logit_scale) # kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1) kld_loss = -0.5 * torch.sum(1 + log_var - (0.3 * mu) ** 6 - log_var.exp(), dim = 1) kld_loss_weight = 1e-2 # 0.0005 loss_mlp = recons_loss + kld_loss * kld_loss_weight ############ loss for FM noise = x0.reshape(x_start.shape) if hasattr(all_config.nnet.model_args, "cfg_indicator"): null_indicator = torch.from_numpy(np.array([random.random() < all_config.nnet.model_args.cfg_indicator for _ in range(x_start.shape[0])])).to(x_start.device) if null_indicator.sum()<=1: null_indicator[null_indicator==True] = False assert null_indicator.sum() == 0 pass else: target_null = x_start[null_indicator] target_null = torch.cat((target_null[1:], target_null[:1])) x_start[null_indicator] = target_null else: null_indicator = None x_noisy = self.psi(t, x=noise, x1=x_start) target_velocity = self.Dt_psi(t, x=noise, x1=x_start) prediction = nnet(x_noisy, t = t, null_indicator = null_indicator)[0] loss_diff = self.mos(prediction - target_velocity) ########### loss = loss_diff + loss_mlp return loss, {'loss_diff': loss_diff, 'clip_loss': recons_loss, 'kld_loss': kld_loss, 'kld_loss_weight': torch.tensor(kld_loss_weight, device=kld_loss.device), 'clip_logit_scale': logit_scale} ## flow matching specific functions def psi(self, t, x, x1): assert ( t.shape[0] == x.shape[0] ), f"Batch size of t and x does not agree {t.shape[0]} vs. {x.shape[0]}" assert ( t.shape[0] == x1.shape[0] ), f"Batch size of t and x1 does not agree {t.shape[0]} vs. {x1.shape[0]}" assert t.ndim == 1 t = self.expand_t(t, x) return (t * (self.sigma_min / self.sigma_max - 1) + 1) * x + t * x1 def Dt_psi(self, t: torch.Tensor, x: torch.Tensor, x1: torch.Tensor): assert x.shape[0] == x1.shape[0] return (self.sigma_min / self.sigma_max - 1) * x + x1 def expand_t(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: t_expanded = t while t_expanded.ndim < x.ndim: t_expanded = t_expanded.unsqueeze(-1) return t_expanded.expand_as(x) class ODEEulerFlowMatchingSolver(Solver): """ ODE Solver for Flow matching that uses an Euler discretization Supports number of time steps at inference """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.step_size_type = kwargs.get("step_size_type", "step_in_dsigma") assert self.step_size_type in ["step_in_dsigma", "step_in_dt"] self.sample_timescale = 1.0 - 1e-5 @torch.no_grad() def sample_euler( self, x_T, unconditional_guidance_scale, has_null_indicator, t=[0, 1.0], **kwargs, ): """ Euler solver for flow matching. Based on https://github.com/VinAIResearch/LFM/blob/main/sampler/karras_sample.py """ t = torch.tensor(t) t = t * self.sample_timescale sigma_min = 1e-5 sigma_max = 1.0 sigma_steps = torch.linspace( sigma_min, sigma_max, self.num_time_steps + 1, device=x_T.device ) discrete_time_steps_for_step = torch.linspace( t[0], t[1], self.num_time_steps + 1, device=x_T.device ) discrete_time_steps_to_eval_model_at = torch.linspace( t[0], t[1], self.num_time_steps, device=x_T.device ) print("num_time_steps : " + str(self.num_time_steps)) for i in range(self.num_time_steps): t_i = discrete_time_steps_to_eval_model_at[i] velocity = self.get_model_output_dimr( x_T, has_null_indicator = has_null_indicator, t_continuous = t_i.repeat(x_T.shape[0]), unconditional_guidance_scale = unconditional_guidance_scale, ) if self.step_size_type == "step_in_dsigma": step_size = sigma_steps[i + 1] - sigma_steps[i] elif self.step_size_type == "step_in_dt": step_size = ( discrete_time_steps_for_step[i + 1] - discrete_time_steps_for_step[i] ) x_T = x_T + velocity * step_size intermediates = None return x_T, intermediates @torch.no_grad() def sample( self, *args, **kwargs, ): assert kwargs.get("ucg_schedule", None) is None assert kwargs.get("skip_type", None) is None assert kwargs.get("dynamic_threshold", None) is None assert kwargs.get("x0", None) is None assert kwargs.get("x_T") is not None assert kwargs.get("score_corrector", None) is None assert kwargs.get("normals_sequence", None) is None assert kwargs.get("callback", None) is None assert kwargs.get("quantize_x0", False) is False assert kwargs.get("eta", 0.0) == 0.0 assert kwargs.get("mask", None) is None assert kwargs.get("noise_dropout", 0.0) == 0.0 self.num_time_steps = kwargs.get("sample_steps") self.x_T_uncon = kwargs.get("x_T_uncon") samples, intermediates = super().sample( *args, sampling_method=self.sample_euler, do_make_schedule=False, **kwargs, ) return samples, intermediates class ODEFlowMatchingSolver(Solver): """ ODE Solver for Flow matching that uses `dopri5` Does not support number of time steps based control """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.sample_timescale = 1.0 - 1e-5 # sampling for inference @torch.no_grad() def sample_transport( self, x_T, unconditional_guidance_scale, has_null_indicator, t=[0, 1.0], ode_opts={}, **kwargs, ): num_evals = 0 t = torch.tensor(t, device=x_T.device) if "options" not in ode_opts: ode_opts["options"] = {} ode_opts["options"]["step_t"] = [self.sample_timescale + 1e-6] def ode_func(t, x_T): nonlocal num_evals num_evals += 1 model_output = self.get_model_output_dimr( x_T, has_null_indicator = has_null_indicator, t_continuous = t.repeat(x_T.shape[0]), unconditional_guidance_scale = unconditional_guidance_scale, ) return model_output z = torchdiffeq.odeint( ode_func, x_T, t * self.sample_timescale, **{"atol": 1e-5, "rtol": 1e-5, "method": "dopri5", **ode_opts}, ) # first dimension of z contains solutions to different timepoints # we only need the last one (corresponding to t=1, i.e., image) z = z[-1] intermediates = None return z, intermediates @torch.no_grad() def sample( self, *args, **kwargs, ): assert kwargs.get("ucg_schedule", None) is None assert kwargs.get("skip_type", None) is None assert kwargs.get("dynamic_threshold", None) is None assert kwargs.get("x0", None) is None assert kwargs.get("x_T") is not None assert kwargs.get("score_corrector", None) is None assert kwargs.get("normals_sequence", None) is None assert kwargs.get("callback", None) is None assert kwargs.get("quantize_x0", False) is False assert kwargs.get("eta", 0.0) == 0.0 assert kwargs.get("mask", None) is None assert kwargs.get("noise_dropout", 0.0) == 0.0 samples, intermediates = super().sample( *args, sampling_method=self.sample_transport, do_make_schedule=False, **kwargs, ) return samples, intermediates