Spaces:
Paused
Paused
import argparse | |
import os | |
# add project root to sys path | |
import sys | |
from tqdm import tqdm | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
import torch | |
from diffusers.loaders import LoraLoaderMixin | |
from safetensors.torch import load_file | |
from collections import OrderedDict | |
import json | |
from toolkit.config_modules import ModelConfig | |
from toolkit.paths import KEYMAPS_ROOT | |
from toolkit.saving import convert_state_dict_to_ldm_with_mapping, get_ldm_state_dict_from_diffusers | |
from toolkit.stable_diffusion_model import StableDiffusion | |
# this was just used to match the vae keys to the diffusers keys | |
# you probably wont need this. Unless they change them.... again... again | |
# on second thought, you probably will | |
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
device = torch.device('cpu') | |
dtype = torch.float32 | |
parser = argparse.ArgumentParser() | |
# require at lease one config file | |
parser.add_argument( | |
'file_1', | |
nargs='+', | |
type=str, | |
help='Path an LDM model' | |
) | |
parser.add_argument( | |
'--is_xl', | |
action='store_true', | |
help='Is the model an XL model' | |
) | |
parser.add_argument( | |
'--is_v2', | |
action='store_true', | |
help='Is the model a v2 model' | |
) | |
args = parser.parse_args() | |
find_matches = False | |
print("Loading model") | |
state_dict_file_1 = load_file(args.file_1[0]) | |
state_dict_1_keys = list(state_dict_file_1.keys()) | |
print("Loading model into diffusers format") | |
model_config = ModelConfig( | |
name_or_path=args.file_1[0], | |
is_xl=args.is_xl | |
) | |
sd = StableDiffusion( | |
model_config=model_config, | |
device=device, | |
) | |
sd.load_model() | |
# load our base | |
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors') | |
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json') | |
print("Converting model back to LDM") | |
version_string = '1' | |
if args.is_v2: | |
version_string = '2' | |
if args.is_xl: | |
version_string = 'sdxl' | |
# convert the state dict | |
state_dict_file_2 = get_ldm_state_dict_from_diffusers( | |
sd.state_dict(), | |
version_string, | |
device='cpu', | |
dtype=dtype | |
) | |
# state_dict_file_2 = load_file(args.file_2[0]) | |
state_dict_2_keys = list(state_dict_file_2.keys()) | |
keys_in_both = [] | |
keys_not_in_state_dict_2 = [] | |
for key in state_dict_1_keys: | |
if key not in state_dict_2_keys: | |
keys_not_in_state_dict_2.append(key) | |
keys_not_in_state_dict_1 = [] | |
for key in state_dict_2_keys: | |
if key not in state_dict_1_keys: | |
keys_not_in_state_dict_1.append(key) | |
keys_in_both = [] | |
for key in state_dict_1_keys: | |
if key in state_dict_2_keys: | |
keys_in_both.append(key) | |
# sort them | |
keys_not_in_state_dict_2.sort() | |
keys_not_in_state_dict_1.sort() | |
keys_in_both.sort() | |
if len(keys_not_in_state_dict_2) == 0 and len(keys_not_in_state_dict_1) == 0: | |
print("All keys match!") | |
print("Checking values...") | |
mismatch_keys = [] | |
loss = torch.nn.MSELoss() | |
tolerance = 1e-6 | |
for key in tqdm(keys_in_both): | |
if loss(state_dict_file_1[key], state_dict_file_2[key]) > tolerance: | |
print(f"Values for key {key} don't match!") | |
print(f"Loss: {loss(state_dict_file_1[key], state_dict_file_2[key])}") | |
mismatch_keys.append(key) | |
if len(mismatch_keys) == 0: | |
print("All values match!") | |
else: | |
print("Some valued font match!") | |
print(mismatch_keys) | |
mismatched_path = os.path.join(project_root, 'config', 'mismatch.json') | |
with open(mismatched_path, 'w') as f: | |
f.write(json.dumps(mismatch_keys, indent=4)) | |
exit(0) | |
else: | |
print("Keys don't match!, generating info...") | |
json_data = { | |
"both": keys_in_both, | |
"not_in_state_dict_2": keys_not_in_state_dict_2, | |
"not_in_state_dict_1": keys_not_in_state_dict_1 | |
} | |
json_data = json.dumps(json_data, indent=4) | |
remaining_diffusers_values = OrderedDict() | |
for key in keys_not_in_state_dict_1: | |
remaining_diffusers_values[key] = state_dict_file_2[key] | |
# print(remaining_diffusers_values.keys()) | |
remaining_ldm_values = OrderedDict() | |
for key in keys_not_in_state_dict_2: | |
remaining_ldm_values[key] = state_dict_file_1[key] | |
# print(json_data) | |
json_save_path = os.path.join(project_root, 'config', 'keys.json') | |
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') | |
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') | |
state_dict_1_filename = os.path.basename(args.file_1[0]) | |
# state_dict_2_filename = os.path.basename(args.file_2[0]) | |
# save key names for each in own file | |
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f: | |
f.write(json.dumps(state_dict_1_keys, indent=4)) | |
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}_loop.json'), 'w') as f: | |
f.write(json.dumps(state_dict_2_keys, indent=4)) | |
with open(json_save_path, 'w') as f: | |
f.write(json_data) | |