Spaces:
Paused
Paused
from typing import TYPE_CHECKING, Mapping, Any | |
import torch | |
import weakref | |
from toolkit.config_modules import AdapterConfig | |
from toolkit.models.clip_fusion import ZipperBlock | |
from toolkit.models.zipper_resampler import ZipperModule | |
from toolkit.prompt_utils import PromptEmbeds | |
from toolkit.train_tools import get_torch_dtype | |
if TYPE_CHECKING: | |
from toolkit.stable_diffusion_model import StableDiffusion | |
from transformers import ( | |
CLIPImageProcessor, | |
CLIPVisionModelWithProjection, | |
CLIPVisionModel | |
) | |
from toolkit.resampler import Resampler | |
import torch.nn as nn | |
class Embedder(nn.Module): | |
def __init__( | |
self, | |
num_input_tokens: int = 1, | |
input_dim: int = 1024, | |
num_output_tokens: int = 8, | |
output_dim: int = 768, | |
mid_dim: int = 1024 | |
): | |
super(Embedder, self).__init__() | |
self.num_output_tokens = num_output_tokens | |
self.num_input_tokens = num_input_tokens | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.layer_norm = nn.LayerNorm(input_dim) | |
self.fc1 = nn.Linear(input_dim, mid_dim) | |
self.gelu = nn.GELU() | |
# self.fc2 = nn.Linear(mid_dim, mid_dim) | |
self.fc2 = nn.Linear(mid_dim, mid_dim) | |
self.fc2.weight.data.zero_() | |
self.layer_norm2 = nn.LayerNorm(mid_dim) | |
self.fc3 = nn.Linear(mid_dim, mid_dim) | |
self.gelu2 = nn.GELU() | |
self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens) | |
# set the weights to 0 | |
self.fc3.weight.data.zero_() | |
self.fc4.weight.data.zero_() | |
# self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim)) | |
# self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim)) | |
def forward(self, x): | |
if len(x.shape) == 2: | |
x = x.unsqueeze(1) | |
x = self.layer_norm(x) | |
x = self.fc1(x) | |
x = self.gelu(x) | |
x = self.fc2(x) | |
x = self.layer_norm2(x) | |
x = self.fc3(x) | |
x = self.gelu2(x) | |
x = self.fc4(x) | |
x = x.view(-1, self.num_output_tokens, self.output_dim) | |
return x | |
class ClipVisionAdapter(torch.nn.Module): | |
def __init__(self, sd: 'StableDiffusion', adapter_config: AdapterConfig): | |
super().__init__() | |
self.config = adapter_config | |
self.trigger = adapter_config.trigger | |
self.trigger_class_name = adapter_config.trigger_class_name | |
self.sd_ref: weakref.ref = weakref.ref(sd) | |
# embedding stuff | |
self.text_encoder_list = sd.text_encoder if isinstance(sd.text_encoder, list) else [sd.text_encoder] | |
self.tokenizer_list = sd.tokenizer if isinstance(sd.tokenizer, list) else [sd.tokenizer] | |
placeholder_tokens = [self.trigger] | |
# add dummy tokens for multi-vector | |
additional_tokens = [] | |
for i in range(1, self.config.num_tokens): | |
additional_tokens.append(f"{self.trigger}_{i}") | |
placeholder_tokens += additional_tokens | |
# handle dual tokenizer | |
self.tokenizer_list = self.sd_ref().tokenizer if isinstance(self.sd_ref().tokenizer, list) else [ | |
self.sd_ref().tokenizer] | |
self.text_encoder_list = self.sd_ref().text_encoder if isinstance(self.sd_ref().text_encoder, list) else [ | |
self.sd_ref().text_encoder] | |
self.placeholder_token_ids = [] | |
self.embedding_tokens = [] | |
print(f"Adding {placeholder_tokens} tokens to tokenizer") | |
print(f"Adding {self.config.num_tokens} tokens to tokenizer") | |
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): | |
num_added_tokens = tokenizer.add_tokens(placeholder_tokens) | |
if num_added_tokens != self.config.num_tokens: | |
raise ValueError( | |
f"The tokenizer already contains the token {self.trigger}. Please pass a different" | |
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}" | |
) | |
# Convert the initializer_token, placeholder_token to ids | |
init_token_ids = tokenizer.encode(self.config.trigger_class_name, add_special_tokens=False) | |
# if length of token ids is more than number of orm embedding tokens fill with * | |
if len(init_token_ids) > self.config.num_tokens: | |
init_token_ids = init_token_ids[:self.config.num_tokens] | |
elif len(init_token_ids) < self.config.num_tokens: | |
pad_token_id = tokenizer.encode(["*"], add_special_tokens=False) | |
init_token_ids += pad_token_id * (self.config.num_tokens - len(init_token_ids)) | |
placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False) | |
self.placeholder_token_ids.append(placeholder_token_ids) | |
# Resize the token embeddings as we are adding new special tokens to the tokenizer | |
text_encoder.resize_token_embeddings(len(tokenizer)) | |
# Initialise the newly added placeholder token with the embeddings of the initializer token | |
token_embeds = text_encoder.get_input_embeddings().weight.data | |
with torch.no_grad(): | |
for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids): | |
token_embeds[token_id] = token_embeds[initializer_token_id].clone() | |
# replace "[name] with this. on training. This is automatically generated in pipeline on inference | |
self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))) | |
# backup text encoder embeddings | |
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list] | |
try: | |
self.clip_image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path) | |
except EnvironmentError: | |
self.clip_image_processor = CLIPImageProcessor() | |
self.device = self.sd_ref().unet.device | |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
self.config.image_encoder_path, | |
ignore_mismatched_sizes=True | |
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) | |
if self.config.train_image_encoder: | |
self.image_encoder.train() | |
else: | |
self.image_encoder.eval() | |
# max_seq_len = CLIP tokens + CLS token | |
image_encoder_state_dict = self.image_encoder.state_dict() | |
in_tokens = 257 | |
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: | |
# clip | |
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) | |
if hasattr(self.image_encoder.config, 'hidden_sizes'): | |
embedding_dim = self.image_encoder.config.hidden_sizes[-1] | |
else: | |
embedding_dim = self.image_encoder.config.target_hidden_size | |
if self.config.clip_layer == 'image_embeds': | |
in_tokens = 1 | |
embedding_dim = self.image_encoder.config.projection_dim | |
self.embedder = Embedder( | |
num_output_tokens=self.config.num_tokens, | |
num_input_tokens=in_tokens, | |
input_dim=embedding_dim, | |
output_dim=self.sd_ref().unet.config['cross_attention_dim'], | |
mid_dim=embedding_dim * self.config.num_tokens, | |
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) | |
self.embedder.train() | |
def state_dict(self, *args, destination=None, prefix='', keep_vars=False): | |
state_dict = { | |
'embedder': self.embedder.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) | |
} | |
if self.config.train_image_encoder: | |
state_dict['image_encoder'] = self.image_encoder.state_dict( | |
*args, destination=destination, prefix=prefix, | |
keep_vars=keep_vars) | |
return state_dict | |
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): | |
self.embedder.load_state_dict(state_dict["embedder"], strict=strict) | |
if self.config.train_image_encoder and 'image_encoder' in state_dict: | |
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) | |
def parameters(self, *args, **kwargs): | |
yield from self.embedder.parameters(*args, **kwargs) | |
def named_parameters(self, *args, **kwargs): | |
yield from self.embedder.named_parameters(*args, **kwargs) | |
def get_clip_image_embeds_from_tensors( | |
self, tensors_0_1: torch.Tensor, drop=False, | |
is_training=False, | |
has_been_preprocessed=False | |
) -> torch.Tensor: | |
with torch.no_grad(): | |
if not has_been_preprocessed: | |
# tensors should be 0-1 | |
if tensors_0_1.ndim == 3: | |
tensors_0_1 = tensors_0_1.unsqueeze(0) | |
# training tensors are 0 - 1 | |
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) | |
# if images are out of this range throw error | |
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: | |
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( | |
tensors_0_1.min(), tensors_0_1.max() | |
)) | |
# unconditional | |
if drop: | |
if self.clip_noise_zero: | |
tensors_0_1 = torch.rand_like(tensors_0_1).detach() | |
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, | |
dtype=get_torch_dtype(self.sd_ref().dtype)) | |
tensors_0_1 = tensors_0_1 * noise_scale | |
else: | |
tensors_0_1 = torch.zeros_like(tensors_0_1).detach() | |
# tensors_0_1 = tensors_0_1 * 0 | |
clip_image = self.clip_image_processor( | |
images=tensors_0_1, | |
return_tensors="pt", | |
do_resize=True, | |
do_rescale=False, | |
).pixel_values | |
else: | |
if drop: | |
# scale the noise down | |
if self.clip_noise_zero: | |
tensors_0_1 = torch.rand_like(tensors_0_1).detach() | |
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, | |
dtype=get_torch_dtype(self.sd_ref().dtype)) | |
tensors_0_1 = tensors_0_1 * noise_scale | |
else: | |
tensors_0_1 = torch.zeros_like(tensors_0_1).detach() | |
# tensors_0_1 = tensors_0_1 * 0 | |
mean = torch.tensor(self.clip_image_processor.image_mean).to( | |
self.device, dtype=get_torch_dtype(self.sd_ref().dtype) | |
).detach() | |
std = torch.tensor(self.clip_image_processor.image_std).to( | |
self.device, dtype=get_torch_dtype(self.sd_ref().dtype) | |
).detach() | |
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 | |
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) | |
else: | |
clip_image = tensors_0_1 | |
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() | |
with torch.set_grad_enabled(is_training): | |
if is_training: | |
self.image_encoder.train() | |
else: | |
self.image_encoder.eval() | |
clip_output = self.image_encoder(clip_image, output_hidden_states=True) | |
if self.config.clip_layer == 'penultimate_hidden_states': | |
# they skip last layer for ip+ | |
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26 | |
clip_image_embeds = clip_output.hidden_states[-2] | |
elif self.config.clip_layer == 'last_hidden_state': | |
clip_image_embeds = clip_output.hidden_states[-1] | |
else: | |
clip_image_embeds = clip_output.image_embeds | |
return clip_image_embeds | |
import torch | |
def set_vec(self, new_vector, text_encoder_idx=0): | |
# Get the embedding layer | |
embedding_layer = self.text_encoder_list[text_encoder_idx].get_input_embeddings() | |
# Indices to replace in the embeddings | |
indices_to_replace = self.placeholder_token_ids[text_encoder_idx] | |
# Replace the specified embeddings with new_vector | |
for idx in indices_to_replace: | |
vector_idx = idx - indices_to_replace[0] | |
embedding_layer.weight[idx] = new_vector[vector_idx] | |
# adds it to the tokenizer | |
def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds: | |
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) | |
if clip_image_embeds.ndim == 2: | |
# expand the token dimension | |
clip_image_embeds = clip_image_embeds.unsqueeze(1) | |
image_prompt_embeds = self.embedder(clip_image_embeds) | |
# todo add support for multiple batch sizes | |
if image_prompt_embeds.shape[0] != 1: | |
raise ValueError("Batch size must be 1 for embedder for now") | |
# output on sd1.5 is bs, num_tokens, 768 | |
if len(self.text_encoder_list) == 1: | |
# add it to the text encoder | |
self.set_vec(image_prompt_embeds[0], text_encoder_idx=0) | |
elif len(self.text_encoder_list) == 2: | |
if self.text_encoder_list[0].config.target_hidden_size + self.text_encoder_list[1].config.target_hidden_size != \ | |
image_prompt_embeds.shape[2]: | |
raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes") | |
# sdxl variants | |
# image_prompt_embeds = 2048 | |
# te1 = 768 | |
# te2 = 1280 | |
te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.target_hidden_size] | |
te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.target_hidden_size:] | |
self.set_vec(te1_embeds[0], text_encoder_idx=0) | |
self.set_vec(te2_embeds[0], text_encoder_idx=1) | |
else: | |
raise ValueError("Unsupported number of text encoders") | |
# just a place to put a breakpoint | |
pass | |
def restore_embeddings(self): | |
# Let's make sure we don't update any embedding weights besides the newly added token | |
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip( | |
self.text_encoder_list, | |
self.tokenizer_list, | |
self.orig_embeds_params, | |
self.placeholder_token_ids | |
): | |
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) | |
index_no_updates[ | |
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False | |
with torch.no_grad(): | |
text_encoder.get_input_embeddings().weight[ | |
index_no_updates | |
] = orig_embeds[index_no_updates] | |
# detach it all | |
text_encoder.get_input_embeddings().weight.detach_() | |
def enable_gradient_checkpointing(self): | |
self.image_encoder.gradient_checkpointing = True | |
def inject_trigger_into_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): | |
output_prompt = prompt | |
embedding_tokens = self.embedding_tokens[0] # shoudl be the same | |
default_replacements = ["[name]", "[trigger]"] | |
replace_with = embedding_tokens if expand_token else self.trigger | |
if to_replace_list is None: | |
to_replace_list = default_replacements | |
else: | |
to_replace_list += default_replacements | |
# remove duplicates | |
to_replace_list = list(set(to_replace_list)) | |
# replace them all | |
for to_replace in to_replace_list: | |
# replace it | |
output_prompt = output_prompt.replace(to_replace, replace_with) | |
# see how many times replace_with is in the prompt | |
num_instances = output_prompt.count(replace_with) | |
if num_instances == 0 and add_if_not_present: | |
# add it to the beginning of the prompt | |
output_prompt = replace_with + " " + output_prompt | |
if num_instances > 1: | |
print( | |
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") | |
return output_prompt | |
# reverses injection with class name. useful for normalizations | |
def inject_trigger_class_name_into_prompt(self, prompt): | |
output_prompt = prompt | |
embedding_tokens = self.embedding_tokens[0] # shoudl be the same | |
default_replacements = ["[name]", "[trigger]", embedding_tokens, self.trigger] | |
replace_with = self.config.trigger_class_name | |
to_replace_list = default_replacements | |
# remove duplicates | |
to_replace_list = list(set(to_replace_list)) | |
# replace them all | |
for to_replace in to_replace_list: | |
# replace it | |
output_prompt = output_prompt.replace(to_replace, replace_with) | |
# see how many times replace_with is in the prompt | |
num_instances = output_prompt.count(replace_with) | |
if num_instances > 1: | |
print( | |
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") | |
return output_prompt | |