Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from typing import Optional, Union, Tuple, List, Callable, Dict | |
class DDIMInversion: | |
def __init__(self, model, NUM_DDIM_STEPS): | |
self.model = model | |
self.model.scheduler.set_timesteps(NUM_DDIM_STEPS) | |
self.NUM_DDIM_STEPS = NUM_DDIM_STEPS | |
self.prompt = None | |
def next_step( | |
self, | |
model_output: Union[torch.FloatTensor, np.ndarray], | |
timestep: int, | |
sample: Union[torch.FloatTensor, np.ndarray], | |
prediction_type: str = "v_prediction", | |
): | |
timestep, next_timestep = ( | |
min( | |
timestep | |
- self.scheduler.config.num_train_timesteps | |
// self.scheduler.num_inference_steps, | |
999, | |
), | |
timestep, | |
) | |
alpha_prod_t = ( | |
self.scheduler.alphas_cumprod[timestep] | |
if timestep >= 0 | |
else self.scheduler.final_alpha_cumprod | |
) | |
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] | |
beta_prod_t = 1 - alpha_prod_t | |
if prediction_type == "epsilon": | |
next_original_sample = ( | |
sample - beta_prod_t**0.5 * model_output | |
) / alpha_prod_t**0.5 | |
next_epsilon = model_output | |
elif prediction_type == "v_prediction": | |
next_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | |
next_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | |
else: | |
raise ValueError( | |
f"prediction_type given as {prediction_type} must be one of `epsilon` or" | |
" `v_prediction`" | |
) | |
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * next_epsilon | |
next_sample = ( | |
alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction | |
) | |
return next_sample | |
def get_noise_pred_single( | |
self, latents, t, cond_embeddings, cond_masks, iter_cur, save_kv=True, mode="drag" | |
): | |
boolean_cond_masks = (cond_masks == 1).to(cond_masks.device) | |
try: | |
noise_pred = self.model.unet( | |
latents, | |
t, | |
encoder_hidden_states=( | |
cond_embeddings if self.model.use_cross_attn else None | |
), | |
class_labels=None if self.model.use_cross_attn else cond_embeddings, | |
encoder_attention_mask=boolean_cond_masks if self.model.use_cross_attn else None, | |
iter_cur=iter_cur, | |
mode=mode, | |
save_kv=save_kv, | |
)["sample"] | |
except TypeError as e: | |
print(f"Warning: {e}") | |
noise_pred = self.model.unet( | |
latents, | |
t, | |
encoder_hidden_states=( | |
cond_embeddings if self.model.use_cross_attn else None | |
), | |
class_labels=None if self.model.use_cross_attn else cond_embeddings, | |
encoder_attention_mask=boolean_cond_masks if self.model.use_cross_attn else None, | |
)["sample"] | |
return noise_pred | |
def init_prompt(self, prompt: str, emb_im=None): | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
device = self.model.text_encoder.device | |
if not isinstance(prompt, list): | |
prompt = [prompt] | |
text_input = self.model.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
input_ids, attn_masks = text_input.input_ids.to(device), text_input.attention_mask.to(device) | |
text_embeddings = self.model.text_encoder( | |
input_ids, attention_mask=attn_masks, | |
)[0] | |
text_embeddings = F.normalize(text_embeddings, dim=-1) | |
if emb_im is not None: | |
raise NotImplementedError | |
self.text_embeddings = torch.cat([text_embeddings, emb_im], dim=1) | |
else: | |
self.text_embeddings = text_embeddings | |
self.text_masks = attn_masks | |
self.prompt = prompt | |
def ddim_loop(self, latent, save_kv=True, mode="drag", prediction_type="v_prediction"): | |
cond_embeddings = self.text_embeddings | |
cond_masks = self.text_masks | |
all_latent = [latent] | |
latent = latent.clone().detach() | |
print("DDIM Inversion:") | |
for i in tqdm(range(self.NUM_DDIM_STEPS)): | |
t = self.model.scheduler.timesteps[ | |
len(self.model.scheduler.timesteps) - i - 1 | |
] | |
noise_pred = self.get_noise_pred_single( | |
latent, | |
t, | |
cond_embeddings, | |
cond_masks, | |
iter_cur=len(self.model.scheduler.timesteps) - i - 1, | |
save_kv=save_kv, | |
mode=mode, | |
) | |
latent = self.next_step(noise_pred, t, latent, prediction_type=prediction_type) | |
all_latent.append(latent) | |
return all_latent | |
def scheduler(self): | |
return self.model.scheduler | |
def invert(self, ddim_latents, prompt: str, emb_im=None, save_kv=True, mode="drag", prediction_type="v_prediction"): | |
self.init_prompt(prompt, emb_im=emb_im) | |
ddim_latents = self.ddim_loop(ddim_latents, save_kv=save_kv, mode=mode, prediction_type=prediction_type) | |
return ddim_latents | |