Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import copy | |
import logging | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
# from basicsr.utils import img2tensor | |
from diffusers import DDIMScheduler | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
from transformers import ( | |
ClapTextModelWithProjection, | |
RobertaTokenizer, | |
RobertaTokenizerFast, | |
SpeechT5HifiGan, | |
) | |
from diffusers.models import AutoencoderKL, UNet2DConditionModel | |
from typing import List, Optional, Union | |
from src.utils.factory import ( | |
normalize_along_channel, | |
project_onto_tangent_space, | |
identity_projection, | |
) | |
log = logging.getLogger(__name__) | |
class Sampler(DiffusionPipeline): | |
r"""Core of Audio Morphix that samples latent with inversed latent and cross-attend trajactory.""" | |
def __init__( | |
self, | |
vae: AutoencoderKL, | |
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], | |
text_encoder: ClapTextModelWithProjection, | |
unet: UNet2DConditionModel, | |
feature_estimator: UNet2DConditionModel, | |
scheduler: DDIMScheduler, | |
device: torch.device = torch.device("cpu"), | |
precision: torch.dtype = torch.float32, | |
): | |
super().__init__() | |
self.register_modules( | |
vae=vae, | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
unet=unet, | |
feature_estimator=feature_estimator, | |
scheduler=scheduler, | |
) | |
self = self.to(torch_device=device, torch_dtype=precision) | |
self._device = device | |
def edit( | |
self, | |
prompt: str, | |
mode, | |
edit_kwargs, | |
prompt_replace: str = None, | |
negative_prompt: str = None, | |
num_inference_steps: int = 50, | |
guidance_scale: Optional[float] = 7.5, | |
latent: Optional[torch.FloatTensor] = None, | |
start_time: int = 50, | |
energy_scale: float = 0, | |
SDE_strength: float = 0.4, | |
SDE_strength_un: float = 0, | |
latent_noise_ref: Optional[torch.FloatTensor] = None, | |
bg_to_fg_ratio: float = 0.5, | |
disable_tangent_proj: bool = False, | |
alg: str = "D+", | |
): | |
log.info("Start Editing:") | |
self.alg = alg | |
# Select projection function | |
if disable_tangent_proj: | |
log.info("Use guidance directly.") | |
self.proj_fn = identity_projection | |
else: | |
log.info("Project guidance onto tangent space.") | |
self.proj_fn = project_onto_tangent_space | |
# Generate source text embedding | |
text_input = self._encode_text(prompt) | |
if prompt_replace is not None: | |
text_replace = self._encode_text(prompt_replace) | |
else: | |
text_replace = text_input | |
# Generate null text embedding for CFG | |
prompt_uncond = "" if negative_prompt is None else negative_prompt | |
text_uncond = self._encode_text(prompt_uncond) | |
# Text condition for the current trajectory | |
context = self._stack_text(text_uncond, text_input) | |
self.scheduler.set_timesteps(num_inference_steps) | |
dict_mask = edit_kwargs["dict_mask"] if "dict_mask" in edit_kwargs else None | |
time_scale = num_inference_steps / 50 # scale the editing operation (default is 50) | |
for i, t in enumerate(tqdm(self.scheduler.timesteps[-start_time:])): | |
next_timestep = min( | |
t | |
- self.scheduler.config.num_train_timesteps | |
// self.scheduler.num_inference_steps, | |
999, | |
) | |
next_timestep = max(next_timestep, 0) | |
if energy_scale == 0 or alg == "D": | |
repeat = 1 | |
elif int(20*time_scale) < i < int(30*time_scale) and i % 2 == 0: | |
repeat = 3 | |
else: | |
repeat = 1 | |
for ri in range(repeat): | |
latent_in = torch.cat([latent.unsqueeze(2)] * 2) | |
with torch.no_grad(): | |
noise_pred = self.unet( | |
latent_in, | |
t, | |
encoder_hidden_states=context["embed"] if self.use_cross_attn else None, | |
class_labels=None if self.use_cross_attn else context["embed"], | |
encoder_attention_mask=context["mask"] if self.use_cross_attn else None, | |
mask=dict_mask, | |
save_kv=False, | |
mode=mode, | |
iter_cur=i, | |
)["sample"].squeeze(2) | |
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_prediction_text - noise_pred_uncond | |
) | |
if ( | |
energy_scale != 0 | |
and int(i*time_scale) < 30 | |
and (alg == "D" or i % 2 == 0 or i < int(10*time_scale)) | |
): | |
# editing guidance | |
noise_pred_org = noise_pred | |
if mode == "move": | |
guidance = self.guidance_move( | |
latent=latent, | |
latent_noise_ref=latent_noise_ref[-(i + 1)], | |
t=t, | |
context=text_input, | |
context_base=text_input, | |
energy_scale=energy_scale, | |
**edit_kwargs | |
) | |
elif mode == "drag": | |
guidance = self.guidance_drag( | |
latent=latent, | |
latent_noise_ref=latent_noise_ref[-(i + 1)], | |
t=t, | |
context=text_input, | |
context_base=text_input, | |
energy_scale=energy_scale, | |
**edit_kwargs | |
) | |
elif mode == "landmark": | |
guidance = self.guidance_landmark( | |
latent=latent, | |
latent_noise_ref=latent_noise_ref[-(i + 1)], | |
t=t, | |
context=text_input, | |
context_base=text_input, | |
energy_scale=energy_scale, | |
**edit_kwargs | |
) | |
elif mode == "appearance": | |
guidance = self.guidance_appearance( | |
latent=latent, | |
latent_noise_ref=latent_noise_ref[-(i + 1)], | |
t=t, | |
context=text_input, | |
context_base=text_input, | |
energy_scale=energy_scale, | |
**edit_kwargs | |
) | |
elif mode == "paste": | |
guidance = self.guidance_paste( | |
latent=latent, | |
latent_noise_ref=latent_noise_ref[-(i + 1)], | |
t=t, | |
context=text_input, | |
context_base=text_input, | |
context_replace=text_replace, | |
energy_scale=energy_scale, | |
**edit_kwargs | |
) | |
elif mode == "mix": | |
guidance = self.guidance_mix( | |
latent=latent, | |
latent_noise_ref=latent_noise_ref[-(i + 1)], | |
t=t, | |
context=text_input, | |
context_base=text_input, | |
context_replace=text_replace, | |
energy_scale=energy_scale, | |
**edit_kwargs | |
) | |
elif mode == "remove": | |
guidance = self.guidance_remove( | |
latent=latent, | |
latent_noise_ref=latent_noise_ref[-(i + 1)], | |
t=t, | |
context=text_input, | |
context_base=text_input, | |
context_replace=text_replace, | |
energy_scale=energy_scale, | |
**edit_kwargs | |
) | |
elif mode == "style_transfer": | |
guidance = self.guidance_style_transfer( | |
latent=latent, | |
latent_noise_ref=latent_noise_ref[-(i + 1)], | |
t=t, | |
context=text_input, | |
context_base=text_input, | |
energy_scale=energy_scale, | |
**edit_kwargs | |
) | |
# Project guidance onto z_t | |
guidance = self.proj_fn(guidance, latent) | |
noise_pred = noise_pred + guidance # NOTE: weighted sum? | |
else: | |
noise_pred_org = None | |
# zt->zt-1 | |
prev_timestep = ( | |
t | |
- self.scheduler.config.num_train_timesteps | |
// self.scheduler.num_inference_steps | |
) | |
alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
alpha_prod_t_prev = ( | |
self.scheduler.alphas_cumprod[prev_timestep] | |
if prev_timestep >= 0 | |
else self.scheduler.final_alpha_cumprod | |
) | |
beta_prod_t = 1 - alpha_prod_t | |
if self.scheduler.config.prediction_type == "epsilon": | |
pred_original_sample = ( | |
latent - beta_prod_t ** (0.5) * noise_pred | |
) / alpha_prod_t ** (0.5) | |
pred_epsilon = noise_pred | |
pred_epsilon_org = noise_pred_org | |
elif self.scheduler.config.prediction_type == "v_prediction": | |
pred_original_sample = (alpha_prod_t**0.5) * latent - (beta_prod_t**0.5) * noise_pred | |
pred_epsilon = (alpha_prod_t**0.5) * noise_pred + (beta_prod_t**0.5) * latent | |
if noise_pred_org is not None: | |
pred_epsilon_org = (alpha_prod_t**0.5) * noise_pred_org + (beta_prod_t**0.5) * latent | |
else: | |
pred_epsilon_org = None | |
if int(10*time_scale) < i < int(20*time_scale): | |
eta, eta_rd = SDE_strength_un, SDE_strength | |
else: | |
eta, eta_rd = 0.0, 0.0 | |
variance = self.scheduler._get_variance(t, prev_timestep) | |
std_dev_t = eta * variance ** (0.5) | |
std_dev_t_rd = eta_rd * variance ** (0.5) | |
if noise_pred_org is not None: | |
pred_sample_direction_rd = ( | |
1 - alpha_prod_t_prev - std_dev_t_rd**2 | |
) ** (0.5) * pred_epsilon_org | |
pred_sample_direction = ( | |
1 - alpha_prod_t_prev - std_dev_t**2 | |
) ** (0.5) * pred_epsilon_org | |
else: | |
pred_sample_direction_rd = ( | |
1 - alpha_prod_t_prev - std_dev_t_rd**2 | |
) ** (0.5) * pred_epsilon | |
pred_sample_direction = ( | |
1 - alpha_prod_t_prev - std_dev_t**2 | |
) ** (0.5) * pred_epsilon | |
latent_prev = ( | |
alpha_prod_t_prev ** (0.5) * pred_original_sample | |
+ pred_sample_direction | |
) | |
latent_prev_rd = ( | |
alpha_prod_t_prev ** (0.5) * pred_original_sample | |
+ pred_sample_direction_rd | |
) | |
# Regional SDE | |
if (eta_rd > 0 or eta > 0) and alg == "D+": | |
variance_noise = torch.randn_like(latent_prev) | |
variance_rd = std_dev_t_rd * variance_noise | |
variance = std_dev_t * variance_noise | |
if mode == "move": | |
mask = ( | |
F.interpolate( | |
edit_kwargs["mask_x0"][None, None], | |
( | |
edit_kwargs["mask_cur"].shape[-2], | |
edit_kwargs["mask_cur"].shape[-1], | |
), | |
) | |
> 0.5 | |
).float() | |
mask = ((edit_kwargs["mask_cur"] + mask) > 0.5).float() | |
mask = ( | |
F.interpolate( | |
mask, (latent_prev.shape[-2], latent_prev.shape[-1]) | |
) | |
> 0.5 | |
).to(dtype=latent.dtype) | |
elif mode == "drag": | |
mask = F.interpolate( | |
edit_kwargs["mask_x0"][None, None], | |
(latent_prev[-1].shape[-2], latent_prev[-1].shape[-1]), | |
) | |
mask = (mask > 0).to(dtype=latent.dtype) | |
elif mode == "landmark": | |
mask = torch.ones_like(latent_prev) | |
elif ( | |
mode == "appearance" | |
or mode == "paste" | |
or mode == "remove" | |
or mode == "mix" | |
): | |
mask = F.interpolate( | |
edit_kwargs["mask_base_cur"].float(), | |
(latent_prev[-1].shape[-2], latent_prev[-1].shape[-1]), | |
) | |
mask = (mask > 0).to(dtype=latent.dtype) | |
latent_prev = (latent_prev + variance) * (1 - mask) + ( | |
latent_prev_rd + variance_rd | |
) * mask | |
if repeat > 1: | |
with torch.no_grad(): | |
alpha_prod_t = self.scheduler.alphas_cumprod[next_timestep] | |
alpha_prod_t_next = self.scheduler.alphas_cumprod[t] | |
beta_prod_t = 1 - alpha_prod_t | |
model_output = self.unet( | |
latent_prev.unsqueeze(2), | |
next_timestep, | |
class_labels=None if self.use_cross_attn else text_input["embed"], | |
encoder_hidden_states=text_input["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=text_input["mask"] if self.use_cross_attn else None, | |
mask=dict_mask, | |
save_kv=False, | |
mode=mode, | |
iter_cur=-2, | |
)["sample"].squeeze(2) | |
# Different scheduling options | |
if self.scheduler.config.prediction_type == "epsilon": | |
next_original_sample = ( | |
latent_prev - beta_prod_t**0.5 * model_output | |
) / alpha_prod_t**0.5 | |
pred_epsilon = model_output | |
elif self.scheduler.config.prediction_type == "v_prediction": | |
next_original_sample = (alpha_prod_t**0.5) * latent_prev - (beta_prod_t**0.5) * model_output | |
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * latent_prev | |
next_sample_direction = ( | |
1 - alpha_prod_t_next | |
) ** 0.5 * pred_epsilon | |
latent = ( | |
alpha_prod_t_next**0.5 * next_original_sample | |
+ next_sample_direction | |
) | |
latent = latent_prev | |
return latent | |
def guidance_move( | |
self, | |
mask_x0, | |
mask_x0_ref, | |
mask_x0_keep, | |
mask_tar, | |
mask_cur, | |
mask_keep, | |
mask_other, | |
mask_overlap, | |
mask_non_overlap, | |
latent, | |
latent_noise_ref, | |
t, | |
up_ft_index, | |
context, | |
context_base, | |
up_scale, | |
resize_scale_x, | |
resize_scale_y, | |
energy_scale, | |
w_edit, | |
w_content, | |
w_contrast, | |
w_inpaint, | |
): | |
cos = nn.CosineSimilarity(dim=1) | |
loss_scale = [0.5, 0.5] | |
with torch.no_grad(): | |
up_ft_tar = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2), | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_base["embed"], | |
encoder_hidden_states=( | |
context_base["embed"] if self.use_cross_attn else None | |
), | |
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
up_ft_tar_org = copy.deepcopy(up_ft_tar) | |
for f_id in range(len(up_ft_tar_org)): | |
up_ft_tar_org[f_id] = F.interpolate( | |
up_ft_tar_org[f_id], | |
( | |
up_ft_tar_org[-1].shape[-2] * up_scale, | |
up_ft_tar_org[-1].shape[-1] * up_scale, | |
), | |
) | |
latent = latent.detach().requires_grad_(True) | |
for f_id in range(len(up_ft_tar)): | |
up_ft_tar[f_id] = F.interpolate( | |
up_ft_tar[f_id], | |
( | |
int(up_ft_tar[-1].shape[-2] * resize_scale_y * up_scale), | |
int(up_ft_tar[-1].shape[-1] * resize_scale_x * up_scale), | |
), | |
) | |
up_ft_cur = self.feature_estimator( | |
sample=latent, | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context["embed"], | |
encoder_hidden_states=( | |
context["embed"] if self.use_cross_attn else None | |
), | |
encoder_attention_mask=context["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_cur)): | |
up_ft_cur[f_id] = F.interpolate( | |
up_ft_cur[f_id], | |
( | |
up_ft_cur[-1].shape[-2] * up_scale, | |
up_ft_cur[-1].shape[-1] * up_scale, | |
), | |
) | |
# editing energy | |
loss_edit = 0 | |
for f_id in range(len(up_ft_tar)): | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) # (x1, D) | |
up_ft_tar_vec = ( | |
up_ft_tar[f_id][mask_tar.repeat(1, up_ft_tar[f_id].shape[1], 1, 1)] | |
.view(up_ft_tar[f_id].shape[1], -1) | |
.permute(1, 0) | |
) # (x2, D) | |
# Compute consine sim between `up_ft_cur_vec` and `up_ft_tar_vec` | |
# up_ft_cur_vec_norm = up_ft_cur_vec / (up_ft_cur_vec.norm(dim=1, keepdim=True) + 1e-8) | |
# up_ft_tar_vec_norm = up_ft_tar_vec / (up_ft_tar_vec.norm(dim=1, keepdim=True) + 1e-8) | |
# sim = torch.mm(up_ft_cur_vec_norm, up_ft_tar_vec_norm.T) | |
sim = cos(up_ft_cur_vec, up_ft_tar_vec) | |
# sim_global = cos( | |
# up_ft_cur_vec.mean(0, keepdim=True), up_ft_tar_vec.mean(0, keepdim=True) | |
# ) | |
loss_edit = loss_edit + (w_edit / (1 + 4 * sim.mean())) * loss_scale[f_id] | |
# Content energy | |
loss_con = 0 | |
for f_id in range(len(up_ft_tar_org)): | |
sim_other = cos(up_ft_tar_org[f_id], up_ft_cur[f_id])[0][mask_other[0, 0]] | |
loss_con = ( | |
loss_con + w_content / (1 + 4 * sim_other.mean()) * loss_scale[f_id] | |
) | |
if mask_x0_ref is not None: | |
mask_x0_ref_cur = ( | |
F.interpolate( | |
mask_x0_ref[None, None], | |
(mask_other.shape[-2], mask_other.shape[-1]), | |
) | |
> 0.5 | |
) | |
else: | |
mask_x0_ref_cur = mask_other | |
for f_id in range(len(up_ft_tar)): | |
# # Global | |
# up_ft_cur_non_overlap_contrast = ( | |
# up_ft_cur[f_id][ | |
# mask_non_overlap.repeat(1, up_ft_cur[f_id].shape[1], 1, 1) | |
# ] | |
# .view(up_ft_cur[f_id].shape[1], -1) | |
# .permute(1, 0) | |
# ) | |
# up_ft_tar_non_overlap_contrast = ( | |
# up_ft_tar_org[f_id][ | |
# mask_non_overlap.repeat(1, up_ft_tar_org[f_id].shape[1], 1, 1) | |
# ] | |
# .view(up_ft_tar_org[f_id].shape[1], -1) | |
# .permute(1, 0) | |
# ) | |
# F sim | |
up_ft_cur_non_overlap_sum = torch.sum( | |
up_ft_cur[f_id] * mask_non_overlap.repeat(1, up_ft_cur[f_id].shape[1], 1, 1), | |
dim=-2, | |
) | |
up_ft_tar_non_overlap_sum = torch.sum( | |
up_ft_tar_org[f_id] | |
* mask_non_overlap.repeat(1, up_ft_tar_org[f_id].shape[1], 1, 1), | |
dim=-2, | |
) # feature for reference audio | |
mask_sum = torch.sum(mask_non_overlap, dim=-2) | |
up_ft_cur_non_overlap_contrast = ( | |
(up_ft_cur_non_overlap_sum / (mask_sum + 1e-8)) | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) # avoid dividing zero | |
up_ft_tar_non_overlap_contrast = ( | |
(up_ft_tar_non_overlap_sum / (mask_sum + 1e-8)) | |
.view(up_ft_tar_org[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim_non_overlap_contrast = ( | |
cos(up_ft_cur_non_overlap_contrast, up_ft_tar_non_overlap_contrast) + 1.0 | |
) / 2.0 | |
loss_con = loss_con + w_contrast * sim_non_overlap_contrast.mean() * loss_scale[f_id] | |
up_ft_cur_non_overlap_inpaint = ( | |
up_ft_cur[f_id][ | |
mask_non_overlap.repeat(1, up_ft_cur[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
.mean(0, keepdim=True) | |
) | |
up_ft_tar_non_overlap_inpaint = ( | |
up_ft_tar_org[f_id][ | |
mask_x0_ref_cur.repeat(1, up_ft_tar_org[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_org[f_id].shape[1], -1) | |
.permute(1, 0) | |
.mean(0, keepdim=True) | |
) | |
sim_inpaint = ( | |
cos(up_ft_cur_non_overlap_inpaint, up_ft_tar_non_overlap_inpaint) + 1.0 | |
) / 2.0 | |
loss_con = loss_con + w_inpaint / (1 + 4 * sim_inpaint.mean()) | |
cond_grad_edit = torch.autograd.grad( | |
loss_edit * energy_scale, latent, retain_graph=True | |
)[0] | |
cond_grad_con = torch.autograd.grad(loss_con * energy_scale, latent)[0] | |
mask_edit1 = (mask_cur > 0.5).float() | |
mask_edit1 = ( | |
F.interpolate(mask_edit1, (latent.shape[-2], latent.shape[-1])) > 0 | |
).to(dtype=latent.dtype) | |
# mask_edit2 = ((mask_keep + mask_non_overlap.float()) > 0.5).float() | |
# mask_edit2 = ( | |
# F.interpolate(mask_edit2, (latent.shape[-2], latent.shape[-1])) > 0 | |
# ).to(dtype=latent.dtype) | |
mask_edit2 = 1-mask_edit1 | |
guidance = ( | |
cond_grad_edit.detach() * 8e-2 * mask_edit1 | |
+ cond_grad_con.detach() * 8e-2 * mask_edit2 | |
) | |
self.feature_estimator.zero_grad() | |
return guidance | |
def guidance_drag( | |
self, | |
mask_x0, | |
mask_cur, | |
mask_tar, | |
mask_other, | |
latent, | |
latent_noise_ref, | |
t, | |
up_ft_index, | |
up_scale, | |
context, | |
context_base, | |
energy_scale, | |
w_edit, | |
w_inpaint, | |
w_content, | |
dict_mask=None, | |
): | |
cos = nn.CosineSimilarity(dim=1) | |
with torch.no_grad(): | |
up_ft_tar = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2), | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_base["embed"], | |
encoder_hidden_states=( | |
context_base["embed"] if self.use_cross_attn else None | |
), | |
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar)): | |
up_ft_tar[f_id] = F.interpolate( | |
up_ft_tar[f_id], | |
( | |
up_ft_tar[-1].shape[-2] * up_scale, | |
up_ft_tar[-1].shape[-1] * up_scale, | |
), | |
) | |
latent = latent.detach().requires_grad_(True) | |
up_ft_cur = self.feature_estimator( | |
sample=latent, | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context["embed"], | |
encoder_hidden_states=( | |
context["embed"] if self.use_cross_attn else None | |
), | |
encoder_attention_mask=context["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_cur)): | |
up_ft_cur[f_id] = F.interpolate( | |
up_ft_cur[f_id], | |
( | |
up_ft_cur[-1].shape[-2] * up_scale, | |
up_ft_cur[-1].shape[-1] * up_scale, | |
), | |
) | |
# moving loss | |
loss_edit = 0 | |
for f_id in range(len(up_ft_tar)): | |
for mask_cur_i, mask_tar_i in zip(mask_cur, mask_tar): | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][ | |
mask_cur_i.repeat(1, up_ft_cur[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec = ( | |
up_ft_tar[f_id][ | |
mask_tar_i.repeat(1, up_ft_tar[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
loss_edit = loss_edit + w_edit / (1 + 4 * sim.mean()) | |
mask_overlap = ((mask_cur_i.float() + mask_tar_i.float()) > 1.5).float() | |
mask_non_overlap = (mask_tar_i.float() - mask_overlap) > 0.5 | |
up_ft_cur_non_overlap = ( | |
up_ft_cur[f_id][ | |
mask_non_overlap.repeat(1, up_ft_cur[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_non_overlap = ( | |
up_ft_tar[f_id][ | |
mask_non_overlap.repeat(1, up_ft_tar[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim_non_overlap = ( | |
cos(up_ft_cur_non_overlap, up_ft_tar_non_overlap) + 1.0 | |
) / 2.0 | |
loss_edit = loss_edit + w_inpaint * sim_non_overlap.mean() | |
# consistency loss | |
loss_con = 0 | |
for f_id in range(len(up_ft_tar)): | |
sim_other = ( | |
cos(up_ft_tar[f_id], up_ft_cur[f_id])[0][mask_other[0, 0]] + 1.0 | |
) / 2.0 | |
loss_con = loss_con + w_content / (1 + 4 * sim_other.mean()) | |
loss_edit = loss_edit / len(up_ft_cur) / len(mask_cur) | |
loss_con = loss_con / len(up_ft_cur) | |
cond_grad_edit = torch.autograd.grad( | |
loss_edit * energy_scale, latent, retain_graph=True | |
)[0] | |
cond_grad_con = torch.autograd.grad(loss_con * energy_scale, latent)[0] | |
mask = F.interpolate( | |
mask_x0[None, None], | |
(cond_grad_edit[-1].shape[-2], cond_grad_edit[-1].shape[-1]), | |
) | |
mask = (mask > 0).to(dtype=latent.dtype) | |
guidance = ( | |
cond_grad_edit.detach() * 4e-2 * mask | |
+ cond_grad_con.detach() * 4e-2 * (1 - mask) | |
) | |
self.feature_estimator.zero_grad() | |
return guidance | |
def guidance_landmark( | |
self, | |
mask_cur, | |
mask_tar, | |
latent, | |
latent_noise_ref, | |
t, | |
up_ft_index, | |
up_scale, | |
context, | |
context_base, | |
energy_scale, | |
w_edit, | |
w_inpaint, | |
): | |
cos = nn.CosineSimilarity(dim=1) | |
with torch.no_grad(): | |
up_ft_tar = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2), | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_base["embed"], | |
encoder_hidden_states=( | |
context_base["embed"] if self.use_cross_attn else None | |
), | |
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar)): | |
up_ft_tar[f_id] = F.interpolate( | |
up_ft_tar[f_id], | |
( | |
up_ft_tar[-1].shape[-2] * up_scale, | |
up_ft_tar[-1].shape[-1] * up_scale, | |
), | |
) | |
latent = latent.detach().requires_grad_(True) | |
up_ft_cur = self.feature_estimator( | |
sample=latent, | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context["embed"], | |
encoder_hidden_states=( | |
context["embed"] if self.use_cross_attn else None | |
), | |
encoder_attention_mask=context["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_cur)): | |
up_ft_cur[f_id] = F.interpolate( | |
up_ft_cur[f_id], | |
( | |
up_ft_cur[-1].shape[-2] * up_scale, | |
up_ft_cur[-1].shape[-1] * up_scale, | |
), | |
) | |
# moving loss | |
loss_edit = 0 | |
for f_id in range(len(up_ft_tar)): | |
for mask_cur_i, mask_tar_i in zip(mask_cur, mask_tar): | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][ | |
mask_cur_i.repeat(1, up_ft_cur[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec = ( | |
up_ft_tar[f_id][ | |
mask_tar_i.repeat(1, up_ft_tar[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
loss_edit = loss_edit + w_edit / (1 + 4 * sim.mean()) | |
loss_edit = loss_edit / len(up_ft_cur) / len(mask_cur) | |
cond_grad_edit = torch.autograd.grad( | |
loss_edit * energy_scale, latent, retain_graph=True | |
)[0] | |
guidance = cond_grad_edit.detach() * 4e-2 | |
self.feature_estimator.zero_grad() | |
return guidance | |
def guidance_appearance( | |
self, | |
mask_base_cur, | |
mask_replace_cur, | |
latent, | |
latent_noise_ref, | |
t, | |
up_ft_index, | |
up_scale, | |
context, | |
context_base, | |
context_replace, | |
energy_scale, | |
dict_mask, | |
w_edit, | |
w_content, | |
): | |
cos = nn.CosineSimilarity(dim=1) | |
with torch.no_grad(): | |
up_ft_tar_base = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2)[::2], | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_base["embed"], | |
encoder_hidden_states=( | |
context_base["embed"] if self.use_cross_attn else None | |
), | |
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar_base)): | |
up_ft_tar_base[f_id] = F.interpolate( | |
up_ft_tar_base[f_id], | |
( | |
up_ft_tar_base[-1].shape[-2] * up_scale, | |
up_ft_tar_base[-1].shape[-1] * up_scale, | |
), | |
) | |
with torch.no_grad(): | |
up_ft_tar_replace = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2)[1::2], | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_replace["embed"], | |
encoder_hidden_states=( | |
context_replace["embed"] if self.use_cross_attn else None | |
), | |
encoder_attention_mask=context_replace["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar_replace)): | |
up_ft_tar_replace[f_id] = F.interpolate( | |
up_ft_tar_replace[f_id], | |
( | |
up_ft_tar_replace[-1].shape[-2] * up_scale, | |
up_ft_tar_replace[-1].shape[-1] * up_scale, | |
), | |
) | |
latent = latent.detach().requires_grad_(True) | |
up_ft_cur = self.feature_estimator( | |
sample=latent, | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context["embed"], | |
encoder_hidden_states=( | |
context["embed"] if self.use_cross_attn else None | |
), | |
encoder_attention_mask=context["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_cur)): | |
up_ft_cur[f_id] = F.interpolate( | |
up_ft_cur[f_id], | |
( | |
up_ft_cur[-1].shape[-2] * up_scale, | |
up_ft_cur[-1].shape[-1] * up_scale, | |
), | |
) | |
# for base content | |
loss_con = 0 | |
for f_id in range(len(up_ft_tar_base)): | |
mask_cur = (1 - mask_base_cur.float()) > 0.5 | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec = ( | |
up_ft_tar_base[f_id][ | |
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_base[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
loss_con = loss_con + w_content / (1 + 4 * sim.mean()) | |
# for replace content | |
loss_edit = 0 | |
for f_id in range(len(up_ft_tar_replace)): | |
mask_cur = mask_base_cur | |
mask_tar = mask_replace_cur | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
.mean(0, keepdim=True) | |
) | |
up_ft_tar_vec = ( | |
up_ft_tar_replace[f_id][ | |
mask_tar.repeat(1, up_ft_tar_replace[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_replace[f_id].shape[1], -1) | |
.permute(1, 0) | |
.mean(0, keepdim=True) | |
) | |
sim_all = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
loss_edit = loss_edit + w_edit / (1 + 4 * sim_all.mean()) | |
cond_grad_con = torch.autograd.grad( | |
loss_con * energy_scale, latent, retain_graph=True | |
)[0] | |
cond_grad_edit = torch.autograd.grad(loss_edit * energy_scale, latent)[0] | |
mask = F.interpolate( | |
mask_base_cur.float(), | |
(cond_grad_edit[-1].shape[-2], cond_grad_edit[-1].shape[-1]), | |
) | |
mask = (mask > 0).to(dtype=latent.dtype) | |
guidance = ( | |
cond_grad_con.detach() * (1 - mask) * 4e-2 | |
+ cond_grad_edit.detach() * mask * 4e-2 | |
) | |
self.feature_estimator.zero_grad() | |
return guidance | |
def guidance_mix( | |
self, | |
mask_base_cur, | |
mask_replace_cur, | |
latent, | |
latent_noise_ref, | |
t, | |
up_ft_index, | |
up_scale, | |
context, | |
context_base, | |
context_replace, | |
energy_scale, | |
dict_mask, | |
w_edit, | |
w_content, | |
): | |
cos = nn.CosineSimilarity(dim=1) | |
# pcpt_loss_fn = LPIPS() | |
with torch.no_grad(): | |
up_ft_tar_base = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2)[::2], | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_base["embed"], | |
encoder_hidden_states=context_base["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar_base)): | |
up_ft_tar_base[f_id] = F.interpolate( | |
up_ft_tar_base[f_id], | |
( | |
up_ft_tar_base[-1].shape[-2] * up_scale, | |
up_ft_tar_base[-1].shape[-1] * up_scale, | |
), | |
) | |
with torch.no_grad(): | |
up_ft_tar_replace = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2)[1::2], | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_replace["embed"], | |
encoder_hidden_states=context_replace["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context_replace["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar_replace)): | |
up_ft_tar_replace[f_id] = F.interpolate( | |
up_ft_tar_replace[f_id], | |
( | |
up_ft_tar_replace[-1].shape[-2] * up_scale, | |
up_ft_tar_replace[-1].shape[-1] * up_scale, | |
), | |
) | |
latent = latent.detach().requires_grad_(True) | |
up_ft_cur = self.feature_estimator( | |
sample=latent, | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context["embed"], | |
encoder_hidden_states=context["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_cur)): | |
up_ft_cur[f_id] = F.interpolate( | |
up_ft_cur[f_id], | |
( | |
up_ft_cur[-1].shape[-2] * up_scale, | |
up_ft_cur[-1].shape[-1] * up_scale, | |
), | |
) | |
# for base content | |
# loss_con = 0 | |
# for f_id in range(len(up_ft_tar_base)): | |
# mask_cur = (1-mask_base_cur.float())>0.5 | |
# up_ft_cur_vec = up_ft_cur[f_id][mask_cur.repeat(1,up_ft_cur[f_id].shape[1],1,1)].view(up_ft_cur[f_id].shape[1], -1).permute(1,0) | |
# up_ft_tar_vec = up_ft_tar_base[f_id][mask_cur.repeat(1,up_ft_tar_base[f_id].shape[1],1,1)].view(up_ft_tar_base[f_id].shape[1], -1).permute(1,0) | |
# sim = (cos(up_ft_cur_vec, up_ft_tar_vec)+1.)/2. | |
# loss_con = loss_con + w_content/(1+4*sim.mean()) | |
loss_con = 0 | |
for f_id in range(len(up_ft_tar_base)): | |
mask_cur = (1 - mask_base_cur.float()) > 0.5 | |
mask_cur_p = mask_base_cur.float() > 0.5 | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec = ( | |
up_ft_tar_base[f_id][ | |
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_base[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
loss_con = loss_con + w_content / (1 + 4 * sim.mean()) | |
# for replace content | |
loss_edit = 0 | |
for f_id in range(len(up_ft_tar_replace)): | |
mask_cur = mask_base_cur | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec = ( | |
up_ft_tar_replace[f_id][ | |
mask_replace_cur.repeat(1, up_ft_tar_replace[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_replace[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec_b = ( | |
up_ft_tar_base[f_id][ | |
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_base[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
# sim_all=0.8*((cos(up_ft_cur_vec, up_ft_tar_vec_b)+1.)/2.) + 0.2*((cos(up_ft_cur_vec, up_ft_tar_vec)+1.)/2.) | |
# # sim_all=0.7*((cos(up_ft_cur_vec, up_ft_tar_vec)+1.)/2.)+0.3*((cos(up_ft_cur_vec, up_ft_tar_vec_b)+1.)/2.) | |
# loss_edit=loss_edit+w_edit/(1+4*sim_all.mean()) | |
# NOTE: try to use Harmonic mean | |
sim_base = (cos(up_ft_cur_vec, up_ft_tar_vec_b) + 1.0) / 2.0 | |
sim_tar = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
loss_edit = ( | |
loss_edit | |
+ w_content / (1 + 4 * sim_base.mean()) | |
+ w_edit / (1 + 4 * sim_tar.mean()) | |
) # NOTE empirically 0.7 is a good bg to all ratio | |
cond_grad_con = torch.autograd.grad( | |
loss_con * energy_scale, latent, retain_graph=True | |
)[0] | |
cond_grad_edit = torch.autograd.grad(loss_edit * energy_scale, latent)[0] | |
mask = F.interpolate( | |
mask_base_cur.float(), | |
(cond_grad_edit[-1].shape[-2], cond_grad_edit[-1].shape[-1]), | |
) | |
mask = (mask > 0).to(dtype=latent.dtype) | |
guidance = ( | |
cond_grad_con.detach() * (1 - mask) * 4e-2 | |
+ cond_grad_edit.detach() * mask * 4e-2 | |
) | |
self.feature_estimator.zero_grad() | |
return guidance | |
def guidance_paste( | |
self, | |
mask_base_cur, | |
mask_replace_cur, | |
latent, | |
latent_noise_ref, | |
t, | |
up_ft_index, | |
up_scale, | |
context, | |
context_base, | |
context_replace, | |
energy_scale, | |
dict_mask, | |
w_edit, | |
w_content, | |
): | |
cos = nn.CosineSimilarity(dim=1) | |
with torch.no_grad(): | |
up_ft_tar_base = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2)[::2], | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_base["embed"], | |
encoder_hidden_states=context_base["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar_base)): | |
up_ft_tar_base[f_id] = F.interpolate( | |
up_ft_tar_base[f_id], | |
( | |
up_ft_tar_base[-1].shape[-2] * up_scale, | |
up_ft_tar_base[-1].shape[-1] * up_scale, | |
), | |
) | |
with torch.no_grad(): | |
up_ft_tar_replace = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2)[1::2], | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_replace["embed"], | |
encoder_hidden_states=context_replace["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context_replace["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar_replace)): | |
up_ft_tar_replace[f_id] = F.interpolate( | |
up_ft_tar_replace[f_id], | |
( | |
up_ft_tar_replace[-1].shape[-2] * up_scale, | |
up_ft_tar_replace[-1].shape[-1] * up_scale, | |
), | |
) | |
latent = latent.detach().requires_grad_(True) | |
up_ft_cur = self.feature_estimator( | |
sample=latent, | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context["embed"], | |
encoder_hidden_states=context["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_cur)): | |
up_ft_cur[f_id] = F.interpolate( | |
up_ft_cur[f_id], | |
( | |
up_ft_cur[-1].shape[-2] * up_scale, | |
up_ft_cur[-1].shape[-1] * up_scale, | |
), | |
) | |
# for base content | |
loss_con = 0 | |
for f_id in range(len(up_ft_tar_base)): | |
mask_cur = (1 - mask_base_cur.float()) > 0.5 | |
mask_cur_p = mask_base_cur.float() > 0.5 | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec = ( | |
up_ft_tar_base[f_id][ | |
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_base[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
loss_con = loss_con + w_content / (1 + 4 * sim.mean()) | |
# for replace content | |
loss_edit = 0 | |
for f_id in range(len(up_ft_tar_replace)): | |
mask_cur = mask_base_cur | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec = ( | |
up_ft_tar_replace[f_id][ | |
mask_replace_cur.repeat(1, up_ft_tar_replace[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_replace[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim_all = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
loss_edit = loss_edit + w_edit / (1 + 4 * sim_all.mean()) | |
cond_grad_con = torch.autograd.grad( | |
loss_con * energy_scale, latent, retain_graph=True | |
)[0] | |
cond_grad_edit = torch.autograd.grad(loss_edit * energy_scale, latent)[0] | |
mask = F.interpolate( | |
mask_base_cur.float(), | |
(cond_grad_edit[-1].shape[-2], cond_grad_edit[-1].shape[-1]), | |
) | |
mask = (mask > 0).to(dtype=latent.dtype) | |
guidance = ( | |
cond_grad_con.detach() * (1 - mask) * 4e-2 | |
+ cond_grad_edit.detach() * mask * 4e-2 | |
) | |
self.feature_estimator.zero_grad() | |
return guidance | |
def guidance_remove( | |
self, | |
mask_base_cur, | |
mask_replace_cur, | |
latent, | |
latent_noise_ref, | |
t, | |
up_ft_index, | |
up_scale, | |
context, | |
context_base, | |
context_replace, | |
energy_scale, | |
dict_mask, | |
w_edit, | |
w_contrast, | |
w_content, | |
): | |
cos = nn.CosineSimilarity(dim=1) | |
with torch.no_grad(): | |
up_ft_tar_base = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2)[::2], | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_base["embed"], | |
encoder_hidden_states=context_base["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar_base)): | |
up_ft_tar_base[f_id] = F.interpolate( | |
up_ft_tar_base[f_id], | |
( | |
up_ft_tar_base[-1].shape[-2] * up_scale, | |
up_ft_tar_base[-1].shape[-1] * up_scale, | |
), | |
) | |
# up_ft_tar_base[f_id] = normalize_along_channel(up_ft_tar_base[f_id]) # No improvement observed | |
with torch.no_grad(): | |
up_ft_tar_replace = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2)[1::2], | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_replace["embed"], | |
encoder_hidden_states=context_replace["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context_replace["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar_replace)): | |
up_ft_tar_replace[f_id] = F.interpolate( | |
up_ft_tar_replace[f_id], | |
( | |
up_ft_tar_replace[-1].shape[-2] * up_scale, | |
up_ft_tar_replace[-1].shape[-1] * up_scale, | |
), | |
) | |
# up_ft_tar_replace[f_id] = normalize_along_channel(up_ft_tar_replace[f_id]) | |
latent = latent.detach().requires_grad_(True) | |
up_ft_cur = self.feature_estimator( | |
sample=latent, | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context["embed"], | |
encoder_hidden_states=context["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_cur)): | |
up_ft_cur[f_id] = F.interpolate( | |
up_ft_cur[f_id], | |
( | |
up_ft_cur[-1].shape[-2] * up_scale, | |
up_ft_cur[-1].shape[-1] * up_scale, | |
), | |
) | |
# up_ft_cur[f_id] = normalize_along_channel(up_ft_cur[f_id]) | |
# for base content | |
loss_con = 0 | |
for f_id in range(len(up_ft_tar_base)): | |
mask_cur = (1 - mask_base_cur.float()) > 0.5 | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec = ( | |
up_ft_tar_base[f_id][ | |
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_base[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
loss_con = loss_con + w_content / (1 + 4 * sim.mean()) | |
# for replace content | |
loss_edit = 0 | |
for f_id in range(len(up_ft_tar_replace)): | |
mask_cur = mask_base_cur | |
# NOTE: Uncomment to get Global time&freq | |
# up_ft_cur_vec = up_ft_cur[f_id][mask_cur.repeat(1,up_ft_cur[f_id].shape[1],1,1)].view(up_ft_cur[f_id].shape[1], -1).permute(1,0) | |
# up_ft_tar_vec = up_ft_tar_replace[f_id][mask_replace_cur.repeat(1,up_ft_tar_replace[f_id].shape[1],1,1)].view(up_ft_tar_replace[f_id].shape[1], -1).permute(1,0) | |
# sim_all=((cos(up_ft_cur_vec.mean(0,keepdim=True), up_ft_tar_vec.mean(0,keepdim=True))+1.)/2.) | |
# Get a vec along time axis (global time) | |
up_ft_cur_vec_masked_sum = torch.sum( | |
up_ft_cur[f_id] * mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1), | |
dim=-2, | |
) | |
up_ft_tar_vec_masked_sum = torch.sum( | |
up_ft_tar_replace[f_id] | |
* mask_replace_cur.repeat(1, up_ft_tar_replace[f_id].shape[1], 1, 1), | |
dim=-2, | |
) # feature for reference audio | |
mask_sum = torch.sum(mask_cur, dim=-2) | |
up_ft_cur_vec = ( | |
(up_ft_cur_vec_masked_sum / (mask_sum + 1e-8)) | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec = ( | |
(up_ft_tar_vec_masked_sum / (mask_sum + 1e-8)) | |
.view(up_ft_tar_replace[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim_edit_contrast = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
# NOTE: begin to modify energy func | |
up_ft_cur_vec_base = ( | |
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
mask_cur_con = mask_cur | |
up_ft_tar_vec_base = ( | |
up_ft_tar_base[f_id][ | |
mask_cur_con.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_base[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
# # local | |
# sim_edit_consist = (cos(up_ft_cur_vec_base, up_ft_tar_vec_base) + 1.0) / 2.0 | |
# global | |
sim_edit_consist=((cos(up_ft_cur_vec_base.mean(0, keepdim=True), up_ft_tar_vec_base.mean(0, keepdim=True))+1.)/2.) | |
# loss_edit = loss_edit - w_edit/(1+4*sim_all.mean()) + w_edit/(1+4*sim_edit_consist.mean()) # NOTE: decrease sim | |
# loss_edit = loss_edit + w_edit*sim_all.mean() | |
# loss_edit = loss_edit - 0.1*w_edit/(1+4*sim_all.mean()) + w_edit/(1+4*sim_edit_consist.mean()) | |
# loss_edit = loss_edit - 0.5*w_edit/(1+4*sim_all.mean()) # NOTE: this only affect local features not semantic | |
# loss_edit = loss_edit + 0.005*w_edit*sim_all.mean() + w_edit/(1+4*sim_edit_consist.mean()) # NOTE: local | |
loss_edit = ( | |
loss_edit | |
+ w_contrast * sim_edit_contrast.mean() | |
+ w_edit / (1 + 4 * sim_edit_consist.mean()) | |
) # NOTE: local | |
cond_grad_con = torch.autograd.grad( | |
loss_con * energy_scale, latent, retain_graph=True | |
)[0] | |
cond_grad_edit = torch.autograd.grad(loss_edit * energy_scale, latent)[0] | |
mask = F.interpolate( | |
mask_base_cur.float(), | |
(cond_grad_edit[-1].shape[-2], cond_grad_edit[-1].shape[-1]), | |
) | |
mask = (mask > 0).to(dtype=latent.dtype) | |
guidance = ( | |
cond_grad_con.detach() * (1 - mask) * 4e-2 | |
+ cond_grad_edit.detach() * mask * 4e-2 | |
) | |
self.feature_estimator.zero_grad() | |
return guidance | |
def guidance_style_transfer( | |
self, | |
mask_base_cur, | |
mask_replace_cur, | |
latent, | |
latent_noise_ref, | |
t, | |
up_ft_index, | |
up_scale, | |
context, | |
context_base, | |
energy_scale, | |
dict_mask, | |
w_edit, | |
w_content, | |
): | |
cos = nn.CosineSimilarity(dim=1) | |
with torch.no_grad(): | |
up_ft_tar_base = self.feature_estimator( | |
sample=latent_noise_ref.squeeze(2), | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context_base["embed"], | |
encoder_hidden_states=context_base["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_tar_base)): | |
up_ft_tar_base[f_id] = F.interpolate( | |
up_ft_tar_base[f_id], | |
( | |
up_ft_tar_base[-1].shape[-2] * up_scale, | |
up_ft_tar_base[-1].shape[-1] * up_scale, | |
), | |
) | |
latent = latent.detach().requires_grad_(True) | |
up_ft_cur = self.feature_estimator( | |
sample=latent, | |
timestep=t, | |
up_ft_indices=up_ft_index, | |
class_labels=None if self.use_cross_attn else context["embed"], | |
encoder_hidden_states=context["embed"] if self.use_cross_attn else None, | |
encoder_attention_mask=context["mask"] if self.use_cross_attn else None, | |
)["up_ft"] | |
for f_id in range(len(up_ft_cur)): | |
up_ft_cur[f_id] = F.interpolate( | |
up_ft_cur[f_id], | |
(up_ft_cur[-1].shape[-2] * up_scale, up_ft_cur[-1].shape[-1] * up_scale), | |
) | |
# for base content | |
loss_con = 0 | |
for f_id in range(len(up_ft_tar_base)): | |
mask_cur = (1 - mask_base_cur.float()) > 0.5 | |
mask_cur_p = mask_base_cur.float() > 0.5 | |
up_ft_cur_vec = ( | |
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)] | |
.view(up_ft_cur[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
up_ft_tar_vec = ( | |
up_ft_tar_base[f_id][ | |
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1) | |
] | |
.view(up_ft_tar_base[f_id].shape[1], -1) | |
.permute(1, 0) | |
) | |
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0 | |
loss_con = loss_con + w_content / (1 + 4 * sim.mean()) | |
cond_grad_con = torch.autograd.grad( | |
loss_con * energy_scale, latent, retain_graph=True | |
)[0] | |
mask = F.interpolate( | |
mask_base_cur.float(), | |
(cond_grad_con[-1].shape[-2], cond_grad_con[-1].shape[-1]), | |
) | |
mask = (mask > 0).to(dtype=latent.dtype) | |
guidance = cond_grad_con.detach() * (1 - mask) * 4e-2 | |
self.feature_estimator.zero_grad() | |
return guidance | |
def _encode_text(self, text_input: str): | |
text_input = self.tokenizer( | |
[text_input], | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, # NOTE 77 | |
truncation=True, | |
return_tensors="pt", | |
) | |
input_ids, attn_mask = text_input.input_ids.to(self._device), text_input.attention_mask.to(self._device) | |
text_embeddings = self.text_encoder(input_ids,attention_mask=attn_mask)[0] | |
text_embeddings = F.normalize(text_embeddings, dim=-1) | |
boolean_attn_mask = (attn_mask == 1).to(self._device) | |
return {"embed": text_embeddings, "mask": boolean_attn_mask} | |
def _stack_text(self, text_input0, text_input1): | |
text_embs0, text_mask0 = text_input0["embed"], text_input0["mask"] | |
text_embs1, text_mask1 = text_input1["embed"], text_input1["mask"] | |
out_embs = torch.cat( | |
[text_embs0.expand(*text_embs1.shape), text_embs1] | |
) | |
out_mask = torch.cat( | |
[text_mask0.expand(*text_mask1.shape), text_mask1] | |
) | |
return {"embed": out_embs, "mask": out_mask} | |
# class AudioLDMSampler(_Sampler): | |
# def __init__( | |
# self, | |
# vae: AutoencoderKL, | |
# tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], | |
# text_encoder: ClapTextModelWithProjection, | |
# unet: UNet2DConditionModel, | |
# feature_estimator: UNet2DConditionModel, | |
# scheduler: DDIMScheduler, | |
# vocoder: SpeechT5HifiGan, | |
# device: torch.device = torch.device("cpu"), | |
# precision: torch.dtype = torch.float32, | |
# ): | |
# super().__init__() | |
# self.register_modules( | |
# vae=vae, | |
# tokenizer=tokenizer, | |
# text_encoder=text_encoder, | |
# unet=unet, | |
# estimator=feature_estimator, | |
# scheduler=scheduler, | |
# vocoder=vocoder, | |
# ) | |
# self = self.to(device=device, dtype=precision) | |
# class TangoSampler(_Sampler): | |
# def __init__( | |
# self, | |
# vae: AutoencoderKL, | |
# text_encoder: ClapTextModelWithProjection, | |
# tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], | |
# unet: UNet2DConditionModel, | |
# scheduler: KarrasDiffusionSchedulers, | |
# vocoder: SpeechT5HifiGan, | |
# ): | |
# super().__init__() | |
# self.register_modules( | |
# vae=vae, | |
# text_encoder=text_encoder, | |
# tokenizer=tokenizer, | |
# unet=unet, | |
# scheduler=scheduler, | |
# vocoder=vocoder, | |
# ) | |
# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | |