Spaces:
Paused
Paused
import json | |
import os | |
from collections import OrderedDict | |
import safetensors | |
import torch | |
from typing import TYPE_CHECKING | |
from safetensors.torch import save_file | |
from toolkit.metadata import get_meta_for_safetensors | |
if TYPE_CHECKING: | |
from toolkit.stable_diffusion_model import StableDiffusion | |
from toolkit.config_modules import EmbeddingConfig | |
# this is a frankenstein mix of automatic1111 and my own code | |
class Embedding: | |
def __init__( | |
self, | |
sd: 'StableDiffusion', | |
embed_config: 'EmbeddingConfig', | |
state_dict: OrderedDict = None, | |
): | |
self.name = embed_config.trigger | |
self.sd = sd | |
self.trigger = embed_config.trigger | |
self.embed_config = embed_config | |
self.step = 0 | |
# setup our embedding | |
# Add the placeholder token in tokenizer | |
placeholder_tokens = [self.embed_config.trigger] | |
# add dummy tokens for multi-vector | |
additional_tokens = [] | |
for i in range(1, self.embed_config.tokens): | |
additional_tokens.append(f"{self.embed_config.trigger}_{i}") | |
placeholder_tokens += additional_tokens | |
# handle dual tokenizer | |
self.tokenizer_list = self.sd.tokenizer if isinstance(self.sd.tokenizer, list) else [self.sd.tokenizer] | |
self.text_encoder_list = self.sd.text_encoder if isinstance(self.sd.text_encoder, list) else [ | |
self.sd.text_encoder] | |
self.placeholder_token_ids = [] | |
self.embedding_tokens = [] | |
print(f"Adding {placeholder_tokens} tokens to tokenizer") | |
print(f"Adding {self.embed_config.tokens} tokens to tokenizer") | |
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): | |
num_added_tokens = tokenizer.add_tokens(placeholder_tokens) | |
if num_added_tokens != self.embed_config.tokens: | |
raise ValueError( | |
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different" | |
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}" | |
) | |
# Convert the initializer_token, placeholder_token to ids | |
init_token_ids = tokenizer.encode(self.embed_config.init_words, add_special_tokens=False) | |
# if length of token ids is more than number of orm embedding tokens fill with * | |
if len(init_token_ids) > self.embed_config.tokens: | |
init_token_ids = init_token_ids[:self.embed_config.tokens] | |
elif len(init_token_ids) < self.embed_config.tokens: | |
pad_token_id = tokenizer.encode(["*"], add_special_tokens=False) | |
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids)) | |
placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False) | |
self.placeholder_token_ids.append(placeholder_token_ids) | |
# Resize the token embeddings as we are adding new special tokens to the tokenizer | |
text_encoder.resize_token_embeddings(len(tokenizer)) | |
# Initialise the newly added placeholder token with the embeddings of the initializer token | |
token_embeds = text_encoder.get_input_embeddings().weight.data | |
with torch.no_grad(): | |
for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids): | |
token_embeds[token_id] = token_embeds[initializer_token_id].clone() | |
# replace "[name] with this. on training. This is automatically generated in pipeline on inference | |
self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))) | |
# backup text encoder embeddings | |
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list] | |
def restore_embeddings(self): | |
with torch.no_grad(): | |
# Let's make sure we don't update any embedding weights besides the newly added token | |
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list, | |
self.tokenizer_list, | |
self.orig_embeds_params, | |
self.placeholder_token_ids): | |
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) | |
index_no_updates[ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False | |
text_encoder.get_input_embeddings().weight[ | |
index_no_updates | |
] = orig_embeds[index_no_updates] | |
weight = text_encoder.get_input_embeddings().weight | |
pass | |
def get_trainable_params(self): | |
params = [] | |
for text_encoder in self.text_encoder_list: | |
params += text_encoder.get_input_embeddings().parameters() | |
return params | |
def _get_vec(self, text_encoder_idx=0): | |
# should we get params instead | |
# create vector from token embeds | |
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data | |
# stack the tokens along batch axis adding that axis | |
new_vector = torch.stack( | |
[token_embeds[token_id] for token_id in self.placeholder_token_ids[text_encoder_idx]], | |
dim=0 | |
) | |
return new_vector | |
def _set_vec(self, new_vector, text_encoder_idx=0): | |
# shape is (1, 768) for SD 1.5 for 1 token | |
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data | |
for i in range(new_vector.shape[0]): | |
# apply the weights to the placeholder tokens while preserving gradient | |
token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone() | |
# make setter and getter for vec | |
def vec(self): | |
return self._get_vec(0) | |
def vec(self, new_vector): | |
self._set_vec(new_vector, 0) | |
def vec2(self): | |
return self._get_vec(1) | |
def vec2(self, new_vector): | |
self._set_vec(new_vector, 1) | |
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc | |
# however, on training we don't use that pipeline, so we have to do it ourselves | |
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): | |
output_prompt = prompt | |
embedding_tokens = self.embedding_tokens[0] # shoudl be the same | |
default_replacements = ["[name]", "[trigger]"] | |
replace_with = embedding_tokens if expand_token else self.trigger | |
if to_replace_list is None: | |
to_replace_list = default_replacements | |
else: | |
to_replace_list += default_replacements | |
# remove duplicates | |
to_replace_list = list(set(to_replace_list)) | |
# replace them all | |
for to_replace in to_replace_list: | |
# replace it | |
output_prompt = output_prompt.replace(to_replace, replace_with) | |
# see how many times replace_with is in the prompt | |
num_instances = output_prompt.count(replace_with) | |
if num_instances == 0 and add_if_not_present: | |
# add it to the beginning of the prompt | |
output_prompt = replace_with + " " + output_prompt | |
if num_instances > 1: | |
print( | |
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") | |
return output_prompt | |
def state_dict(self): | |
if self.sd.is_xl: | |
state_dict = OrderedDict() | |
state_dict['clip_l'] = self.vec | |
state_dict['clip_g'] = self.vec2 | |
else: | |
state_dict = OrderedDict() | |
state_dict['emb_params'] = self.vec | |
return state_dict | |
def save(self, filename): | |
# todo check to see how to get the vector out of the embedding | |
embedding_data = { | |
"string_to_token": {"*": 265}, | |
"string_to_param": {"*": self.vec}, | |
"name": self.name, | |
"step": self.step, | |
# todo get these | |
"sd_checkpoint": None, | |
"sd_checkpoint_name": None, | |
"notes": None, | |
} | |
# TODO we do not currently support this. Check how auto is doing it. Only safetensors supported sor sdxl | |
if filename.endswith('.pt'): | |
torch.save(embedding_data, filename) | |
elif filename.endswith('.bin'): | |
torch.save(embedding_data, filename) | |
elif filename.endswith('.safetensors'): | |
# save the embedding as a safetensors file | |
state_dict = self.state_dict() | |
# add all embedding data (except string_to_param), to metadata | |
metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"}) | |
metadata["string_to_param"] = {"*": "emb_params"} | |
save_meta = get_meta_for_safetensors(metadata, name=self.name) | |
save_file(state_dict, filename, metadata=save_meta) | |
def load_embedding_from_file(self, file_path, device): | |
# full path | |
path = os.path.realpath(file_path) | |
filename = os.path.basename(path) | |
name, ext = os.path.splitext(filename) | |
tensors = {} | |
ext = ext.upper() | |
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: | |
_, second_ext = os.path.splitext(name) | |
if second_ext.upper() == '.PREVIEW': | |
return | |
if ext in ['.BIN', '.PT']: | |
# todo check this | |
if self.sd.is_xl: | |
raise Exception("XL not supported yet for bin, pt") | |
data = torch.load(path, map_location="cpu") | |
elif ext in ['.SAFETENSORS']: | |
# rebuild the embedding from the safetensors file if it has it | |
with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f: | |
metadata = f.metadata() | |
for k in f.keys(): | |
tensors[k] = f.get_tensor(k) | |
# data = safetensors.torch.load_file(path, device="cpu") | |
if metadata and 'string_to_param' in metadata and 'emb_params' in tensors: | |
# our format | |
def try_json(v): | |
try: | |
return json.loads(v) | |
except: | |
return v | |
data = {k: try_json(v) for k, v in metadata.items()} | |
data['string_to_param'] = {'*': tensors['emb_params']} | |
else: | |
# old format | |
data = tensors | |
else: | |
return | |
if self.sd.is_xl: | |
self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32) | |
self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32) | |
if 'step' in data: | |
self.step = int(data['step']) | |
else: | |
# textual inversion embeddings | |
if 'string_to_param' in data: | |
param_dict = data['string_to_param'] | |
if hasattr(param_dict, '_parameters'): | |
param_dict = getattr(param_dict, | |
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 | |
assert len(param_dict) == 1, 'embedding file has multiple terms in it' | |
emb = next(iter(param_dict.items()))[1] | |
# diffuser concepts | |
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: | |
assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | |
emb = next(iter(data.values())) | |
if len(emb.shape) == 1: | |
emb = emb.unsqueeze(0) | |
else: | |
raise Exception( | |
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") | |
if 'step' in data: | |
self.step = int(data['step']) | |
self.vec = emb.detach().to(device, dtype=torch.float32) | |