John6666 commited on
Commit
57b53d4
·
verified ·
1 Parent(s): d47d03d

Upload 12 files

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Convert diffusers SDXL repo to single Safetensors V2
3
  emoji: 🐶
4
  colorFrom: yellow
5
  colorTo: red
 
1
  ---
2
+ title: Convert HF Diffusers repo to single safetensors file V2 (for SDXL / SD 1.5 / LoRA)
3
  emoji: 🐶
4
  colorFrom: yellow
5
  colorTo: red
app.py CHANGED
@@ -3,30 +3,37 @@ import os
3
  from convert_repo_to_safetensors_gr import convert_repo_to_safetensors_multi, clear_safetensors
4
  os.environ['HF_OUTPUT_REPO'] = 'John6666/safetensors_converting_test'
5
 
6
- css = """"""
 
 
 
7
 
8
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
9
- gr.Markdown(
10
- f"""
11
- - [A CLI version of this tool is available here](https://huggingface.co/spaces/John6666/convert_repo_to_safetensors/tree/main/local).
12
- """)
13
  with gr.Column():
14
  repo_id = gr.Textbox(label="Repo ID", placeholder="author/model", value="", lines=1)
15
  is_upload = gr.Checkbox(label="Upload safetensors to HF Repo", info="Fast download, but files will be public.", value=False)
16
  with gr.Accordion("Advanced", open=False):
17
- dtype = gr.Radio(label="Output data type", choices=["fp16", "fp32", "bf16", "default"], value="fp16")
18
- with gr.Row():
19
- hf_token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
20
- gr.Markdown("Your token is available at [hf.co/settings/tokens](https://huggingface.co/settings/tokens).")
21
- with gr.Row():
22
- newrepo_id = gr.Textbox(label="Upload repo ID", placeholder="yourid/newrepo", value="", max_lines=1)
23
- newrepo_type = gr.Radio(label="Upload repo type", choices=["model", "dataset"], value="model")
24
- is_private = gr.Checkbox(label="Create / Use private repo", value=True)
 
25
  uploaded_urls = gr.CheckboxGroup(visible=False, choices=[], value=None) # hidden
26
  run_button = gr.Button(value="Convert")
27
  st_file = gr.Files(label="Output", interactive=False)
28
  st_md = gr.Markdown()
29
  delete_button = gr.Button(value="Delete Safetensors")
 
 
 
 
 
 
30
 
