LoRa_Streamlit / ai-toolkit /scripts /convert_diffusers_to_comfy.py
ramimu's picture
Upload 586 files
1c72248 verified
#######################################################
# Convert Diffusers Flux/Flex to all in one ComfyUI safetensors file
# The VAE, T5 and clip will all be in the safetensors file
# T5 will always be 8bit with the all in one file
# You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag
#
# Download a reference model from Huggingface
# https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors
#
# Call like this for 8-bit transformer weights:
# python convert_flux_diffusers_to_orig.py /path/to/diffusers/checkpoint /path/to/flux1-dev-fp8.safetensors /output/path/my_finetune.safetensors --do_8_bit
#
# Call like this for bf16 transformer weights:
# python convert_flux_diffusers_to_orig.py /path/to/diffusers/checkpoint /path/to/flux1-dev-fp8.safetensors /output/path/my_finetune.safetensors
#
#######################################################
import argparse
from datetime import date
import json
import os
from pathlib import Path
import safetensors
import safetensors.torch
import torch
import tqdm
from collections import OrderedDict
parser = argparse.ArgumentParser()
parser.add_argument("diffusers_path", type=str,
help="Path to the original Flux diffusers folder.")
parser.add_argument("quantized_state_dict_path", type=str,
help="Path to the ComfyUI all in one template file.")
parser.add_argument("flux_path", type=str,
help="Output path for the Flux safetensors file.")
parser.add_argument("--do_8_bit", action="store_true",
help="Use 8-bit weights instead of bf16.")
args = parser.parse_args()
flux_path = Path(args.flux_path)
diffusers_path = Path(args.diffusers_path, "transformer")
quantized_state_dict_path = Path(args.quantized_state_dict_path)
do_8_bit = args.do_8_bit
if not os.path.exists(flux_path.parent):
os.makedirs(flux_path.parent)
if not diffusers_path.exists():
print(f"Error: Missing transformer folder: {diffusers_path}")
exit()
original_json_path = Path.joinpath(
diffusers_path, "diffusion_pytorch_model.safetensors.index.json")
if not original_json_path.exists():
print(f"Error: Missing transformer index json: {original_json_path}")
exit()
if not os.path.exists(quantized_state_dict_path):
print(
f"Error: Missing quantized state dict file: {args.quantized_state_dict_path}")
exit()
with open(original_json_path, "r", encoding="utf-8") as f:
original_json = json.load(f)
diffusers_map = {
"time_in.in_layer.weight": [
"time_text_embed.timestep_embedder.linear_1.weight",
],
"time_in.in_layer.bias": [
"time_text_embed.timestep_embedder.linear_1.bias",
],
"time_in.out_layer.weight": [
"time_text_embed.timestep_embedder.linear_2.weight",
],
"time_in.out_layer.bias": [
"time_text_embed.timestep_embedder.linear_2.bias",
],
"vector_in.in_layer.weight": [
"time_text_embed.text_embedder.linear_1.weight",
],
"vector_in.in_layer.bias": [
"time_text_embed.text_embedder.linear_1.bias",
],
"vector_in.out_layer.weight": [
"time_text_embed.text_embedder.linear_2.weight",
],
"vector_in.out_layer.bias": [
"time_text_embed.text_embedder.linear_2.bias",
],
"guidance_in.in_layer.weight": [
"time_text_embed.guidance_embedder.linear_1.weight",
],
"guidance_in.in_layer.bias": [
"time_text_embed.guidance_embedder.linear_1.bias",
],
"guidance_in.out_layer.weight": [
"time_text_embed.guidance_embedder.linear_2.weight",
],
"guidance_in.out_layer.bias": [
"time_text_embed.guidance_embedder.linear_2.bias",
],
"txt_in.weight": [
"context_embedder.weight",
],
"txt_in.bias": [
"context_embedder.bias",
],
"img_in.weight": [
"x_embedder.weight",
],
"img_in.bias": [
"x_embedder.bias",
],
"double_blocks.().img_mod.lin.weight": [
"norm1.linear.weight",
],
"double_blocks.().img_mod.lin.bias": [
"norm1.linear.bias",
],
"double_blocks.().txt_mod.lin.weight": [
"norm1_context.linear.weight",
],
"double_blocks.().txt_mod.lin.bias": [
"norm1_context.linear.bias",
],
"double_blocks.().img_attn.qkv.weight": [
"attn.to_q.weight",
"attn.to_k.weight",
"attn.to_v.weight",
],
"double_blocks.().img_attn.qkv.bias": [
"attn.to_q.bias",
"attn.to_k.bias",
"attn.to_v.bias",
],
"double_blocks.().txt_attn.qkv.weight": [
"attn.add_q_proj.weight",
"attn.add_k_proj.weight",
"attn.add_v_proj.weight",
],
"double_blocks.().txt_attn.qkv.bias": [
"attn.add_q_proj.bias",
"attn.add_k_proj.bias",
"attn.add_v_proj.bias",
],
"double_blocks.().img_attn.norm.query_norm.scale": [
"attn.norm_q.weight",
],
"double_blocks.().img_attn.norm.key_norm.scale": [
"attn.norm_k.weight",
],
"double_blocks.().txt_attn.norm.query_norm.scale": [
"attn.norm_added_q.weight",
],
"double_blocks.().txt_attn.norm.key_norm.scale": [
"attn.norm_added_k.weight",
],
"double_blocks.().img_mlp.0.weight": [
"ff.net.0.proj.weight",
],
"double_blocks.().img_mlp.0.bias": [
"ff.net.0.proj.bias",
],
"double_blocks.().img_mlp.2.weight": [
"ff.net.2.weight",
],
"double_blocks.().img_mlp.2.bias": [
"ff.net.2.bias",
],
"double_blocks.().txt_mlp.0.weight": [
"ff_context.net.0.proj.weight",
],
"double_blocks.().txt_mlp.0.bias": [
"ff_context.net.0.proj.bias",
],
"double_blocks.().txt_mlp.2.weight": [
"ff_context.net.2.weight",
],
"double_blocks.().txt_mlp.2.bias": [
"ff_context.net.2.bias",
],
"double_blocks.().img_attn.proj.weight": [
"attn.to_out.0.weight",
],
"double_blocks.().img_attn.proj.bias": [
"attn.to_out.0.bias",
],
"double_blocks.().txt_attn.proj.weight": [
"attn.to_add_out.weight",
],
"double_blocks.().txt_attn.proj.bias": [
"attn.to_add_out.bias",
],
"single_blocks.().modulation.lin.weight": [
"norm.linear.weight",
],
"single_blocks.().modulation.lin.bias": [
"norm.linear.bias",
],
"single_blocks.().linear1.weight": [
"attn.to_q.weight",
"attn.to_k.weight",
"attn.to_v.weight",
"proj_mlp.weight",
],
"single_blocks.().linear1.bias": [
"attn.to_q.bias",
"attn.to_k.bias",
"attn.to_v.bias",
"proj_mlp.bias",
],
"single_blocks.().linear2.weight": [
"proj_out.weight",
],
"single_blocks.().norm.query_norm.scale": [
"attn.norm_q.weight",
],
"single_blocks.().norm.key_norm.scale": [
"attn.norm_k.weight",
],
"single_blocks.().linear2.weight": [
"proj_out.weight",
],
"single_blocks.().linear2.bias": [
"proj_out.bias",
],
"final_layer.linear.weight": [
"proj_out.weight",
],
"final_layer.linear.bias": [
"proj_out.bias",
],
"final_layer.adaLN_modulation.1.weight": [
"norm_out.linear.weight",
],
"final_layer.adaLN_modulation.1.bias": [
"norm_out.linear.bias",
],
}
def is_in_diffusers_map(k):
for values in diffusers_map.values():
for value in values:
if k.endswith(value):
return True
return False
diffusers = {k: Path.joinpath(diffusers_path, v)
for k, v in original_json["weight_map"].items() if is_in_diffusers_map(k)}
original_safetensors = set(diffusers.values())
# determine the number of transformer blocks
transformer_blocks = 0
single_transformer_blocks = 0
for key in diffusers.keys():
print(key)
if key.startswith("transformer_blocks."):
print(key)
block = int(key.split(".")[1])
if block >= transformer_blocks:
transformer_blocks = block + 1
elif key.startswith("single_transformer_blocks."):
block = int(key.split(".")[1])
if block >= single_transformer_blocks:
single_transformer_blocks = block + 1
print(f"Transformer blocks: {transformer_blocks}")
print(f"Single transformer blocks: {single_transformer_blocks}")
for file in original_safetensors:
if not file.exists():
print(f"Error: Missing transformer safetensors file: {file}")
exit()
original_safetensors = {f: safetensors.safe_open(
f, framework="pt", device="cpu") for f in original_safetensors}
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
flux_values = {}
for b in range(transformer_blocks):
for key, weights in diffusers_map.items():
if key.startswith("double_blocks."):
block_prefix = f"transformer_blocks.{b}."
found = True
for weight in weights:
if not (f"{block_prefix}{weight}" in diffusers):
found = False
if found:
flux_values[key.replace("()", f"{b}")] = [
f"{block_prefix}{weight}" for weight in weights]
for b in range(single_transformer_blocks):
for key, weights in diffusers_map.items():
if key.startswith("single_blocks."):
block_prefix = f"single_transformer_blocks.{b}."
found = True
for weight in weights:
if not (f"{block_prefix}{weight}" in diffusers):
found = False
if found:
flux_values[key.replace("()", f"{b}")] = [
f"{block_prefix}{weight}" for weight in weights]
for key, weights in diffusers_map.items():
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
found = True
for weight in weights:
if not (f"{weight}" in diffusers):
found = False
if found:
flux_values[key] = [f"{weight}" for weight in weights]
flux = {}
for key, values in tqdm.tqdm(flux_values.items()):
if len(values) == 1:
flux[key] = original_safetensors[diffusers[values[0]]
].get_tensor(values[0]).to("cpu")
else:
flux[key] = torch.cat(
[
original_safetensors[diffusers[value]
].get_tensor(value).to("cpu")
for value in values
]
)
if "norm_out.linear.weight" in diffusers:
flux["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(
original_safetensors[diffusers["norm_out.linear.weight"]].get_tensor(
"norm_out.linear.weight").to("cpu")
)
if "norm_out.linear.bias" in diffusers:
flux["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(
original_safetensors[diffusers["norm_out.linear.bias"]].get_tensor(
"norm_out.linear.bias").to("cpu")
)
def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn):
# Define the float8 range
min_val = torch.finfo(dtype).min
max_val = torch.finfo(dtype).max
# Clip values to float8 range
tensor = torch.clamp(tensor, min_val, max_val)
# Convert to float32 for calculations
tensor = tensor.float()
# Get the nearest representable float8 values
lower = torch.floor(tensor * 256) / 256
upper = torch.ceil(tensor * 256) / 256
# Calculate the probability of rounding up
prob = (tensor - lower) / (upper - lower)
# Generate random values for stochastic rounding
rand = torch.rand_like(tensor)
# Perform stochastic rounding
rounded = torch.where(rand < prob, upper, lower)
# Convert back to float8
return rounded.to(dtype)
# set all the keys to bf16
for key in flux.keys():
if do_8_bit:
flux[key] = stochastic_round_to(
flux[key], torch.float8_e4m3fn).to('cpu')
else:
flux[key] = flux[key].clone().to('cpu', torch.bfloat16)
# load the quantized state dict
quantized_state_dict = safetensors.torch.load_file(quantized_state_dict_path)
transformer_pre = "model.diffusion_model."
did_print = False
# remove old parts
for key in list(quantized_state_dict.keys()):
if key.startswith(transformer_pre):
if not did_print:
# print("dtype: ", quantized_state_dict[key].dtype)
did_print = True
del quantized_state_dict[key]
# add the new parts
for key, value in flux.items():
quantized_state_dict[transformer_pre + key] = value
meta = OrderedDict()
meta['format'] = 'pt'
# date format like 2024-08-01 YYYY-MM-DD
meta['modelspec.date'] = date.today().strftime("%Y-%m-%d")
meta['modelspec.title'] = "Flex.1-alpha"
meta['modelspec.author'] = "Ostris, LLC"
meta['modelspec.license'] = "Apache-2.0"
meta['modelspec.implementation'] = "https://github.com/black-forest-labs/flux"
meta['modelspec.architecture'] = "Flex.1-alpha"
meta['modelspec.description'] = "Flex.1-alpha"
os.makedirs(os.path.dirname(flux_path), exist_ok=True)
print(f"Saving to {flux_path}")
safetensors.torch.save_file(quantized_state_dict, flux_path, metadata=meta)
print("Done.")