Spaces:
Paused
Paused
import argparse | |
import os | |
import torch | |
from diffusers.loaders import LoraLoaderMixin | |
from safetensors.torch import load_file | |
from collections import OrderedDict | |
import json | |
# this was just used to match the vae keys to the diffusers keys | |
# you probably wont need this. Unless they change them.... again... again | |
# on second thought, you probably will | |
device = torch.device('cpu') | |
dtype = torch.float32 | |
parser = argparse.ArgumentParser() | |
# require at lease one config file | |
parser.add_argument( | |
'file_1', | |
nargs='+', | |
type=str, | |
help='Path to first safe tensor file' | |
) | |
parser.add_argument( | |
'file_2', | |
nargs='+', | |
type=str, | |
help='Path to second safe tensor file' | |
) | |
args = parser.parse_args() | |
find_matches = False | |
state_dict_file_1 = load_file(args.file_1[0]) | |
state_dict_1_keys = list(state_dict_file_1.keys()) | |
state_dict_file_2 = load_file(args.file_2[0]) | |
state_dict_2_keys = list(state_dict_file_2.keys()) | |
keys_in_both = [] | |
keys_not_in_state_dict_2 = [] | |
for key in state_dict_1_keys: | |
if key not in state_dict_2_keys: | |
keys_not_in_state_dict_2.append(key) | |
keys_not_in_state_dict_1 = [] | |
for key in state_dict_2_keys: | |
if key not in state_dict_1_keys: | |
keys_not_in_state_dict_1.append(key) | |
keys_in_both = [] | |
for key in state_dict_1_keys: | |
if key in state_dict_2_keys: | |
keys_in_both.append(key) | |
# sort them | |
keys_not_in_state_dict_2.sort() | |
keys_not_in_state_dict_1.sort() | |
keys_in_both.sort() | |
json_data = { | |
"both": keys_in_both, | |
"not_in_state_dict_2": keys_not_in_state_dict_2, | |
"not_in_state_dict_1": keys_not_in_state_dict_1 | |
} | |
json_data = json.dumps(json_data, indent=4) | |
remaining_diffusers_values = OrderedDict() | |
for key in keys_not_in_state_dict_1: | |
remaining_diffusers_values[key] = state_dict_file_2[key] | |
# print(remaining_diffusers_values.keys()) | |
remaining_ldm_values = OrderedDict() | |
for key in keys_not_in_state_dict_2: | |
remaining_ldm_values[key] = state_dict_file_1[key] | |
# print(json_data) | |
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
json_save_path = os.path.join(project_root, 'config', 'keys.json') | |
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') | |
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') | |
state_dict_1_filename = os.path.basename(args.file_1[0]) | |
state_dict_2_filename = os.path.basename(args.file_2[0]) | |
# save key names for each in own file | |
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f: | |
f.write(json.dumps(state_dict_1_keys, indent=4)) | |
with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f: | |
f.write(json.dumps(state_dict_2_keys, indent=4)) | |
with open(json_save_path, 'w') as f: | |
f.write(json_data) |