31
  gr.on(
32
  triggers=[repo_id.submit, run_button.click],
 
3
  from convert_repo_to_safetensors_gr import convert_repo_to_safetensors_multi, clear_safetensors
4
  os.environ['HF_OUTPUT_REPO'] = 'John6666/safetensors_converting_test'
5
 
6
+ css = """
7
+ .title { text-align: center; !important; }
8
+ .footer { text-align: center; !important; }
9
+ """
10
 
11
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
12
+ gr.Markdown("# HF Diffusers repo to WebUI/ComfyUI single safetensors file converter (for SDXL / SD 1.5 / LoRA)", elem_classes="title")
 
 
 
13
  with gr.Column():
14
  repo_id = gr.Textbox(label="Repo ID", placeholder="author/model", value="", lines=1)
15
  is_upload = gr.Checkbox(label="Upload safetensors to HF Repo", info="Fast download, but files will be public.", value=False)
16
  with gr.Accordion("Advanced", open=False):
17
+ dtype = gr.Radio(label="Output data type", choices=["fp16", "fp32", "bf16", "fp8", "default"], value="fp16")
18
+ with gr.Accordion("Upload to your repo", open=True):
19
+ with gr.Row():
20
+ hf_token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
21
+ gr.Markdown("Your token is available at [hf.co/settings/tokens](https://huggingface.co/settings/tokens).")
22
+ with gr.Row():
23
+ newrepo_id = gr.Textbox(label="Upload repo ID", placeholder="yourid/newrepo", value="", max_lines=1)
24
+ newrepo_type = gr.Radio(label="Upload repo type", choices=["model", "dataset"], value="model")
25
+ is_private = gr.Checkbox(label="Create private repo", value=True)
26
  uploaded_urls = gr.CheckboxGroup(visible=False, choices=[], value=None) # hidden
27
  run_button = gr.Button(value="Convert")
28
  st_file = gr.Files(label="Output", interactive=False)
29
  st_md = gr.Markdown()
30
  delete_button = gr.Button(value="Delete Safetensors")
31
+ gr.DuplicateButton(value="Duplicate Space")
32
+ gr.Markdown(
33
+ f"""
34
+ - Thanks to [xi0v](https://huggingface.co/xi0v)
35
+ - [A CLI version of this tool is available here](https://huggingface.co/spaces/John6666/convert_repo_to_safetensors/tree/main/local).
36
+ """, elem_classes="footer")
37
 
38
  gr.on(
39
  triggers=[repo_id.submit, run_button.click],
convert_repo_to_safetensors_gr.py CHANGED
@@ -15,8 +15,9 @@ import os
15
  from pathlib import Path
16
  import shutil
17
  import gc
18
- from utils import get_token, set_token, is_repo_exists
19
-
 
20
 
21
  # =================#
22
  # UNet Conversion #
@@ -336,6 +337,7 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16",
336
  if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
337
  elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
338
  elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
 
339
 
340
  save_file(state_dict, checkpoint_path)
341
 
@@ -387,17 +389,32 @@ def convert_repo_to_safetensors_multi(repo_id, hf_token, files, urls, dtype="fp1
387
  else: set_token(os.environ.get("HF_TOKEN"))
388
  if is_upload and newrepo_id and not hf_token: raise gr.Error("HF write token is required for this process.")
389
  if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
390
- file = convert_repo_to_safetensors(repo_id, dtype)
 
 
 
 
 
 
 
 
 
 
391
  if not urls: urls = []
 
392
  url = ""
393
  if is_upload:
394
  url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
395
- if url: urls.append(url)
 
 
 
 
 
 
396
  md = ""
397
  for u in urls:
398
  md += f"[Download {str(u).split('/')[-1]}]({str(u)})<br>"
399
- if not files: files = []
400
- files.append(file)
401
  gc.collect()
402
  return gr.update(value=files), gr.update(value=urls, choices=urls), gr.update(value=md)
403
 
@@ -414,7 +431,7 @@ if __name__ == "__main__":
414
  parser = argparse.ArgumentParser()
415
 
416
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
417
- parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "default"], help='Output data type. (Default: "fp16")')
418
 
419
  args = parser.parse_args()
420
  assert args.repo_id is not None, "Must provide a Repo ID!"
 
15
  from pathlib import Path
16
  import shutil
17
  import gc
18
+ from utils import get_token, set_token, is_repo_exists, get_model_type
19
+ from convert_repo_to_safetensors_sd_gr import convert_repo_to_safetensors as convert_repo_to_safetensors_sd
20
+ from convert_repo_to_safetensors_sdxl_lora_gr import convert_repo_to_safetensors_sdxl_lora
21
 
22
  # =================#
23
  # UNet Conversion #
 
337
  if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
338
  elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
339
  elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
340
+ elif dtype == "fp8": state_dict = {k: v.to(torch.float8_e4m3fn) for k, v in state_dict.items()}
341
 
342
  save_file(state_dict, checkpoint_path)
343
 
 
389
  else: set_token(os.environ.get("HF_TOKEN"))
390
  if is_upload and newrepo_id and not hf_token: raise gr.Error("HF write token is required for this process.")
391
  if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
392
+ model_type = get_model_type(repo_id)
393
+ if model_type == "SDXL":
394
+ gr.Info(f"Converting {model_type} model.")
395
+ file = convert_repo_to_safetensors(repo_id, dtype)
396
+ elif model_type == "SD 1.5":
397
+ gr.Info(f"Converting {model_type} model.")
398
+ file = convert_repo_to_safetensors_sd(repo_id, dtype)
399
+ elif model_type == "LoRA":
400
+ gr.Info(f"Converting {model_type}.")
401
+ file = convert_repo_to_safetensors_sdxl_lora(repo_id)
402
+ else: raise gr.Error(f"Unsupported model type: {model_type}")
403
  if not urls: urls = []
404
+ if not files: files = []
405
  url = ""
406
  if is_upload:
407
  url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
408
+ if url:
409
+ urls.append(url)
410
+ Path(file).unlink()
411
+ else: files.append(file)
412
+ else:
413
+ files.append(file)
414
+ progress(1, desc="Processing...")
415
  md = ""
416
  for u in urls:
417
  md += f"[Download {str(u).split('/')[-1]}]({str(u)})<br>"
 
 
418
  gc.collect()
419
  return gr.update(value=files), gr.update(value=urls, choices=urls), gr.update(value=md)
420
 
 
431
  parser = argparse.ArgumentParser()
432
 
433
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
434
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "fp8", "default"], help='Output data type. (Default: "fp16")')
435
 
436
  args = parser.parse_args()
437
  assert args.repo_id is not None, "Must provide a Repo ID!"
