test / modules /textual_inversion /textual_inversion.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
from typing import List, Union
import os
import time
from collections import namedtuple
import torch
import safetensors.torch
from PIL import Image
from modules import shared, devices, sd_models, errors
from modules.textual_inversion.image_embedding import embedding_from_b64, extract_image_data_embed
from modules.files_cache import directory_files, directory_mtime, extension_filter
debug = shared.log.trace if os.environ.get('SD_TI_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: TEXTUAL INVERSION')
TokenToAdd = namedtuple("TokenToAdd", ["clip_l", "clip_g"])
TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
textual_inversion_templates = {}
def list_textual_inversion_templates():
textual_inversion_templates.clear()
for root, _dirs, fns in os.walk(shared.opts.embeddings_templates_dir):
for fn in fns:
path = os.path.join(root, fn)
textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
return textual_inversion_templates
def list_embeddings(*dirs):
is_ext = extension_filter(['.SAFETENSORS', '.PT' ] + ( ['.PNG', '.WEBP', '.JXL', '.AVIF', '.BIN' ] if shared.backend != shared.Backend.DIFFUSERS else [] ))
is_not_preview = lambda fp: not next(iter(os.path.splitext(fp))).upper().endswith('.PREVIEW') # pylint: disable=unnecessary-lambda-assignment
return list(filter(lambda fp: is_ext(fp) and is_not_preview(fp) and os.stat(fp).st_size > 0, directory_files(*dirs)))
class Embedding:
def __init__(self, vec, name, filename=None, step=None):
self.vec = vec
self.name = name
self.tag = name
self.step = step
self.filename = filename
self.basename = os.path.relpath(filename, shared.opts.embeddings_dir) if filename is not None else None
self.shape = None
self.vectors = 0
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
"sd_checkpoint": self.sd_checkpoint,
"sd_checkpoint_name": self.sd_checkpoint_name,
}
torch.save(embedding_data, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
optimizer_saved_dict = {
'hash': self.checksum(),
'optimizer_state_dict': self.optimizer_state_dict,
}
torch.save(optimizer_saved_dict, f"{filename}.optim")
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
class DirWithTextualInversionEmbeddings:
def __init__(self, path):
self.path = path
self.mtime = None
def has_changed(self):
if not os.path.isdir(self.path):
return False
return directory_mtime(self.path) != self.mtime
def update(self):
if not os.path.isdir(self.path):
return
self.mtime = directory_mtime(self.path)
def convert_embedding(tensor, text_encoder, text_encoder_2):
with torch.no_grad():
vectors = []
clip_l_embeds = text_encoder.get_input_embeddings().weight.data.clone().to(device=devices.device)
tensor = tensor.to(device=devices.device)
for vec in tensor:
values, indices = torch.max(torch.nan_to_num(torch.cosine_similarity(vec.unsqueeze(0), clip_l_embeds)), 0)
if values < 0.707: # Arbitrary similarity to cutoff, here 45 degrees
indices *= 0 # Use SDXL padding vector 0
vectors.append(indices)
vectors = torch.stack(vectors)
output = text_encoder_2.get_input_embeddings().weight.data[vectors]
return output
class EmbeddingDatabase:
def __init__(self):
self.ids_lookup = {}
self.word_embeddings = {}
self.skipped_embeddings = {}
self.expected_shape = -1
self.embedding_dirs = {}
self.previously_displayed_embeddings = ()
self.embeddings_used = []
def add_embedding_dir(self, path):
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
def clear_embedding_dirs(self):
self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
if hasattr(model, 'cond_stage_model'):
ids = model.cond_stage_model.tokenize([embedding.name])[0]
elif hasattr(model, 'tokenizer'):
ids = model.tokenizer.convert_tokens_to_ids(embedding.name)
if type(ids) != list:
ids = [ids]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
return embedding
def get_expected_shape(self):
if shared.backend == shared.Backend.DIFFUSERS:
return 0
if shared.sd_model is None:
shared.log.error('Model not loaded')
return 0
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
def load_diffusers_embedding(self, filename: Union[str, List[str]]):
_loaded_pre = len(self.word_embeddings)
embeddings_to_load = []
loaded_embeddings = {}
skipped_embeddings = []
if shared.sd_model is None:
return 0
tokenizer = getattr(shared.sd_model, 'tokenizer', None)
tokenizer_2 = getattr(shared.sd_model, 'tokenizer_2', None)
clip_l = getattr(shared.sd_model, 'text_encoder', None)
clip_g = getattr(shared.sd_model, 'text_encoder_2', None)
if clip_g and tokenizer_2:
model_type = 'SDXL'
elif clip_l and tokenizer:
model_type = 'SD'
else:
return 0
filenames = list(filename)
exts = [".SAFETENSORS", '.BIN', '.PT', '.PNG', '.WEBP', '.JXL', '.AVIF']
for filename in filenames:
# debug(f'Embedding check: {filename}')
fullname = filename
filename = os.path.basename(fullname)
fn, ext = os.path.splitext(filename)
name = os.path.basename(fn)
embedding = Embedding(vec=None, name=name, filename=fullname)
tokenizer_vocab = tokenizer.get_vocab()
try:
if ext.upper() not in exts:
raise ValueError(f'extension `{ext}` is invalid, expected one of: {exts}')
if name in tokenizer.get_vocab() or f"{name}_1" in tokenizer.get_vocab():
loaded_embeddings[name] = embedding
debug(f'Embedding already loaded: {name}')
embeddings_to_load.append(embedding)
except Exception as e:
skipped_embeddings.append(embedding)
debug(f'Embedding skipped: "{name}" {e}')
continue
embeddings_to_load = sorted(embeddings_to_load, key=lambda e: exts.index(os.path.splitext(e.filename)[1].upper()))
tokens_to_add = {}
for embedding in embeddings_to_load:
try:
if embedding.name in tokens_to_add or embedding.name in loaded_embeddings:
raise ValueError('duplicate token')
embeddings_dict = {}
_, ext = os.path.splitext(embedding.filename)
if ext.upper() in ['.SAFETENSORS']:
with safetensors.torch.safe_open(embedding.filename, framework="pt") as f: # type: ignore
for k in f.keys():
embeddings_dict[k] = f.get_tensor(k)
else: # fallback for sd1.5 pt embeddings
embeddings_dict["clip_l"] = self.load_from_file(embedding.filename, embedding.filename)
if 'clip_l' not in embeddings_dict:
raise ValueError('Invalid Embedding, dict missing required key `clip_l`')
if 'clip_g' not in embeddings_dict and model_type == "SDXL" and shared.opts.diffusers_convert_embed:
embeddings_dict["clip_g"] = convert_embedding(embeddings_dict["clip_l"], clip_l, clip_g)
if 'clip_g' in embeddings_dict:
embedding_type = 'SDXL'
else:
embedding_type = 'SD'
if embedding_type != model_type:
raise ValueError(f'Unable to load {embedding_type} Embedding "{embedding.name}" into {model_type} Model')
_tokens_to_add = {}
for i in range(len(embeddings_dict["clip_l"])):
if len(clip_l.get_input_embeddings().weight.data[0]) == len(embeddings_dict["clip_l"][i]):
token = embedding.name if i == 0 else f"{embedding.name}_{i}"
if token in tokenizer_vocab:
raise RuntimeError(f'Multi-Vector Embedding would add pre-existing Token in Vocabulary: {token}')
if token in tokens_to_add:
raise RuntimeError(f'Multi-Vector Embedding would add duplicate Token to Add: {token}')
_tokens_to_add[token] = TokenToAdd(
embeddings_dict["clip_l"][i],
embeddings_dict["clip_g"][i] if 'clip_g' in embeddings_dict else None
)
if not _tokens_to_add:
raise ValueError('no valid tokens to add')
tokens_to_add.update(_tokens_to_add)
loaded_embeddings[embedding.name] = embedding
except Exception as e:
debug(f"Embedding loading: {embedding.filename} {e}")
continue
if len(tokens_to_add) > 0:
tokenizer.add_tokens(list(tokens_to_add.keys()))
clip_l.resize_token_embeddings(len(tokenizer))
if model_type == 'SDXL':
tokenizer_2.add_tokens(list(tokens_to_add.keys())) # type: ignore
clip_g.resize_token_embeddings(len(tokenizer_2)) # type: ignore
unk_token_id = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
for token, data in tokens_to_add.items():
token_id = tokenizer.convert_tokens_to_ids(token)
if token_id > unk_token_id:
clip_l.get_input_embeddings().weight.data[token_id] = data.clip_l
if model_type == 'SDXL':
clip_g.get_input_embeddings().weight.data[token_id] = data.clip_g # type: ignore
for embedding in loaded_embeddings.values():
if not embedding:
continue
self.register_embedding(embedding, shared.sd_model)
if embedding in embeddings_to_load:
embeddings_to_load.remove(embedding)
skipped_embeddings.extend(embeddings_to_load)
for embedding in skipped_embeddings:
if loaded_embeddings.get(embedding.name, None) == embedding:
continue
self.skipped_embeddings[embedding.name] = embedding
try:
if model_type == 'SD':
debug(f"Embeddings loaded: text-encoder={shared.sd_model.text_encoder.get_input_embeddings().weight.data.shape[0]}")
if model_type == 'SDXL':
debug(f"Embeddings loaded: text-encoder-1={shared.sd_model.text_encoder.get_input_embeddings().weight.data.shape[0]} text-encoder-2={shared.sd_model.text_encoder_2.get_input_embeddings().weight.data.shape[0]}")
except Exception:
pass
return len(self.word_embeddings) - _loaded_pre
def load_from_file(self, path, filename):
name, ext = os.path.splitext(filename)
ext = ext.upper()
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
if '.preview' in filename.lower():
return
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
else:
data = extract_image_data_embed(embed_image)
if not data: # if data is None, means this is not an embeding, just a preview image
return
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
data = safetensors.torch.load_file(path, device="cpu")
else:
return
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
if len(data.keys()) != 1:
self.skipped_embeddings[name] = Embedding(None, name=name, filename=path)
return
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise RuntimeError(f"Couldn't identify {filename} as textual inversion embedding")
if shared.backend == shared.Backend.DIFFUSERS:
return emb
vec = emb.detach().to(devices.device, dtype=torch.float32)
# name = data.get('name', name)
embedding = Embedding(vec=vec, name=name, filename=path)
embedding.tag = data.get('name', None)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
self.skipped_embeddings[name] = embedding
def load_from_dir(self, embdir):
if sd_models.model_data.sd_model is None:
shared.log.info('Skipping embeddings load: model not loaded')
return
if not os.path.isdir(embdir.path):
return
file_paths = list_embeddings(embdir.path)
if shared.backend == shared.Backend.DIFFUSERS:
self.load_diffusers_embedding(file_paths)
else:
for file_path in file_paths:
try:
fn = os.path.basename(file_path)
self.load_from_file(file_path, fn)
except Exception as e:
errors.display(e, f'Load embeding={fn}')
continue
def load_textual_inversion_embeddings(self, force_reload=False):
if shared.sd_model is None:
return
t0 = time.time()
if not force_reload:
need_reload = False
for embdir in self.embedding_dirs.values():
if embdir.has_changed():
need_reload = True
break
if not need_reload:
return
self.ids_lookup.clear()
self.word_embeddings.clear()
self.skipped_embeddings.clear()
self.embeddings_used.clear()
self.expected_shape = self.get_expected_shape()
for embdir in self.embedding_dirs.values():
self.load_from_dir(embdir)
embdir.update()
# re-sort word_embeddings because load_from_dir may not load in alphabetic order.
# using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
self.word_embeddings.clear()
self.word_embeddings.update(sorted_word_embeddings)
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
t1 = time.time()
shared.log.info(f"Load embeddings: loaded={len(self.word_embeddings)} skipped={len(self.skipped_embeddings)} time={t1-t0:.2f}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
return embedding, len(ids)
return None, None