Spaces:
Paused
Paused
import json | |
from collections import OrderedDict | |
import os | |
import torch | |
from safetensors import safe_open | |
from safetensors.torch import save_file | |
device = torch.device('cpu') | |
# [diffusers] -> kohya | |
embedding_mapping = { | |
'text_encoders_0': 'clip_l', | |
'text_encoders_1': 'clip_g' | |
} | |
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
KEYMAP_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps') | |
sdxl_keymap_path = os.path.join(KEYMAP_ROOT, 'stable_diffusion_locon_sdxl.json') | |
# load keymap | |
with open(sdxl_keymap_path, 'r') as f: | |
ldm_diffusers_keymap = json.load(f)['ldm_diffusers_keymap'] | |
# invert the item / key pairs | |
diffusers_ldm_keymap = {v: k for k, v in ldm_diffusers_keymap.items()} | |
def get_ldm_key(diffuser_key): | |
diffuser_key = f"lora_unet_{diffuser_key.replace('.', '_')}" | |
diffuser_key = diffuser_key.replace('_lora_down_weight', '.lora_down.weight') | |
diffuser_key = diffuser_key.replace('_lora_up_weight', '.lora_up.weight') | |
diffuser_key = diffuser_key.replace('_alpha', '.alpha') | |
diffuser_key = diffuser_key.replace('_processor_to_', '_to_') | |
diffuser_key = diffuser_key.replace('_to_out.', '_to_out_0.') | |
if diffuser_key in diffusers_ldm_keymap: | |
return diffusers_ldm_keymap[diffuser_key] | |
else: | |
raise KeyError(f"Key {diffuser_key} not found in keymap") | |
def convert_cog(lora_path, embedding_path): | |
embedding_state_dict = OrderedDict() | |
lora_state_dict = OrderedDict() | |
# # normal dict | |
# normal_dict = OrderedDict() | |
# example_path = "/mnt/Models/stable-diffusion/models/LoRA/sdxl/LogoRedmond_LogoRedAF.safetensors" | |
# with safe_open(example_path, framework="pt", device='cpu') as f: | |
# keys = list(f.keys()) | |
# for key in keys: | |
# normal_dict[key] = f.get_tensor(key) | |
with safe_open(embedding_path, framework="pt", device='cpu') as f: | |
keys = list(f.keys()) | |
for key in keys: | |
new_key = embedding_mapping[key] | |
embedding_state_dict[new_key] = f.get_tensor(key) | |
with safe_open(lora_path, framework="pt", device='cpu') as f: | |
keys = list(f.keys()) | |
lora_rank = None | |
# get the lora dim first. Check first 3 linear layers just to be safe | |
for key in keys: | |
new_key = get_ldm_key(key) | |
tensor = f.get_tensor(key) | |
num_checked = 0 | |
if len(tensor.shape) == 2: | |
this_dim = min(tensor.shape) | |
if lora_rank is None: | |
lora_rank = this_dim | |
elif lora_rank != this_dim: | |
raise ValueError(f"lora rank is not consistent, got {tensor.shape}") | |
else: | |
num_checked += 1 | |
if num_checked >= 3: | |
break | |
for key in keys: | |
new_key = get_ldm_key(key) | |
tensor = f.get_tensor(key) | |
if new_key.endswith('.lora_down.weight'): | |
alpha_key = new_key.replace('.lora_down.weight', '.alpha') | |
# diffusers does not have alpha, they usa an alpha multiplier of 1 which is a tensor weight of the dims | |
# assume first smallest dim is the lora rank if shape is 2 | |
lora_state_dict[alpha_key] = torch.ones(1).to(tensor.device, tensor.dtype) * lora_rank | |
lora_state_dict[new_key] = tensor | |
return lora_state_dict, embedding_state_dict | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'lora_path', | |
type=str, | |
help='Path to lora file' | |
) | |
parser.add_argument( | |
'embedding_path', | |
type=str, | |
help='Path to embedding file' | |
) | |
parser.add_argument( | |
'--lora_output', | |
type=str, | |
default="lora_output", | |
) | |
parser.add_argument( | |
'--embedding_output', | |
type=str, | |
default="embedding_output", | |
) | |
args = parser.parse_args() | |
lora_state_dict, embedding_state_dict = convert_cog(args.lora_path, args.embedding_path) | |
# save them | |
save_file(lora_state_dict, args.lora_output) | |
save_file(embedding_state_dict, args.embedding_output) | |
print(f"Saved lora to {args.lora_output}") | |
print(f"Saved embedding to {args.embedding_output}") | |