convert_repo_to_safetensors_sd_gr.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, VAE, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+
5
+ import argparse
6
+ import os.path as osp
7
+ import re
8
+
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+ import gradio as gr
12
+
13
+ from huggingface_hub import HfApi, HfFolder, hf_hub_url, snapshot_download
14
+ import os
15
+ from pathlib import Path
16
+ import shutil
17
+ import gc
18
+ from utils import get_token, set_token, is_repo_exists
19
+
20
+ # =================#
21
+ # UNet Conversion #
22
+ # =================#
23
+
24
+ unet_conversion_map = [
25
+ # (stable-diffusion, HF Diffusers)
26
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
27
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
28
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
29
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
30
+ ("input_blocks.0.0.weight", "conv_in.weight"),
31
+ ("input_blocks.0.0.bias", "conv_in.bias"),
32
+ ("out.0.weight", "conv_norm_out.weight"),
33
+ ("out.0.bias", "conv_norm_out.bias"),
34
+ ("out.2.weight", "conv_out.weight"),
35
+ ("out.2.bias", "conv_out.bias"),
36
+ ]
37
+
38
+ unet_conversion_map_resnet = [
39
+ # (stable-diffusion, HF Diffusers)
40
+ ("in_layers.0", "norm1"),
41
+ ("in_layers.2", "conv1"),
42
+ ("out_layers.0", "norm2"),
43
+ ("out_layers.3", "conv2"),
44
+ ("emb_layers.1", "time_emb_proj"),
45
+ ("skip_connection", "conv_shortcut"),
46
+ ]
47
+
48
+ unet_conversion_map_layer = []
49
+ # hardcoded number of downblocks and resnets/attentions...
50
+ # would need smarter logic for other networks.
51
+ for i in range(4):
52
+ # loop over downblocks/upblocks
53
+
54
+ for j in range(2):
55
+ # loop over resnets/attentions for downblocks
56
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
57
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
58
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
59
+
60
+ if i < 3:
61
+ # no attention layers in down_blocks.3
62
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
63
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
64
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
65
+
66
+ for j in range(3):
67
+ # loop over resnets/attentions for upblocks
68
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
69
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
70
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
71
+
72
+ if i > 0:
73
+ # no attention layers in up_blocks.0
74
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
75
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
76
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
77
+
78
+ if i < 3:
79
+ # no downsample in down_blocks.3
80
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
81
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
82
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
83
+
84
+ # no upsample in up_blocks.3
85
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
86
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
87
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
88
+
89
+ hf_mid_atn_prefix = "mid_block.attentions.0."
90
+ sd_mid_atn_prefix = "middle_block.1."
91
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
92
+
93
+ for j in range(2):
94
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
95
+ sd_mid_res_prefix = f"middle_block.{2*j}."
96
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
97
+
98
+
99
+ def convert_unet_state_dict(unet_state_dict):
100
+ # buyer beware: this is a *brittle* function,
101
+ # and correct output requires that all of these pieces interact in
102
+ # the exact order in which I have arranged them.
103
+ mapping = {k: k for k in unet_state_dict.keys()}
104
+ for sd_name, hf_name in unet_conversion_map:
105
+ mapping[hf_name] = sd_name
106
+ for k, v in mapping.items():
107
+ if "resnets" in k:
108
+ for sd_part, hf_part in unet_conversion_map_resnet:
109
+ v = v.replace(hf_part, sd_part)
110
+ mapping[k] = v
111
+ for k, v in mapping.items():
112
+ for sd_part, hf_part in unet_conversion_map_layer:
113
+ v = v.replace(hf_part, sd_part)
114
+ mapping[k] = v
115
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
116
+ return new_state_dict
117
+
118
+
119
+ # ================#
120
+ # VAE Conversion #
121
+ # ================#
122
+
123
+ vae_conversion_map = [
124
+ # (stable-diffusion, HF Diffusers)
125
+ ("nin_shortcut", "conv_shortcut"),
126
+ ("norm_out", "conv_norm_out"),
127
+ ("mid.attn_1.", "mid_block.attentions.0."),
128
+ ]
129
+
130
+ for i in range(4):
131
+ # down_blocks have two resnets
132
+ for j in range(2):
133
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
134
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
135
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
136
+
137
+ if i < 3:
138
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
139
+ sd_downsample_prefix = f"down.{i}.downsample."
140
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
141
+
142
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
143
+ sd_upsample_prefix = f"up.{3-i}.upsample."
144
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
145
+
146
+ # up_blocks have three resnets
147
+ # also, up blocks in hf are numbered in reverse from sd
148
+ for j in range(3):
149
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
150
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
151
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
152
+
153
+ # this part accounts for mid blocks in both the encoder and the decoder
154
+ for i in range(2):
155
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
156
+ sd_mid_res_prefix = f"mid.block_{i+1}."
157
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
158
+
159
+
160
+ vae_conversion_map_attn = [
161
+ # (stable-diffusion, HF Diffusers)
162
+ ("norm.", "group_norm."),
163
+ ("q.", "query."),
164
+ ("k.", "key."),
165
+ ("v.", "value."),
166
+ ("proj_out.", "proj_attn."),
167
+ ]
168
+
169
+ # This is probably not the most ideal solution, but it does work.
170
+ vae_extra_conversion_map = [
171
+ ("to_q", "q"),
172
+ ("to_k", "k"),
173
+ ("to_v", "v"),
174
+ ("to_out.0", "proj_out"),
175
+ ]
176
+
177
+
178
+ def reshape_weight_for_sd(w):
179
+ # convert HF linear weights to SD conv2d weights
180
+ if not w.ndim == 1:
181
+ return w.reshape(*w.shape, 1, 1)
182
+ else:
183
+ return w
184
+
185
+
186
+ def convert_vae_state_dict(vae_state_dict):
187
+ mapping = {k: k for k in vae_state_dict.keys()}
188
+ for k, v in mapping.items():
189
+ for sd_part, hf_part in vae_conversion_map:
190
+ v = v.replace(hf_part, sd_part)
191
+ mapping[k] = v
192
+ for k, v in mapping.items():
193
+ if "attentions" in k:
194
+ for sd_part, hf_part in vae_conversion_map_attn:
195
+ v = v.replace(hf_part, sd_part)
196
+ mapping[k] = v
197
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
198
+ weights_to_convert = ["q", "k", "v", "proj_out"]
199
+ keys_to_rename = {}
200
+ for k, v in new_state_dict.items():
201
+ for weight_name in weights_to_convert:
202
+ if f"mid.attn_1.{weight_name}.weight" in k:
203
+ print(f"Reshaping {k} for SD format")
204
+ new_state_dict[k] = reshape_weight_for_sd(v)
205
+ for weight_name, real_weight_name in vae_extra_conversion_map:
206
+ if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
207
+ keys_to_rename[k] = k.replace(weight_name, real_weight_name)
208
+ for k, v in keys_to_rename.items():
209
+ if k in new_state_dict:
210
+ print(f"Renaming {k} to {v}")
211
+ new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
212
+ del new_state_dict[k]
213
+ return new_state_dict
214
+
215
+
216
+ # =========================#
217
+ # Text Encoder Conversion #
218
+ # =========================#
219
+
220
+
221
+ textenc_conversion_lst = [
222
+ # (stable-diffusion, HF Diffusers)
223
+ ("resblocks.", "text_model.encoder.layers."),
224
+ ("ln_1", "layer_norm1"),
225
+ ("ln_2", "layer_norm2"),
226
+ (".c_fc.", ".fc1."),
227
+ (".c_proj.", ".fc2."),
228
+ (".attn", ".self_attn"),
229
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
230
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
231
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
232
+ ]
233
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
234
+ textenc_pattern = re.compile("|".join(protected.keys()))
235
+
236
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
237
+ code2idx = {"q": 0, "k": 1, "v": 2}
238
+
239
+
240
+ def convert_text_enc_state_dict_v20(text_enc_dict):
241
+ new_state_dict = {}
242
+ capture_qkv_weight = {}
243
+ capture_qkv_bias = {}
244
+ for k, v in text_enc_dict.items():
245
+ if (
246
+ k.endswith(".self_attn.q_proj.weight")
247
+ or k.endswith(".self_attn.k_proj.weight")
248
+ or k.endswith(".self_attn.v_proj.weight")
249
+ ):
250
+ k_pre = k[: -len(".q_proj.weight")]
251
+ k_code = k[-len("q_proj.weight")]
252
+ if k_pre not in capture_qkv_weight:
253
+ capture_qkv_weight[k_pre] = [None, None, None]
254
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
255
+ continue
256
+
257
+ if (
258
+ k.endswith(".self_attn.q_proj.bias")
259
+ or k.endswith(".self_attn.k_proj.bias")
260
+ or k.endswith(".self_attn.v_proj.bias")
261
+ ):
262
+ k_pre = k[: -len(".q_proj.bias")]
263
+ k_code = k[-len("q_proj.bias")]
264
+ if k_pre not in capture_qkv_bias:
265
+ capture_qkv_bias[k_pre] = [None, None, None]
266
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
267
+ continue
268
+
269
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
270
+ new_state_dict[relabelled_key] = v
271
+
272
+ for k_pre, tensors in capture_qkv_weight.items():
273
+ if None in tensors:
274
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
275
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
276
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
277
+
278
+ for k_pre, tensors in capture_qkv_bias.items():
279
+ if None in tensors:
280
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
281
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
282
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
283
+
284
+ return new_state_dict
285
+
286
+
287
+ def convert_text_enc_state_dict(text_enc_dict):
288
+ return text_enc_dict
289
+
290
+
291
+ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
292
+ # Path for safetensors
293
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
294
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
295
+ text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
296
+
297
+ # Load models from safetensors if it exists, if it doesn't pytorch
298
+ if osp.exists(unet_path):
299
+ unet_state_dict = load_file(unet_path, device="cpu")
300
+ else:
301
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
302
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
303
+
304
+ if osp.exists(vae_path):
305
+ vae_state_dict = load_file(vae_path, device="cpu")
306
+ else:
307
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
308
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
309
+
310
+ if osp.exists(text_enc_path):
311
+ text_enc_dict = load_file(text_enc_path, device="cpu")
312
+ else:
313
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
314
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
315
+
316
+ # Convert the UNet model
317
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
318
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
319
+
320
+ # Convert the VAE model
321
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
322
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
323
+
324
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
325
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
326
+
327
+ if is_v20_model:
328
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
329
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
330
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
331
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
332
+ else:
333
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
334
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
335
+
336
+ # Put together new checkpoint
337
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
338
+
339
+ if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
340
+ elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
341
+ elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
342
+ elif dtype == "fp8": state_dict = {k: v.to(torch.float8_e4m3fn) for k, v in state_dict.items()}
343
+
344
+ save_file(state_dict, checkpoint_path)
345
+
346
+
347
+ # https://huggingface.co/docs/huggingface_hub/v0.25.1/en/package_reference/file_download#huggingface_hub.snapshot_download
348
+ def download_repo(repo_id, dir_path):
349
+ hf_token = get_token()
350
+ try:
351
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"],
352
+ ignore_patterns=["*.fp16.*", "/*.safetensors", "/*.bin"])
353
+ except Exception as e:
354
+ print(f"Error: Failed to download {repo_id}. {e}")
355
+ gr.Warning(f"Error: Failed to download {repo_id}. {e}")
356
+ return
357
+
358
+
359
+ def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
360
+ output_filename = Path(filename).name
361
+ hf_token = get_token()
362
+ api = HfApi(token=hf_token)
363
+ try:
364
+ if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
365
+ progress(0, desc="Start uploading...")
366
+ api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
367
+ progress(1, desc="Uploaded.")
368
+ url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
369
+ except Exception as e:
370
+ print(f"Error: Failed to upload to {repo_id}. {e}")
371
+ gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
372
+ return None
373
+ return url
374
+
375
+
376
+ def convert_repo_to_safetensors(repo_id, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
377
+ download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
378
+ output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
379
+ progress(0, desc="Start downloading...")
380
+ download_repo(repo_id, download_dir)
381
+ progress(0, desc="Start converting...")
382
+ convert_diffusers_to_safetensors(download_dir, output_filename, dtype)
383
+ progress(1, desc="Converted.")
384
+ shutil.rmtree(download_dir)
385
+ return output_filename
386
+
387
+
388
+ def convert_repo_to_safetensors_multi_sd(repo_id, hf_token, files, urls, dtype="fp16", is_upload=False,
389
+ newrepo_id="", repo_type="model", is_private=True, progress=gr.Progress(track_tqdm=True)):
390
+ if hf_token: set_token(hf_token)
391
+ else: set_token(os.environ.get("HF_TOKEN"))
392
+ if is_upload and newrepo_id and not hf_token: raise gr.Error("HF write token is required for this process.")
393
+ if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
394
+ file = convert_repo_to_safetensors(repo_id, dtype)
395
+ if not urls: urls = []
396
+ url = ""
397
+ if is_upload:
398
+ url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
399
+ if url: urls.append(url)
400
+ progress(1, desc="Processing...")
401
+ md = ""
402
+ for u in urls:
403
+ md += f"[Download {str(u).split('/')[-1]}]({str(u)})<br>"
404
+ if not files: files = []
405
+ files.append(file)
406
+ return gr.update(value=files), gr.update(value=urls, choices=urls), gr.update(value=md)
407
+
408
+
409
+ def clear_safetensors():
410
+ for p in Path('.').glob('*.safetensors'):
411
+ p.unlink()
412
+ print("Deleted.")
413
+ gc.collect()
414
+ return gr.update(value=[])
415
+
416
+
417
+ if __name__ == "__main__":
418
+ parser = argparse.ArgumentParser()
419
+
420
+ parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
421
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "fp8", "default"], help='Output data type. (Default: "fp16")')
422
+
423
+ args = parser.parse_args()
424
+ assert args.repo_id is not None, "Must provide a Repo ID!"
425
+
426
+ convert_repo_to_safetensors(args.repo_id, args.dtype)
convert_repo_to_safetensors_sdxl_lora_gr.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
2
+ # This means that you can input your diffusers-trained LoRAs and
3
+ # Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
4
+
5
+ # To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
6
+ # https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
7
+ # and run the script:
8
+ # python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
9
+ # now you can use corgy.safetensors in your WebUI of choice!
10
+
11
+ # To train your own, here are some diffusers training scripts and utils that you can use and then convert:
12
+ # LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
13
+ # Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
14
+ # - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
15
+ # - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
16
+ # Canonical diffusers training scripts:
17
+ # - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
18
+ # - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
19
+
20
+ import argparse
21
+ import os
22
+
23
+ from safetensors.torch import load_file, save_file
24
+ from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
25
+ from pathlib import Path
26
+ import gradio as gr
27
+
28
+ from huggingface_hub import hf_hub_download, HfApi
29
+ from huggingface_hub import HfApi, HfFolder, hf_hub_url, snapshot_download
30
+ import os
31
+ from pathlib import Path
32
+ import shutil
33
+ import gc
34
+ from utils import get_token, set_token, is_repo_exists, get_model_type
35
+
36
+ def convert_and_save(input_lora, output_lora=None):
37
+ if output_lora is None:
38
+ base_name = os.path.splitext(input_lora)[0]
39
+ output_lora = f"{base_name}_webui.safetensors"
40
+
41
+ diffusers_state_dict = load_file(input_lora)
42
+ try:
43
+ peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
44
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
45
+ except Exception: # skipped
46
+ kohya_state_dict = diffusers_state_dict
47
+ save_file(kohya_state_dict, output_lora)
48
+
49
+
50
+ def download_repo_lora(repo_id, local_file, progress=gr.Progress(track_tqdm=True)):
51
+ hf_token = get_token()
52
+ lora_filename = "pytorch_lora_weights.safetensors"
53
+ lora_path = Path(lora_filename)
54
+ api = HfApi(token=hf_token)
55
+ try:
56
+ if not api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token):
57
+ print(f"Error: This repo isn't diffusers LoRA repo: {repo_id}.")
58
+ return None
59
+ if lora_path.exists():
60
+ print(f"Error: Download file already exists: {lora_filename}.")
61
+ return None
62
+ hf_hub_download(repo_id=repo_id, filename=lora_filename, local_dir=".")
63
+ if lora_path.exists(): lora_path.rename(Path(local_file))
64
+ except Exception as e:
65
+ print(f"Error: Failed to download from {repo_id}. {e}")
66
+ return local_file
67
+
68
+
69
+ def convert_repo_to_safetensors_sdxl_lora(repo_id, progress=gr.Progress(track_tqdm=True)):
70
+ download_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}_diffusers.safetensors"
71
+ output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}_webui.safetensors"
72
+ progress(0, desc="Start downloading...")
73
+ download_repo_lora(repo_id, download_filename)
74
+ progress(0, desc="Start converting...")
75
+ convert_and_save(download_filename, output_filename)
76
+ progress(1, desc="Converted.")
77
+ Path(download_filename).unlink()
78
+ return output_filename
79
+
80
+
81
+ def convert_repo_to_safetensors_sdxl_lora_multi(repo_id, files, progress=gr.Progress(track_tqdm=True)):
82
+ file = convert_repo_to_safetensors_sdxl_lora(repo_id)
83
+ if not files: files = []
84
+ files.append(file)
85
+ return gr.update(value=files)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format from Repo.")
90
+ parser.add_argument("--repo_id", type=str, required=True, help="URL to the Repo of input LoRA model in the diffusers format.")
91
+
92
+ args = parser.parse_args()
93
+
94
+ convert_repo_to_safetensors_sdxl_lora(args.repo_id)
95
+
96
+
97
+ # Usage: python convert_repo_to_safetensors_sdxl_lora.py --repo_id nroggendorff/zelda-lora
local/convert_repo_to_safetensors_sd.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, VAE, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+
5
+ import argparse
6
+ import os.path as osp
7
+ import re
8
+
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+
12
+
13
+ # =================#
14
+ # UNet Conversion #
15
+ # =================#
16
+
17
+ unet_conversion_map = [
18
+ # (stable-diffusion, HF Diffusers)
19
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
20
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
21
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
22
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
23
+ ("input_blocks.0.0.weight", "conv_in.weight"),
24
+ ("input_blocks.0.0.bias", "conv_in.bias"),
25
+ ("out.0.weight", "conv_norm_out.weight"),
26
+ ("out.0.bias", "conv_norm_out.bias"),
27
+ ("out.2.weight", "conv_out.weight"),
28
+ ("out.2.bias", "conv_out.bias"),
29
+ ]
30
+
31
+ unet_conversion_map_resnet = [
32
+ # (stable-diffusion, HF Diffusers)
33
+ ("in_layers.0", "norm1"),
34
+ ("in_layers.2", "conv1"),
35
+ ("out_layers.0", "norm2"),
36
+ ("out_layers.3", "conv2"),
37
+ ("emb_layers.1", "time_emb_proj"),
38
+ ("skip_connection", "conv_shortcut"),
39
+ ]
40
+
41
+ unet_conversion_map_layer = []
42
+ # hardcoded number of downblocks and resnets/attentions...
43
+ # would need smarter logic for other networks.
44
+ for i in range(4):
45
+ # loop over downblocks/upblocks
46
+
47
+ for j in range(2):
48
+ # loop over resnets/attentions for downblocks
49
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
50
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
51
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
52
+
53
+ if i < 3:
54
+ # no attention layers in down_blocks.3
55
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
56
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
57
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
58
+
59
+ for j in range(3):
60
+ # loop over resnets/attentions for upblocks
61
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
62
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
63
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
64
+
65
+ if i > 0:
66
+ # no attention layers in up_blocks.0
67
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
68
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
69
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
70
+
71
+ if i < 3:
72
+ # no downsample in down_blocks.3
73
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
74
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
75
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
76
+
77
+ # no upsample in up_blocks.3
78
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
79
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
80
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
81
+
82
+ hf_mid_atn_prefix = "mid_block.attentions.0."
83
+ sd_mid_atn_prefix = "middle_block.1."
84
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
85
+
86
+ for j in range(2):
87
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
88
+ sd_mid_res_prefix = f"middle_block.{2*j}."
89
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
90
+
91
+
92
+ def convert_unet_state_dict(unet_state_dict):
93
+ # buyer beware: this is a *brittle* function,
94
+ # and correct output requires that all of these pieces interact in
95
+ # the exact order in which I have arranged them.
96
+ mapping = {k: k for k in unet_state_dict.keys()}
97
+ for sd_name, hf_name in unet_conversion_map:
98
+ mapping[hf_name] = sd_name
99
+ for k, v in mapping.items():
100
+ if "resnets" in k:
101
+ for sd_part, hf_part in unet_conversion_map_resnet:
102
+ v = v.replace(hf_part, sd_part)
103
+ mapping[k] = v
104
+ for k, v in mapping.items():
105
+ for sd_part, hf_part in unet_conversion_map_layer:
106
+ v = v.replace(hf_part, sd_part)
107
+ mapping[k] = v
108
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
109
+ return new_state_dict
110
+
111
+
112
+ # ================#
113
+ # VAE Conversion #
114
+ # ================#
115
+
116
+ vae_conversion_map = [
117
+ # (stable-diffusion, HF Diffusers)
118
+ ("nin_shortcut", "conv_shortcut"),
119
+ ("norm_out", "conv_norm_out"),
120
+ ("mid.attn_1.", "mid_block.attentions.0."),
121
+ ]
122
+
123
+ for i in range(4):
124
+ # down_blocks have two resnets
125
+ for j in range(2):
126
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
127
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
128
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
129
+
130
+ if i < 3:
131
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
132
+ sd_downsample_prefix = f"down.{i}.downsample."
133
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
134
+
135
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
136
+ sd_upsample_prefix = f"up.{3-i}.upsample."
137
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
138
+
139
+ # up_blocks have three resnets
140
+ # also, up blocks in hf are numbered in reverse from sd
141
+ for j in range(3):
142
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
143
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
144
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
145
+
146
+ # this part accounts for mid blocks in both the encoder and the decoder
147
+ for i in range(2):
148
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
149
+ sd_mid_res_prefix = f"mid.block_{i+1}."
150
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
151
+
152
+
153
+ vae_conversion_map_attn = [
154
+ # (stable-diffusion, HF Diffusers)
155
+ ("norm.", "group_norm."),
156
+ ("q.", "query."),
157
+ ("k.", "key."),
158
+ ("v.", "value."),
159
+ ("proj_out.", "proj_attn."),
160
+ ]
161
+
162
+ # This is probably not the most ideal solution, but it does work.
163
+ vae_extra_conversion_map = [
164
+ ("to_q", "q"),
165
+ ("to_k", "k"),
166
+ ("to_v", "v"),
167
+ ("to_out.0", "proj_out"),
168
+ ]
169
+
170
+
171
+ def reshape_weight_for_sd(w):
172
+ # convert HF linear weights to SD conv2d weights
173
+ if not w.ndim == 1:
174
+ return w.reshape(*w.shape, 1, 1)
175
+ else:
176
+ return w
177
+
178
+
179
+ def convert_vae_state_dict(vae_state_dict):
180
+ mapping = {k: k for k in vae_state_dict.keys()}
181
+ for k, v in mapping.items():
182
+ for sd_part, hf_part in vae_conversion_map:
183
+ v = v.replace(hf_part, sd_part)
184
+ mapping[k] = v
185
+ for k, v in mapping.items():
186
+ if "attentions" in k:
187
+ for sd_part, hf_part in vae_conversion_map_attn:
188
+ v = v.replace(hf_part, sd_part)
189
+ mapping[k] = v
190
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
191
+ weights_to_convert = ["q", "k", "v", "proj_out"]
192
+ keys_to_rename = {}
193
+ for k, v in new_state_dict.items():
194
+ for weight_name in weights_to_convert:
195
+ if f"mid.attn_1.{weight_name}.weight" in k:
196
+ print(f"Reshaping {k} for SD format")
197
+ new_state_dict[k] = reshape_weight_for_sd(v)
198
+ for weight_name, real_weight_name in vae_extra_conversion_map:
199
+ if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
200
+ keys_to_rename[k] = k.replace(weight_name, real_weight_name)
201
+ for k, v in keys_to_rename.items():
202
+ if k in new_state_dict:
203
+ print(f"Renaming {k} to {v}")
204
+ new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
205
+ del new_state_dict[k]
206
+ return new_state_dict
207
+
208
+
209
+ # =========================#
210
+ # Text Encoder Conversion #
211
+ # =========================#
212
+
213
+
214
+ textenc_conversion_lst = [
215
+ # (stable-diffusion, HF Diffusers)
216
+ ("resblocks.", "text_model.encoder.layers."),
217
+ ("ln_1", "layer_norm1"),
218
+ ("ln_2", "layer_norm2"),
219
+ (".c_fc.", ".fc1."),
220
+ (".c_proj.", ".fc2."),
221
+ (".attn", ".self_attn"),
222
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
223
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
224
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
225
+ ]
226
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
227
+ textenc_pattern = re.compile("|".join(protected.keys()))
228
+
229
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
230
+ code2idx = {"q": 0, "k": 1, "v": 2}
231
+
232
+
233
+ def convert_text_enc_state_dict_v20(text_enc_dict):
234
+ new_state_dict = {}
235
+ capture_qkv_weight = {}
236
+ capture_qkv_bias = {}
237
+ for k, v in text_enc_dict.items():
238
+ if (
239
+ k.endswith(".self_attn.q_proj.weight")
240
+ or k.endswith(".self_attn.k_proj.weight")
241
+ or k.endswith(".self_attn.v_proj.weight")
242
+ ):
243
+ k_pre = k[: -len(".q_proj.weight")]
244
+ k_code = k[-len("q_proj.weight")]
245
+ if k_pre not in capture_qkv_weight:
246
+ capture_qkv_weight[k_pre] = [None, None, None]
247
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
248
+ continue
249
+
250
+ if (
251
+ k.endswith(".self_attn.q_proj.bias")
252
+ or k.endswith(".self_attn.k_proj.bias")
253
+ or k.endswith(".self_attn.v_proj.bias")
254
+ ):
255
+ k_pre = k[: -len(".q_proj.bias")]
256
+ k_code = k[-len("q_proj.bias")]
257
+ if k_pre not in capture_qkv_bias:
258
+ capture_qkv_bias[k_pre] = [None, None, None]
259
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
260
+ continue
261
+
262
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
263
+ new_state_dict[relabelled_key] = v
264
+
265
+ for k_pre, tensors in capture_qkv_weight.items():
266
+ if None in tensors:
267
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
268
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
269
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
270
+
271
+ for k_pre, tensors in capture_qkv_bias.items():
272
+ if None in tensors:
273
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
274
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
275
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
276
+
277
+ return new_state_dict
278
+
279
+
280
+ def convert_text_enc_state_dict(text_enc_dict):
281
+ return text_enc_dict
282
+
283
+
284
+ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16"):
285
+ # Path for safetensors
286
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
287
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
288
+ text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
289
+
290
+ # Load models from safetensors if it exists, if it doesn't pytorch
291
+ if osp.exists(unet_path):
292
+ unet_state_dict = load_file(unet_path, device="cpu")
293
+ else:
294
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
295
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
296
+
297
+ if osp.exists(vae_path):
298
+ vae_state_dict = load_file(vae_path, device="cpu")
299
+ else:
300
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
301
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
302
+
303
+ if osp.exists(text_enc_path):
304
+ text_enc_dict = load_file(text_enc_path, device="cpu")
305
+ else:
306
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
307
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
308
+
309
+ # Convert the UNet model
310
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
311
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
312
+
313
+ # Convert the VAE model
314
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
315
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
316
+
317
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
318
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
319
+
320
+ if is_v20_model:
321
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
322
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
323
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
324
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
325
+ else:
326
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
327
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
328
+
329
+ # Put together new checkpoint
330
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
331
+
332
+ if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
333
+ elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
334
+ elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
335
+
336
+ save_file(state_dict, checkpoint_path)
337
+
338
+
339
+ def download_repo(repo_id, dir_path):
340
+ from huggingface_hub import snapshot_download
341
+ try:
342
+ snapshot_download(repo_id=repo_id, local_dir=dir_path)
343
+ except Exception as e:
344
+ print(f"Error: Failed to download {repo_id}. ")
345
+ return
346
+
347
+
348
+ def convert_repo_to_safetensors(repo_id, dtype="fp16"):
349
+ download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
350
+ output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
351
+ download_repo(repo_id, download_dir)
352
+ convert_diffusers_to_safetensors(download_dir, output_filename, dtype)
353
+ return output_filename
354
+
355
+
356
+ if __name__ == "__main__":
357
+ parser = argparse.ArgumentParser()
358
+
359
+ parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
360
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "default"], help='Output data type. (Default: "fp16")')
361
+
362
+ args = parser.parse_args()
363
+ assert args.repo_id is not None, "Must provide a Repo ID!"
364
+
365
+ convert_repo_to_safetensors(args.repo_id, args.dtype)
local/convert_repo_to_safetensors_sdxl_lora.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
2
+ # This means that you can input your diffusers-trained LoRAs and
3
+ # Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
4
+
5
+ # To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
6
+ # https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
7
+ # and run the script:
8
+ # python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
9
+ # now you can use corgy.safetensors in your WebUI of choice!
10
+
11
+ # To train your own, here are some diffusers training scripts and utils that you can use and then convert:
12
+ # LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
13
+ # Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
14
+ # - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
15
+ # - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
16
+ # Canonical diffusers training scripts:
17
+ # - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
18
+ # - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
19
+
20
+ import argparse
21
+ import os
22
+
23
+ from safetensors.torch import load_file, save_file
24
+ from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
25
+ from pathlib import Path
26
+
27
+ def convert_and_save(input_lora, output_lora=None):
28
+ if output_lora is None:
29
+ base_name = os.path.splitext(input_lora)[0]
30
+ output_lora = f"{base_name}_webui.safetensors"
31
+
32
+ diffusers_state_dict = load_file(input_lora)
33
+ try:
34
+ peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
35
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
36
+ except Exception: # skipped
37
+ kohya_state_dict = diffusers_state_dict
38
+ save_file(kohya_state_dict, output_lora)
39
+
40
+
41
+ def download_repo_lora(repo_id, local_file):
42
+ from huggingface_hub import hf_hub_download, HfApi
43
+ lora_filename = "pytorch_lora_weights.safetensors"
44
+ lora_path = Path(lora_filename)
45
+ api = HfApi()
46
+ try:
47
+ if not api.file_exists(repo_id=repo_id, filename=lora_filename):
48
+ print(f"Error: This repo isn't diffusers LoRA repo: {repo_id}. ")
49
+ return None
50
+ if lora_path.exists():
51
+ print(f"Error: Download file already exists: {lora_filename}. ")
52
+ return None
53
+ hf_hub_download(repo_id=repo_id, filename=lora_filename, local_dir=".")
54
+ if lora_path.exists(): lora_path.rename(Path(local_file))
55
+ except Exception as e:
56
+ print(f"Error: Failed to download from {repo_id}. {e}")
57
+ return local_file
58
+
59
+
60
+ def convert_repo_to_safetensors_sdxl_lora(repo_id):
61
+ download_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}_diffusers.safetensors"
62
+ output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}_webui.safetensors"
63
+ download_repo_lora(repo_id, download_filename)
64
+ convert_and_save(download_filename, output_filename)
65
+ return output_filename
66
+
67
+
68
+ if __name__ == "__main__":
69
+ parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format from Repo.")
70
+ parser.add_argument("--repo_id", type=str, required=True, help="URL to the Repo of input LoRA model in the diffusers format.")
71
+
72
+ args = parser.parse_args()
73
+
74
+ convert_repo_to_safetensors_sdxl_lora(args.repo_id)
75
+
76
+
77
+ # Usage: python convert_repo_to_safetensors_sdxl_lora.py --repo_id nroggendorff/zelda-lora
local/requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
  torch
