AudioMorphix / src /utils /inversion.py
JinhuaL1ANG's picture
v1
9a6dac6
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from typing import Optional, Union, Tuple, List, Callable, Dict
class DDIMInversion:
def __init__(self, model, NUM_DDIM_STEPS):
self.model = model
self.model.scheduler.set_timesteps(NUM_DDIM_STEPS)
self.NUM_DDIM_STEPS = NUM_DDIM_STEPS
self.prompt = None
def next_step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
prediction_type: str = "v_prediction",
):
timestep, next_timestep = (
min(
timestep
- self.scheduler.config.num_train_timesteps
// self.scheduler.num_inference_steps,
999,
),
timestep,
)
alpha_prod_t = (
self.scheduler.alphas_cumprod[timestep]
if timestep >= 0
else self.scheduler.final_alpha_cumprod
)
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
beta_prod_t = 1 - alpha_prod_t
if prediction_type == "epsilon":
next_original_sample = (
sample - beta_prod_t**0.5 * model_output
) / alpha_prod_t**0.5
next_epsilon = model_output
elif prediction_type == "v_prediction":
next_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
next_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {prediction_type} must be one of `epsilon` or"
" `v_prediction`"
)
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * next_epsilon
next_sample = (
alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
)
return next_sample
def get_noise_pred_single(
self, latents, t, cond_embeddings, cond_masks, iter_cur, save_kv=True, mode="drag"
):
boolean_cond_masks = (cond_masks == 1).to(cond_masks.device)
try:
noise_pred = self.model.unet(
latents,
t,
encoder_hidden_states=(
cond_embeddings if self.model.use_cross_attn else None
),
class_labels=None if self.model.use_cross_attn else cond_embeddings,
encoder_attention_mask=boolean_cond_masks if self.model.use_cross_attn else None,
iter_cur=iter_cur,
mode=mode,
save_kv=save_kv,
)["sample"]
except TypeError as e:
print(f"Warning: {e}")
noise_pred = self.model.unet(
latents,
t,
encoder_hidden_states=(
cond_embeddings if self.model.use_cross_attn else None
),
class_labels=None if self.model.use_cross_attn else cond_embeddings,
encoder_attention_mask=boolean_cond_masks if self.model.use_cross_attn else None,
)["sample"]
return noise_pred
@torch.no_grad()
def init_prompt(self, prompt: str, emb_im=None):
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = self.model.text_encoder.device
if not isinstance(prompt, list):
prompt = [prompt]
text_input = self.model.tokenizer(
prompt,
padding="max_length",
max_length=self.model.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
input_ids, attn_masks = text_input.input_ids.to(device), text_input.attention_mask.to(device)
text_embeddings = self.model.text_encoder(
input_ids, attention_mask=attn_masks,
)[0]
text_embeddings = F.normalize(text_embeddings, dim=-1)
if emb_im is not None:
raise NotImplementedError
self.text_embeddings = torch.cat([text_embeddings, emb_im], dim=1)
else:
self.text_embeddings = text_embeddings
self.text_masks = attn_masks
self.prompt = prompt
@torch.no_grad()
def ddim_loop(self, latent, save_kv=True, mode="drag", prediction_type="v_prediction"):
cond_embeddings = self.text_embeddings
cond_masks = self.text_masks
all_latent = [latent]
latent = latent.clone().detach()
print("DDIM Inversion:")
for i in tqdm(range(self.NUM_DDIM_STEPS)):
t = self.model.scheduler.timesteps[
len(self.model.scheduler.timesteps) - i - 1
]
noise_pred = self.get_noise_pred_single(
latent,
t,
cond_embeddings,
cond_masks,
iter_cur=len(self.model.scheduler.timesteps) - i - 1,
save_kv=save_kv,
mode=mode,
)
latent = self.next_step(noise_pred, t, latent, prediction_type=prediction_type)
all_latent.append(latent)
return all_latent
@property
def scheduler(self):
return self.model.scheduler
def invert(self, ddim_latents, prompt: str, emb_im=None, save_kv=True, mode="drag", prediction_type="v_prediction"):
self.init_prompt(prompt, emb_im=emb_im)
ddim_latents = self.ddim_loop(ddim_latents, save_kv=save_kv, mode=mode, prediction_type=prediction_type)
return ddim_latents