AudioMorphix / src /model /sampler.py
JinhuaL1ANG's picture
v1
9a6dac6
import os
import copy
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
# from basicsr.utils import img2tensor
from diffusers import DDIMScheduler
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from transformers import (
ClapTextModelWithProjection,
RobertaTokenizer,
RobertaTokenizerFast,
SpeechT5HifiGan,
)
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from typing import List, Optional, Union
from src.utils.factory import (
normalize_along_channel,
project_onto_tangent_space,
identity_projection,
)
log = logging.getLogger(__name__)
class Sampler(DiffusionPipeline):
r"""Core of Audio Morphix that samples latent with inversed latent and cross-attend trajactory."""
def __init__(
self,
vae: AutoencoderKL,
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
text_encoder: ClapTextModelWithProjection,
unet: UNet2DConditionModel,
feature_estimator: UNet2DConditionModel,
scheduler: DDIMScheduler,
device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float32,
):
super().__init__()
self.register_modules(
vae=vae,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=unet,
feature_estimator=feature_estimator,
scheduler=scheduler,
)
self = self.to(torch_device=device, torch_dtype=precision)
self._device = device
def edit(
self,
prompt: str,
mode,
edit_kwargs,
prompt_replace: str = None,
negative_prompt: str = None,
num_inference_steps: int = 50,
guidance_scale: Optional[float] = 7.5,
latent: Optional[torch.FloatTensor] = None,
start_time: int = 50,
energy_scale: float = 0,
SDE_strength: float = 0.4,
SDE_strength_un: float = 0,
latent_noise_ref: Optional[torch.FloatTensor] = None,
bg_to_fg_ratio: float = 0.5,
disable_tangent_proj: bool = False,
alg: str = "D+",
):
log.info("Start Editing:")
self.alg = alg
# Select projection function
if disable_tangent_proj:
log.info("Use guidance directly.")
self.proj_fn = identity_projection
else:
log.info("Project guidance onto tangent space.")
self.proj_fn = project_onto_tangent_space
# Generate source text embedding
text_input = self._encode_text(prompt)
if prompt_replace is not None:
text_replace = self._encode_text(prompt_replace)
else:
text_replace = text_input
# Generate null text embedding for CFG
prompt_uncond = "" if negative_prompt is None else negative_prompt
text_uncond = self._encode_text(prompt_uncond)
# Text condition for the current trajectory
context = self._stack_text(text_uncond, text_input)
self.scheduler.set_timesteps(num_inference_steps)
dict_mask = edit_kwargs["dict_mask"] if "dict_mask" in edit_kwargs else None
time_scale = num_inference_steps / 50 # scale the editing operation (default is 50)
for i, t in enumerate(tqdm(self.scheduler.timesteps[-start_time:])):
next_timestep = min(
t
- self.scheduler.config.num_train_timesteps
// self.scheduler.num_inference_steps,
999,
)
next_timestep = max(next_timestep, 0)
if energy_scale == 0 or alg == "D":
repeat = 1
elif int(20*time_scale) < i < int(30*time_scale) and i % 2 == 0:
repeat = 3
else:
repeat = 1
for ri in range(repeat):
latent_in = torch.cat([latent.unsqueeze(2)] * 2)
with torch.no_grad():
noise_pred = self.unet(
latent_in,
t,
encoder_hidden_states=context["embed"] if self.use_cross_attn else None,
class_labels=None if self.use_cross_attn else context["embed"],
encoder_attention_mask=context["mask"] if self.use_cross_attn else None,
mask=dict_mask,
save_kv=False,
mode=mode,
iter_cur=i,
)["sample"].squeeze(2)
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_prediction_text - noise_pred_uncond
)
if (
energy_scale != 0
and int(i*time_scale) < 30
and (alg == "D" or i % 2 == 0 or i < int(10*time_scale))
):
# editing guidance
noise_pred_org = noise_pred
if mode == "move":
guidance = self.guidance_move(
latent=latent,
latent_noise_ref=latent_noise_ref[-(i + 1)],
t=t,
context=text_input,
context_base=text_input,
energy_scale=energy_scale,
**edit_kwargs
)
elif mode == "drag":
guidance = self.guidance_drag(
latent=latent,
latent_noise_ref=latent_noise_ref[-(i + 1)],
t=t,
context=text_input,
context_base=text_input,
energy_scale=energy_scale,
**edit_kwargs
)
elif mode == "landmark":
guidance = self.guidance_landmark(
latent=latent,
latent_noise_ref=latent_noise_ref[-(i + 1)],
t=t,
context=text_input,
context_base=text_input,
energy_scale=energy_scale,
**edit_kwargs
)
elif mode == "appearance":
guidance = self.guidance_appearance(
latent=latent,
latent_noise_ref=latent_noise_ref[-(i + 1)],
t=t,
context=text_input,
context_base=text_input,
energy_scale=energy_scale,
**edit_kwargs
)
elif mode == "paste":
guidance = self.guidance_paste(
latent=latent,
latent_noise_ref=latent_noise_ref[-(i + 1)],
t=t,
context=text_input,
context_base=text_input,
context_replace=text_replace,
energy_scale=energy_scale,
**edit_kwargs
)
elif mode == "mix":
guidance = self.guidance_mix(
latent=latent,
latent_noise_ref=latent_noise_ref[-(i + 1)],
t=t,
context=text_input,
context_base=text_input,
context_replace=text_replace,
energy_scale=energy_scale,
**edit_kwargs
)
elif mode == "remove":
guidance = self.guidance_remove(
latent=latent,
latent_noise_ref=latent_noise_ref[-(i + 1)],
t=t,
context=text_input,
context_base=text_input,
context_replace=text_replace,
energy_scale=energy_scale,
**edit_kwargs
)
elif mode == "style_transfer":
guidance = self.guidance_style_transfer(
latent=latent,
latent_noise_ref=latent_noise_ref[-(i + 1)],
t=t,
context=text_input,
context_base=text_input,
energy_scale=energy_scale,
**edit_kwargs
)
# Project guidance onto z_t
guidance = self.proj_fn(guidance, latent)
noise_pred = noise_pred + guidance # NOTE: weighted sum?
else:
noise_pred_org = None
# zt->zt-1
prev_timestep = (
t
- self.scheduler.config.num_train_timesteps
// self.scheduler.num_inference_steps
)
alpha_prod_t = self.scheduler.alphas_cumprod[t]
alpha_prod_t_prev = (
self.scheduler.alphas_cumprod[prev_timestep]
if prev_timestep >= 0
else self.scheduler.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
if self.scheduler.config.prediction_type == "epsilon":
pred_original_sample = (
latent - beta_prod_t ** (0.5) * noise_pred
) / alpha_prod_t ** (0.5)
pred_epsilon = noise_pred
pred_epsilon_org = noise_pred_org
elif self.scheduler.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * latent - (beta_prod_t**0.5) * noise_pred
pred_epsilon = (alpha_prod_t**0.5) * noise_pred + (beta_prod_t**0.5) * latent
if noise_pred_org is not None:
pred_epsilon_org = (alpha_prod_t**0.5) * noise_pred_org + (beta_prod_t**0.5) * latent
else:
pred_epsilon_org = None
if int(10*time_scale) < i < int(20*time_scale):
eta, eta_rd = SDE_strength_un, SDE_strength
else:
eta, eta_rd = 0.0, 0.0
variance = self.scheduler._get_variance(t, prev_timestep)
std_dev_t = eta * variance ** (0.5)
std_dev_t_rd = eta_rd * variance ** (0.5)
if noise_pred_org is not None:
pred_sample_direction_rd = (
1 - alpha_prod_t_prev - std_dev_t_rd**2
) ** (0.5) * pred_epsilon_org
pred_sample_direction = (
1 - alpha_prod_t_prev - std_dev_t**2
) ** (0.5) * pred_epsilon_org
else:
pred_sample_direction_rd = (
1 - alpha_prod_t_prev - std_dev_t_rd**2
) ** (0.5) * pred_epsilon
pred_sample_direction = (
1 - alpha_prod_t_prev - std_dev_t**2
) ** (0.5) * pred_epsilon
latent_prev = (
alpha_prod_t_prev ** (0.5) * pred_original_sample
+ pred_sample_direction
)
latent_prev_rd = (
alpha_prod_t_prev ** (0.5) * pred_original_sample
+ pred_sample_direction_rd
)
# Regional SDE
if (eta_rd > 0 or eta > 0) and alg == "D+":
variance_noise = torch.randn_like(latent_prev)
variance_rd = std_dev_t_rd * variance_noise
variance = std_dev_t * variance_noise
if mode == "move":
mask = (
F.interpolate(
edit_kwargs["mask_x0"][None, None],
(
edit_kwargs["mask_cur"].shape[-2],
edit_kwargs["mask_cur"].shape[-1],
),
)
> 0.5
).float()
mask = ((edit_kwargs["mask_cur"] + mask) > 0.5).float()
mask = (
F.interpolate(
mask, (latent_prev.shape[-2], latent_prev.shape[-1])
)
> 0.5
).to(dtype=latent.dtype)
elif mode == "drag":
mask = F.interpolate(
edit_kwargs["mask_x0"][None, None],
(latent_prev[-1].shape[-2], latent_prev[-1].shape[-1]),
)
mask = (mask > 0).to(dtype=latent.dtype)
elif mode == "landmark":
mask = torch.ones_like(latent_prev)
elif (
mode == "appearance"
or mode == "paste"
or mode == "remove"
or mode == "mix"
):
mask = F.interpolate(
edit_kwargs["mask_base_cur"].float(),
(latent_prev[-1].shape[-2], latent_prev[-1].shape[-1]),
)
mask = (mask > 0).to(dtype=latent.dtype)
latent_prev = (latent_prev + variance) * (1 - mask) + (
latent_prev_rd + variance_rd
) * mask
if repeat > 1:
with torch.no_grad():
alpha_prod_t = self.scheduler.alphas_cumprod[next_timestep]
alpha_prod_t_next = self.scheduler.alphas_cumprod[t]
beta_prod_t = 1 - alpha_prod_t
model_output = self.unet(
latent_prev.unsqueeze(2),
next_timestep,
class_labels=None if self.use_cross_attn else text_input["embed"],
encoder_hidden_states=text_input["embed"] if self.use_cross_attn else None,
encoder_attention_mask=text_input["mask"] if self.use_cross_attn else None,
mask=dict_mask,
save_kv=False,
mode=mode,
iter_cur=-2,
)["sample"].squeeze(2)
# Different scheduling options
if self.scheduler.config.prediction_type == "epsilon":
next_original_sample = (
latent_prev - beta_prod_t**0.5 * model_output
) / alpha_prod_t**0.5
pred_epsilon = model_output
elif self.scheduler.config.prediction_type == "v_prediction":
next_original_sample = (alpha_prod_t**0.5) * latent_prev - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * latent_prev
next_sample_direction = (
1 - alpha_prod_t_next
) ** 0.5 * pred_epsilon
latent = (
alpha_prod_t_next**0.5 * next_original_sample
+ next_sample_direction
)
latent = latent_prev
return latent
def guidance_move(
self,
mask_x0,
mask_x0_ref,
mask_x0_keep,
mask_tar,
mask_cur,
mask_keep,
mask_other,
mask_overlap,
mask_non_overlap,
latent,
latent_noise_ref,
t,
up_ft_index,
context,
context_base,
up_scale,
resize_scale_x,
resize_scale_y,
energy_scale,
w_edit,
w_content,
w_contrast,
w_inpaint,
):
cos = nn.CosineSimilarity(dim=1)
loss_scale = [0.5, 0.5]
with torch.no_grad():
up_ft_tar = self.feature_estimator(
sample=latent_noise_ref.squeeze(2),
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_base["embed"],
encoder_hidden_states=(
context_base["embed"] if self.use_cross_attn else None
),
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None,
)["up_ft"]
up_ft_tar_org = copy.deepcopy(up_ft_tar)
for f_id in range(len(up_ft_tar_org)):
up_ft_tar_org[f_id] = F.interpolate(
up_ft_tar_org[f_id],
(
up_ft_tar_org[-1].shape[-2] * up_scale,
up_ft_tar_org[-1].shape[-1] * up_scale,
),
)
latent = latent.detach().requires_grad_(True)
for f_id in range(len(up_ft_tar)):
up_ft_tar[f_id] = F.interpolate(
up_ft_tar[f_id],
(
int(up_ft_tar[-1].shape[-2] * resize_scale_y * up_scale),
int(up_ft_tar[-1].shape[-1] * resize_scale_x * up_scale),
),
)
up_ft_cur = self.feature_estimator(
sample=latent,
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context["embed"],
encoder_hidden_states=(
context["embed"] if self.use_cross_attn else None
),
encoder_attention_mask=context["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_cur)):
up_ft_cur[f_id] = F.interpolate(
up_ft_cur[f_id],
(
up_ft_cur[-1].shape[-2] * up_scale,
up_ft_cur[-1].shape[-1] * up_scale,
),
)
# editing energy
loss_edit = 0
for f_id in range(len(up_ft_tar)):
up_ft_cur_vec = (
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
) # (x1, D)
up_ft_tar_vec = (
up_ft_tar[f_id][mask_tar.repeat(1, up_ft_tar[f_id].shape[1], 1, 1)]
.view(up_ft_tar[f_id].shape[1], -1)
.permute(1, 0)
) # (x2, D)
# Compute consine sim between `up_ft_cur_vec` and `up_ft_tar_vec`
# up_ft_cur_vec_norm = up_ft_cur_vec / (up_ft_cur_vec.norm(dim=1, keepdim=True) + 1e-8)
# up_ft_tar_vec_norm = up_ft_tar_vec / (up_ft_tar_vec.norm(dim=1, keepdim=True) + 1e-8)
# sim = torch.mm(up_ft_cur_vec_norm, up_ft_tar_vec_norm.T)
sim = cos(up_ft_cur_vec, up_ft_tar_vec)
# sim_global = cos(
# up_ft_cur_vec.mean(0, keepdim=True), up_ft_tar_vec.mean(0, keepdim=True)
# )
loss_edit = loss_edit + (w_edit / (1 + 4 * sim.mean())) * loss_scale[f_id]
# Content energy
loss_con = 0
for f_id in range(len(up_ft_tar_org)):
sim_other = cos(up_ft_tar_org[f_id], up_ft_cur[f_id])[0][mask_other[0, 0]]
loss_con = (
loss_con + w_content / (1 + 4 * sim_other.mean()) * loss_scale[f_id]
)
if mask_x0_ref is not None:
mask_x0_ref_cur = (
F.interpolate(
mask_x0_ref[None, None],
(mask_other.shape[-2], mask_other.shape[-1]),
)
> 0.5
)
else:
mask_x0_ref_cur = mask_other
for f_id in range(len(up_ft_tar)):
# # Global
# up_ft_cur_non_overlap_contrast = (
# up_ft_cur[f_id][
# mask_non_overlap.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)
# ]
# .view(up_ft_cur[f_id].shape[1], -1)
# .permute(1, 0)
# )
# up_ft_tar_non_overlap_contrast = (
# up_ft_tar_org[f_id][
# mask_non_overlap.repeat(1, up_ft_tar_org[f_id].shape[1], 1, 1)
# ]
# .view(up_ft_tar_org[f_id].shape[1], -1)
# .permute(1, 0)
# )
# F sim
up_ft_cur_non_overlap_sum = torch.sum(
up_ft_cur[f_id] * mask_non_overlap.repeat(1, up_ft_cur[f_id].shape[1], 1, 1),
dim=-2,
)
up_ft_tar_non_overlap_sum = torch.sum(
up_ft_tar_org[f_id]
* mask_non_overlap.repeat(1, up_ft_tar_org[f_id].shape[1], 1, 1),
dim=-2,
) # feature for reference audio
mask_sum = torch.sum(mask_non_overlap, dim=-2)
up_ft_cur_non_overlap_contrast = (
(up_ft_cur_non_overlap_sum / (mask_sum + 1e-8))
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
) # avoid dividing zero
up_ft_tar_non_overlap_contrast = (
(up_ft_tar_non_overlap_sum / (mask_sum + 1e-8))
.view(up_ft_tar_org[f_id].shape[1], -1)
.permute(1, 0)
)
sim_non_overlap_contrast = (
cos(up_ft_cur_non_overlap_contrast, up_ft_tar_non_overlap_contrast) + 1.0
) / 2.0
loss_con = loss_con + w_contrast * sim_non_overlap_contrast.mean() * loss_scale[f_id]
up_ft_cur_non_overlap_inpaint = (
up_ft_cur[f_id][
mask_non_overlap.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)
]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
.mean(0, keepdim=True)
)
up_ft_tar_non_overlap_inpaint = (
up_ft_tar_org[f_id][
mask_x0_ref_cur.repeat(1, up_ft_tar_org[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_org[f_id].shape[1], -1)
.permute(1, 0)
.mean(0, keepdim=True)
)
sim_inpaint = (
cos(up_ft_cur_non_overlap_inpaint, up_ft_tar_non_overlap_inpaint) + 1.0
) / 2.0
loss_con = loss_con + w_inpaint / (1 + 4 * sim_inpaint.mean())
cond_grad_edit = torch.autograd.grad(
loss_edit * energy_scale, latent, retain_graph=True
)[0]
cond_grad_con = torch.autograd.grad(loss_con * energy_scale, latent)[0]
mask_edit1 = (mask_cur > 0.5).float()
mask_edit1 = (
F.interpolate(mask_edit1, (latent.shape[-2], latent.shape[-1])) > 0
).to(dtype=latent.dtype)
# mask_edit2 = ((mask_keep + mask_non_overlap.float()) > 0.5).float()
# mask_edit2 = (
# F.interpolate(mask_edit2, (latent.shape[-2], latent.shape[-1])) > 0
# ).to(dtype=latent.dtype)
mask_edit2 = 1-mask_edit1
guidance = (
cond_grad_edit.detach() * 8e-2 * mask_edit1
+ cond_grad_con.detach() * 8e-2 * mask_edit2
)
self.feature_estimator.zero_grad()
return guidance
def guidance_drag(
self,
mask_x0,
mask_cur,
mask_tar,
mask_other,
latent,
latent_noise_ref,
t,
up_ft_index,
up_scale,
context,
context_base,
energy_scale,
w_edit,
w_inpaint,
w_content,
dict_mask=None,
):
cos = nn.CosineSimilarity(dim=1)
with torch.no_grad():
up_ft_tar = self.feature_estimator(
sample=latent_noise_ref.squeeze(2),
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_base["embed"],
encoder_hidden_states=(
context_base["embed"] if self.use_cross_attn else None
),
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar)):
up_ft_tar[f_id] = F.interpolate(
up_ft_tar[f_id],
(
up_ft_tar[-1].shape[-2] * up_scale,
up_ft_tar[-1].shape[-1] * up_scale,
),
)
latent = latent.detach().requires_grad_(True)
up_ft_cur = self.feature_estimator(
sample=latent,
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context["embed"],
encoder_hidden_states=(
context["embed"] if self.use_cross_attn else None
),
encoder_attention_mask=context["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_cur)):
up_ft_cur[f_id] = F.interpolate(
up_ft_cur[f_id],
(
up_ft_cur[-1].shape[-2] * up_scale,
up_ft_cur[-1].shape[-1] * up_scale,
),
)
# moving loss
loss_edit = 0
for f_id in range(len(up_ft_tar)):
for mask_cur_i, mask_tar_i in zip(mask_cur, mask_tar):
up_ft_cur_vec = (
up_ft_cur[f_id][
mask_cur_i.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)
]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec = (
up_ft_tar[f_id][
mask_tar_i.repeat(1, up_ft_tar[f_id].shape[1], 1, 1)
]
.view(up_ft_tar[f_id].shape[1], -1)
.permute(1, 0)
)
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
loss_edit = loss_edit + w_edit / (1 + 4 * sim.mean())
mask_overlap = ((mask_cur_i.float() + mask_tar_i.float()) > 1.5).float()
mask_non_overlap = (mask_tar_i.float() - mask_overlap) > 0.5
up_ft_cur_non_overlap = (
up_ft_cur[f_id][
mask_non_overlap.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)
]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_non_overlap = (
up_ft_tar[f_id][
mask_non_overlap.repeat(1, up_ft_tar[f_id].shape[1], 1, 1)
]
.view(up_ft_tar[f_id].shape[1], -1)
.permute(1, 0)
)
sim_non_overlap = (
cos(up_ft_cur_non_overlap, up_ft_tar_non_overlap) + 1.0
) / 2.0
loss_edit = loss_edit + w_inpaint * sim_non_overlap.mean()
# consistency loss
loss_con = 0
for f_id in range(len(up_ft_tar)):
sim_other = (
cos(up_ft_tar[f_id], up_ft_cur[f_id])[0][mask_other[0, 0]] + 1.0
) / 2.0
loss_con = loss_con + w_content / (1 + 4 * sim_other.mean())
loss_edit = loss_edit / len(up_ft_cur) / len(mask_cur)
loss_con = loss_con / len(up_ft_cur)
cond_grad_edit = torch.autograd.grad(
loss_edit * energy_scale, latent, retain_graph=True
)[0]
cond_grad_con = torch.autograd.grad(loss_con * energy_scale, latent)[0]
mask = F.interpolate(
mask_x0[None, None],
(cond_grad_edit[-1].shape[-2], cond_grad_edit[-1].shape[-1]),
)
mask = (mask > 0).to(dtype=latent.dtype)
guidance = (
cond_grad_edit.detach() * 4e-2 * mask
+ cond_grad_con.detach() * 4e-2 * (1 - mask)
)
self.feature_estimator.zero_grad()
return guidance
def guidance_landmark(
self,
mask_cur,
mask_tar,
latent,
latent_noise_ref,
t,
up_ft_index,
up_scale,
context,
context_base,
energy_scale,
w_edit,
w_inpaint,
):
cos = nn.CosineSimilarity(dim=1)
with torch.no_grad():
up_ft_tar = self.feature_estimator(
sample=latent_noise_ref.squeeze(2),
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_base["embed"],
encoder_hidden_states=(
context_base["embed"] if self.use_cross_attn else None
),
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar)):
up_ft_tar[f_id] = F.interpolate(
up_ft_tar[f_id],
(
up_ft_tar[-1].shape[-2] * up_scale,
up_ft_tar[-1].shape[-1] * up_scale,
),
)
latent = latent.detach().requires_grad_(True)
up_ft_cur = self.feature_estimator(
sample=latent,
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context["embed"],
encoder_hidden_states=(
context["embed"] if self.use_cross_attn else None
),
encoder_attention_mask=context["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_cur)):
up_ft_cur[f_id] = F.interpolate(
up_ft_cur[f_id],
(
up_ft_cur[-1].shape[-2] * up_scale,
up_ft_cur[-1].shape[-1] * up_scale,
),
)
# moving loss
loss_edit = 0
for f_id in range(len(up_ft_tar)):
for mask_cur_i, mask_tar_i in zip(mask_cur, mask_tar):
up_ft_cur_vec = (
up_ft_cur[f_id][
mask_cur_i.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)
]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec = (
up_ft_tar[f_id][
mask_tar_i.repeat(1, up_ft_tar[f_id].shape[1], 1, 1)
]
.view(up_ft_tar[f_id].shape[1], -1)
.permute(1, 0)
)
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
loss_edit = loss_edit + w_edit / (1 + 4 * sim.mean())
loss_edit = loss_edit / len(up_ft_cur) / len(mask_cur)
cond_grad_edit = torch.autograd.grad(
loss_edit * energy_scale, latent, retain_graph=True
)[0]
guidance = cond_grad_edit.detach() * 4e-2
self.feature_estimator.zero_grad()
return guidance
def guidance_appearance(
self,
mask_base_cur,
mask_replace_cur,
latent,
latent_noise_ref,
t,
up_ft_index,
up_scale,
context,
context_base,
context_replace,
energy_scale,
dict_mask,
w_edit,
w_content,
):
cos = nn.CosineSimilarity(dim=1)
with torch.no_grad():
up_ft_tar_base = self.feature_estimator(
sample=latent_noise_ref.squeeze(2)[::2],
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_base["embed"],
encoder_hidden_states=(
context_base["embed"] if self.use_cross_attn else None
),
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar_base)):
up_ft_tar_base[f_id] = F.interpolate(
up_ft_tar_base[f_id],
(
up_ft_tar_base[-1].shape[-2] * up_scale,
up_ft_tar_base[-1].shape[-1] * up_scale,
),
)
with torch.no_grad():
up_ft_tar_replace = self.feature_estimator(
sample=latent_noise_ref.squeeze(2)[1::2],
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_replace["embed"],
encoder_hidden_states=(
context_replace["embed"] if self.use_cross_attn else None
),
encoder_attention_mask=context_replace["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar_replace)):
up_ft_tar_replace[f_id] = F.interpolate(
up_ft_tar_replace[f_id],
(
up_ft_tar_replace[-1].shape[-2] * up_scale,
up_ft_tar_replace[-1].shape[-1] * up_scale,
),
)
latent = latent.detach().requires_grad_(True)
up_ft_cur = self.feature_estimator(
sample=latent,
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context["embed"],
encoder_hidden_states=(
context["embed"] if self.use_cross_attn else None
),
encoder_attention_mask=context["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_cur)):
up_ft_cur[f_id] = F.interpolate(
up_ft_cur[f_id],
(
up_ft_cur[-1].shape[-2] * up_scale,
up_ft_cur[-1].shape[-1] * up_scale,
),
)
# for base content
loss_con = 0
for f_id in range(len(up_ft_tar_base)):
mask_cur = (1 - mask_base_cur.float()) > 0.5
up_ft_cur_vec = (
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec = (
up_ft_tar_base[f_id][
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_base[f_id].shape[1], -1)
.permute(1, 0)
)
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
loss_con = loss_con + w_content / (1 + 4 * sim.mean())
# for replace content
loss_edit = 0
for f_id in range(len(up_ft_tar_replace)):
mask_cur = mask_base_cur
mask_tar = mask_replace_cur
up_ft_cur_vec = (
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
.mean(0, keepdim=True)
)
up_ft_tar_vec = (
up_ft_tar_replace[f_id][
mask_tar.repeat(1, up_ft_tar_replace[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_replace[f_id].shape[1], -1)
.permute(1, 0)
.mean(0, keepdim=True)
)
sim_all = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
loss_edit = loss_edit + w_edit / (1 + 4 * sim_all.mean())
cond_grad_con = torch.autograd.grad(
loss_con * energy_scale, latent, retain_graph=True
)[0]
cond_grad_edit = torch.autograd.grad(loss_edit * energy_scale, latent)[0]
mask = F.interpolate(
mask_base_cur.float(),
(cond_grad_edit[-1].shape[-2], cond_grad_edit[-1].shape[-1]),
)
mask = (mask > 0).to(dtype=latent.dtype)
guidance = (
cond_grad_con.detach() * (1 - mask) * 4e-2
+ cond_grad_edit.detach() * mask * 4e-2
)
self.feature_estimator.zero_grad()
return guidance
def guidance_mix(
self,
mask_base_cur,
mask_replace_cur,
latent,
latent_noise_ref,
t,
up_ft_index,
up_scale,
context,
context_base,
context_replace,
energy_scale,
dict_mask,
w_edit,
w_content,
):
cos = nn.CosineSimilarity(dim=1)
# pcpt_loss_fn = LPIPS()
with torch.no_grad():
up_ft_tar_base = self.feature_estimator(
sample=latent_noise_ref.squeeze(2)[::2],
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_base["embed"],
encoder_hidden_states=context_base["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar_base)):
up_ft_tar_base[f_id] = F.interpolate(
up_ft_tar_base[f_id],
(
up_ft_tar_base[-1].shape[-2] * up_scale,
up_ft_tar_base[-1].shape[-1] * up_scale,
),
)
with torch.no_grad():
up_ft_tar_replace = self.feature_estimator(
sample=latent_noise_ref.squeeze(2)[1::2],
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_replace["embed"],
encoder_hidden_states=context_replace["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context_replace["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar_replace)):
up_ft_tar_replace[f_id] = F.interpolate(
up_ft_tar_replace[f_id],
(
up_ft_tar_replace[-1].shape[-2] * up_scale,
up_ft_tar_replace[-1].shape[-1] * up_scale,
),
)
latent = latent.detach().requires_grad_(True)
up_ft_cur = self.feature_estimator(
sample=latent,
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context["embed"],
encoder_hidden_states=context["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_cur)):
up_ft_cur[f_id] = F.interpolate(
up_ft_cur[f_id],
(
up_ft_cur[-1].shape[-2] * up_scale,
up_ft_cur[-1].shape[-1] * up_scale,
),
)
# for base content
# loss_con = 0
# for f_id in range(len(up_ft_tar_base)):
# mask_cur = (1-mask_base_cur.float())>0.5
# up_ft_cur_vec = up_ft_cur[f_id][mask_cur.repeat(1,up_ft_cur[f_id].shape[1],1,1)].view(up_ft_cur[f_id].shape[1], -1).permute(1,0)
# up_ft_tar_vec = up_ft_tar_base[f_id][mask_cur.repeat(1,up_ft_tar_base[f_id].shape[1],1,1)].view(up_ft_tar_base[f_id].shape[1], -1).permute(1,0)
# sim = (cos(up_ft_cur_vec, up_ft_tar_vec)+1.)/2.
# loss_con = loss_con + w_content/(1+4*sim.mean())
loss_con = 0
for f_id in range(len(up_ft_tar_base)):
mask_cur = (1 - mask_base_cur.float()) > 0.5
mask_cur_p = mask_base_cur.float() > 0.5
up_ft_cur_vec = (
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec = (
up_ft_tar_base[f_id][
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_base[f_id].shape[1], -1)
.permute(1, 0)
)
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
loss_con = loss_con + w_content / (1 + 4 * sim.mean())
# for replace content
loss_edit = 0
for f_id in range(len(up_ft_tar_replace)):
mask_cur = mask_base_cur
up_ft_cur_vec = (
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec = (
up_ft_tar_replace[f_id][
mask_replace_cur.repeat(1, up_ft_tar_replace[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_replace[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec_b = (
up_ft_tar_base[f_id][
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_base[f_id].shape[1], -1)
.permute(1, 0)
)
# sim_all=0.8*((cos(up_ft_cur_vec, up_ft_tar_vec_b)+1.)/2.) + 0.2*((cos(up_ft_cur_vec, up_ft_tar_vec)+1.)/2.)
# # sim_all=0.7*((cos(up_ft_cur_vec, up_ft_tar_vec)+1.)/2.)+0.3*((cos(up_ft_cur_vec, up_ft_tar_vec_b)+1.)/2.)
# loss_edit=loss_edit+w_edit/(1+4*sim_all.mean())
# NOTE: try to use Harmonic mean
sim_base = (cos(up_ft_cur_vec, up_ft_tar_vec_b) + 1.0) / 2.0
sim_tar = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
loss_edit = (
loss_edit
+ w_content / (1 + 4 * sim_base.mean())
+ w_edit / (1 + 4 * sim_tar.mean())
) # NOTE empirically 0.7 is a good bg to all ratio
cond_grad_con = torch.autograd.grad(
loss_con * energy_scale, latent, retain_graph=True
)[0]
cond_grad_edit = torch.autograd.grad(loss_edit * energy_scale, latent)[0]
mask = F.interpolate(
mask_base_cur.float(),
(cond_grad_edit[-1].shape[-2], cond_grad_edit[-1].shape[-1]),
)
mask = (mask > 0).to(dtype=latent.dtype)
guidance = (
cond_grad_con.detach() * (1 - mask) * 4e-2
+ cond_grad_edit.detach() * mask * 4e-2
)
self.feature_estimator.zero_grad()
return guidance
def guidance_paste(
self,
mask_base_cur,
mask_replace_cur,
latent,
latent_noise_ref,
t,
up_ft_index,
up_scale,
context,
context_base,
context_replace,
energy_scale,
dict_mask,
w_edit,
w_content,
):
cos = nn.CosineSimilarity(dim=1)
with torch.no_grad():
up_ft_tar_base = self.feature_estimator(
sample=latent_noise_ref.squeeze(2)[::2],
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_base["embed"],
encoder_hidden_states=context_base["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar_base)):
up_ft_tar_base[f_id] = F.interpolate(
up_ft_tar_base[f_id],
(
up_ft_tar_base[-1].shape[-2] * up_scale,
up_ft_tar_base[-1].shape[-1] * up_scale,
),
)
with torch.no_grad():
up_ft_tar_replace = self.feature_estimator(
sample=latent_noise_ref.squeeze(2)[1::2],
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_replace["embed"],
encoder_hidden_states=context_replace["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context_replace["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar_replace)):
up_ft_tar_replace[f_id] = F.interpolate(
up_ft_tar_replace[f_id],
(
up_ft_tar_replace[-1].shape[-2] * up_scale,
up_ft_tar_replace[-1].shape[-1] * up_scale,
),
)
latent = latent.detach().requires_grad_(True)
up_ft_cur = self.feature_estimator(
sample=latent,
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context["embed"],
encoder_hidden_states=context["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_cur)):
up_ft_cur[f_id] = F.interpolate(
up_ft_cur[f_id],
(
up_ft_cur[-1].shape[-2] * up_scale,
up_ft_cur[-1].shape[-1] * up_scale,
),
)
# for base content
loss_con = 0
for f_id in range(len(up_ft_tar_base)):
mask_cur = (1 - mask_base_cur.float()) > 0.5
mask_cur_p = mask_base_cur.float() > 0.5
up_ft_cur_vec = (
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec = (
up_ft_tar_base[f_id][
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_base[f_id].shape[1], -1)
.permute(1, 0)
)
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
loss_con = loss_con + w_content / (1 + 4 * sim.mean())
# for replace content
loss_edit = 0
for f_id in range(len(up_ft_tar_replace)):
mask_cur = mask_base_cur
up_ft_cur_vec = (
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec = (
up_ft_tar_replace[f_id][
mask_replace_cur.repeat(1, up_ft_tar_replace[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_replace[f_id].shape[1], -1)
.permute(1, 0)
)
sim_all = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
loss_edit = loss_edit + w_edit / (1 + 4 * sim_all.mean())
cond_grad_con = torch.autograd.grad(
loss_con * energy_scale, latent, retain_graph=True
)[0]
cond_grad_edit = torch.autograd.grad(loss_edit * energy_scale, latent)[0]
mask = F.interpolate(
mask_base_cur.float(),
(cond_grad_edit[-1].shape[-2], cond_grad_edit[-1].shape[-1]),
)
mask = (mask > 0).to(dtype=latent.dtype)
guidance = (
cond_grad_con.detach() * (1 - mask) * 4e-2
+ cond_grad_edit.detach() * mask * 4e-2
)
self.feature_estimator.zero_grad()
return guidance
def guidance_remove(
self,
mask_base_cur,
mask_replace_cur,
latent,
latent_noise_ref,
t,
up_ft_index,
up_scale,
context,
context_base,
context_replace,
energy_scale,
dict_mask,
w_edit,
w_contrast,
w_content,
):
cos = nn.CosineSimilarity(dim=1)
with torch.no_grad():
up_ft_tar_base = self.feature_estimator(
sample=latent_noise_ref.squeeze(2)[::2],
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_base["embed"],
encoder_hidden_states=context_base["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar_base)):
up_ft_tar_base[f_id] = F.interpolate(
up_ft_tar_base[f_id],
(
up_ft_tar_base[-1].shape[-2] * up_scale,
up_ft_tar_base[-1].shape[-1] * up_scale,
),
)
# up_ft_tar_base[f_id] = normalize_along_channel(up_ft_tar_base[f_id]) # No improvement observed
with torch.no_grad():
up_ft_tar_replace = self.feature_estimator(
sample=latent_noise_ref.squeeze(2)[1::2],
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_replace["embed"],
encoder_hidden_states=context_replace["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context_replace["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar_replace)):
up_ft_tar_replace[f_id] = F.interpolate(
up_ft_tar_replace[f_id],
(
up_ft_tar_replace[-1].shape[-2] * up_scale,
up_ft_tar_replace[-1].shape[-1] * up_scale,
),
)
# up_ft_tar_replace[f_id] = normalize_along_channel(up_ft_tar_replace[f_id])
latent = latent.detach().requires_grad_(True)
up_ft_cur = self.feature_estimator(
sample=latent,
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context["embed"],
encoder_hidden_states=context["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_cur)):
up_ft_cur[f_id] = F.interpolate(
up_ft_cur[f_id],
(
up_ft_cur[-1].shape[-2] * up_scale,
up_ft_cur[-1].shape[-1] * up_scale,
),
)
# up_ft_cur[f_id] = normalize_along_channel(up_ft_cur[f_id])
# for base content
loss_con = 0
for f_id in range(len(up_ft_tar_base)):
mask_cur = (1 - mask_base_cur.float()) > 0.5
up_ft_cur_vec = (
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec = (
up_ft_tar_base[f_id][
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_base[f_id].shape[1], -1)
.permute(1, 0)
)
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
loss_con = loss_con + w_content / (1 + 4 * sim.mean())
# for replace content
loss_edit = 0
for f_id in range(len(up_ft_tar_replace)):
mask_cur = mask_base_cur
# NOTE: Uncomment to get Global time&freq
# up_ft_cur_vec = up_ft_cur[f_id][mask_cur.repeat(1,up_ft_cur[f_id].shape[1],1,1)].view(up_ft_cur[f_id].shape[1], -1).permute(1,0)
# up_ft_tar_vec = up_ft_tar_replace[f_id][mask_replace_cur.repeat(1,up_ft_tar_replace[f_id].shape[1],1,1)].view(up_ft_tar_replace[f_id].shape[1], -1).permute(1,0)
# sim_all=((cos(up_ft_cur_vec.mean(0,keepdim=True), up_ft_tar_vec.mean(0,keepdim=True))+1.)/2.)
# Get a vec along time axis (global time)
up_ft_cur_vec_masked_sum = torch.sum(
up_ft_cur[f_id] * mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1),
dim=-2,
)
up_ft_tar_vec_masked_sum = torch.sum(
up_ft_tar_replace[f_id]
* mask_replace_cur.repeat(1, up_ft_tar_replace[f_id].shape[1], 1, 1),
dim=-2,
) # feature for reference audio
mask_sum = torch.sum(mask_cur, dim=-2)
up_ft_cur_vec = (
(up_ft_cur_vec_masked_sum / (mask_sum + 1e-8))
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec = (
(up_ft_tar_vec_masked_sum / (mask_sum + 1e-8))
.view(up_ft_tar_replace[f_id].shape[1], -1)
.permute(1, 0)
)
sim_edit_contrast = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
# NOTE: begin to modify energy func
up_ft_cur_vec_base = (
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
mask_cur_con = mask_cur
up_ft_tar_vec_base = (
up_ft_tar_base[f_id][
mask_cur_con.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_base[f_id].shape[1], -1)
.permute(1, 0)
)
# # local
# sim_edit_consist = (cos(up_ft_cur_vec_base, up_ft_tar_vec_base) + 1.0) / 2.0
# global
sim_edit_consist=((cos(up_ft_cur_vec_base.mean(0, keepdim=True), up_ft_tar_vec_base.mean(0, keepdim=True))+1.)/2.)
# loss_edit = loss_edit - w_edit/(1+4*sim_all.mean()) + w_edit/(1+4*sim_edit_consist.mean()) # NOTE: decrease sim
# loss_edit = loss_edit + w_edit*sim_all.mean()
# loss_edit = loss_edit - 0.1*w_edit/(1+4*sim_all.mean()) + w_edit/(1+4*sim_edit_consist.mean())
# loss_edit = loss_edit - 0.5*w_edit/(1+4*sim_all.mean()) # NOTE: this only affect local features not semantic
# loss_edit = loss_edit + 0.005*w_edit*sim_all.mean() + w_edit/(1+4*sim_edit_consist.mean()) # NOTE: local
loss_edit = (
loss_edit
+ w_contrast * sim_edit_contrast.mean()
+ w_edit / (1 + 4 * sim_edit_consist.mean())
) # NOTE: local
cond_grad_con = torch.autograd.grad(
loss_con * energy_scale, latent, retain_graph=True
)[0]
cond_grad_edit = torch.autograd.grad(loss_edit * energy_scale, latent)[0]
mask = F.interpolate(
mask_base_cur.float(),
(cond_grad_edit[-1].shape[-2], cond_grad_edit[-1].shape[-1]),
)
mask = (mask > 0).to(dtype=latent.dtype)
guidance = (
cond_grad_con.detach() * (1 - mask) * 4e-2
+ cond_grad_edit.detach() * mask * 4e-2
)
self.feature_estimator.zero_grad()
return guidance
def guidance_style_transfer(
self,
mask_base_cur,
mask_replace_cur,
latent,
latent_noise_ref,
t,
up_ft_index,
up_scale,
context,
context_base,
energy_scale,
dict_mask,
w_edit,
w_content,
):
cos = nn.CosineSimilarity(dim=1)
with torch.no_grad():
up_ft_tar_base = self.feature_estimator(
sample=latent_noise_ref.squeeze(2),
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context_base["embed"],
encoder_hidden_states=context_base["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context_base["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_tar_base)):
up_ft_tar_base[f_id] = F.interpolate(
up_ft_tar_base[f_id],
(
up_ft_tar_base[-1].shape[-2] * up_scale,
up_ft_tar_base[-1].shape[-1] * up_scale,
),
)
latent = latent.detach().requires_grad_(True)
up_ft_cur = self.feature_estimator(
sample=latent,
timestep=t,
up_ft_indices=up_ft_index,
class_labels=None if self.use_cross_attn else context["embed"],
encoder_hidden_states=context["embed"] if self.use_cross_attn else None,
encoder_attention_mask=context["mask"] if self.use_cross_attn else None,
)["up_ft"]
for f_id in range(len(up_ft_cur)):
up_ft_cur[f_id] = F.interpolate(
up_ft_cur[f_id],
(up_ft_cur[-1].shape[-2] * up_scale, up_ft_cur[-1].shape[-1] * up_scale),
)
# for base content
loss_con = 0
for f_id in range(len(up_ft_tar_base)):
mask_cur = (1 - mask_base_cur.float()) > 0.5
mask_cur_p = mask_base_cur.float() > 0.5
up_ft_cur_vec = (
up_ft_cur[f_id][mask_cur.repeat(1, up_ft_cur[f_id].shape[1], 1, 1)]
.view(up_ft_cur[f_id].shape[1], -1)
.permute(1, 0)
)
up_ft_tar_vec = (
up_ft_tar_base[f_id][
mask_cur.repeat(1, up_ft_tar_base[f_id].shape[1], 1, 1)
]
.view(up_ft_tar_base[f_id].shape[1], -1)
.permute(1, 0)
)
sim = (cos(up_ft_cur_vec, up_ft_tar_vec) + 1.0) / 2.0
loss_con = loss_con + w_content / (1 + 4 * sim.mean())
cond_grad_con = torch.autograd.grad(
loss_con * energy_scale, latent, retain_graph=True
)[0]
mask = F.interpolate(
mask_base_cur.float(),
(cond_grad_con[-1].shape[-2], cond_grad_con[-1].shape[-1]),
)
mask = (mask > 0).to(dtype=latent.dtype)
guidance = cond_grad_con.detach() * (1 - mask) * 4e-2
self.feature_estimator.zero_grad()
return guidance
def _encode_text(self, text_input: str):
text_input = self.tokenizer(
[text_input],
padding="max_length",
max_length=self.tokenizer.model_max_length, # NOTE 77
truncation=True,
return_tensors="pt",
)
input_ids, attn_mask = text_input.input_ids.to(self._device), text_input.attention_mask.to(self._device)
text_embeddings = self.text_encoder(input_ids,attention_mask=attn_mask)[0]
text_embeddings = F.normalize(text_embeddings, dim=-1)
boolean_attn_mask = (attn_mask == 1).to(self._device)
return {"embed": text_embeddings, "mask": boolean_attn_mask}
def _stack_text(self, text_input0, text_input1):
text_embs0, text_mask0 = text_input0["embed"], text_input0["mask"]
text_embs1, text_mask1 = text_input1["embed"], text_input1["mask"]
out_embs = torch.cat(
[text_embs0.expand(*text_embs1.shape), text_embs1]
)
out_mask = torch.cat(
[text_mask0.expand(*text_mask1.shape), text_mask1]
)
return {"embed": out_embs, "mask": out_mask}
# class AudioLDMSampler(_Sampler):
# def __init__(
# self,
# vae: AutoencoderKL,
# tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
# text_encoder: ClapTextModelWithProjection,
# unet: UNet2DConditionModel,
# feature_estimator: UNet2DConditionModel,
# scheduler: DDIMScheduler,
# vocoder: SpeechT5HifiGan,
# device: torch.device = torch.device("cpu"),
# precision: torch.dtype = torch.float32,
# ):
# super().__init__()
# self.register_modules(
# vae=vae,
# tokenizer=tokenizer,
# text_encoder=text_encoder,
# unet=unet,
# estimator=feature_estimator,
# scheduler=scheduler,
# vocoder=vocoder,
# )
# self = self.to(device=device, dtype=precision)
# class TangoSampler(_Sampler):
# def __init__(
# self,
# vae: AutoencoderKL,
# text_encoder: ClapTextModelWithProjection,
# tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
# unet: UNet2DConditionModel,
# scheduler: KarrasDiffusionSchedulers,
# vocoder: SpeechT5HifiGan,
# ):
# super().__init__()
# self.register_modules(
# vae=vae,
# text_encoder=text_encoder,
# tokenizer=tokenizer,
# unet=unet,
# scheduler=scheduler,
# vocoder=vocoder,
# )
# self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)