2
  safetensors
3
- huggingface-hub
 
 
 
 
 
1
  torch
2
  safetensors
3
+ huggingface-hub
4
+ accelerate
5
+ diffusers
6
+ transformers
7
+ peft
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
  torch
2
  safetensors
3
- huggingface-hub
 
 
 
 
 
1
  torch
2
  safetensors
3
+ huggingface-hub
4
+ accelerate
5
+ diffusers
6
+ transformers
7
+ peft
utils.py CHANGED
@@ -34,6 +34,31 @@ def is_repo_exists(repo_id: str, repo_type: str="model"):
34
  return True # for safe
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def list_sub(a, b):
38
  return [e for e in a if e not in b]
39
 
 
34
  return True # for safe
35
 
36
 
37
+ MODEL_TYPE_CLASS = {
38
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
39
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
40
+ "diffusers:FluxPipeline": "FLUX",
41
+ }
42
+
43
+
44
+ def get_model_type(repo_id: str):
45
+ hf_token = get_token()
46
+ api = HfApi(token=hf_token)
47
+ lora_filename = "pytorch_lora_weights.safetensors"
48
+ diffusers_filename = "model_index.json"
49
+ default = "SDXL"
50
+ try:
51
+ if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA"
52
+ if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None"
53
+ model = api.model_info(repo_id=repo_id, token=hf_token)
54
+ tags = model.tags
55
+ for tag in tags:
56
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
57
+ except Exception:
58
+ return default
59
+ return default
60
+
61
+
62
  def list_sub(a, b):
63
  return [e for e in a if e not in b]
64