UNO-FLUX / uno /flux /util.py
wuwenxu.01
fix: remove unused parameters
c62efeb
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
import torch
import json
import numpy as np
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from safetensors.torch import load_file as load_sft
from .model import Flux, FluxParams
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
from .modules.conditioner import HFEmbedder
import re
from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor
def load_model(ckpt, device='cpu'):
if ckpt.endswith('safetensors'):
from safetensors import safe_open
pl_sd = {}
with safe_open(ckpt, framework="pt", device=device) as f:
for k in f.keys():
pl_sd[k] = f.get_tensor(k)
else:
pl_sd = torch.load(ckpt, map_location=device)
return pl_sd
def load_safetensors(path):
tensors = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
return tensors
def get_lora_rank(checkpoint):
for k in checkpoint.keys():
if k.endswith(".down.weight"):
return checkpoint[k].shape[0]
def load_checkpoint(local_path, repo_id, name):
if local_path is not None:
if '.safetensors' in local_path:
print(f"Loading .safetensors checkpoint from {local_path}")
checkpoint = load_safetensors(local_path)
else:
print(f"Loading checkpoint from {local_path}")
checkpoint = torch.load(local_path, map_location='cpu')
elif repo_id is not None and name is not None:
print(f"Loading checkpoint {name} from repo id {repo_id}")
checkpoint = load_from_repo_id(repo_id, name)
else:
raise ValueError(
"LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
)
return checkpoint
def c_crop(image):
width, height = image.size
new_size = min(width, height)
left = (width - new_size) / 2
top = (height - new_size) / 2
right = (width + new_size) / 2
bottom = (height + new_size) / 2
return image.crop((left, top, right, bottom))
def pad64(x):
return int(np.ceil(float(x) / 64.0) * 64 - x)
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: str | None
ae_path: str | None
repo_id: str | None
repo_flow: str | None
repo_ae: str | None
repo_id_ae: str | None
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_id_ae="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-dev-fp8": ModelSpec(
repo_id="XLabs-AI/flux-dev-fp8",
repo_id_ae="black-forest-labs/FLUX.1-dev",
repo_flow="flux-dev-fp8.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV_FP8"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_id_ae="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
def load_from_repo_id(repo_id, checkpoint_name):
ckpt_path = hf_hub_download(repo_id, checkpoint_name)
sd = load_sft(ckpt_path, device='cpu')
return sd
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
# Loading Flux
print("Init model")
ckpt_path = configs[name].ckpt_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
with torch.device("meta" if ckpt_path is not None else device):
model = Flux(configs[name].params).to(torch.bfloat16)
if ckpt_path is not None:
print("Loading checkpoint")
# load_sft doesn't support torch.device
sd = load_model(ckpt_path, device=str(device))
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
return model
def load_flow_model_only_lora(
name: str,
device: str | torch.device = "cuda",
hf_download: bool = True,
lora_rank: int = 16
):
# Loading Flux
print("Init model")
ckpt_path = configs[name].ckpt_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
if hf_download:
# lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
try:
lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
except:
lora_ckpt_path = os.environ.get("LORA", None)
else:
lora_ckpt_path = os.environ.get("LORA", None)
with torch.device("meta" if ckpt_path is not None else device):
model = Flux(configs[name].params)
model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device)
if ckpt_path is not None:
print("Loading lora")
lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\
else torch.load(lora_ckpt_path, map_location='cpu')
print("Loading main checkpoint")
# load_sft doesn't support torch.device
if ckpt_path.endswith('safetensors'):
sd = load_sft(ckpt_path, device=str(device))
sd.update(lora_sd)
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
else:
dit_state = torch.load(ckpt_path, map_location='cpu')
sd = {}
for k in dit_state.keys():
sd[k.replace('module.','')] = dit_state[k]
sd.update(lora_sd)
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
model.to(str(device))
print_load_warning(missing, unexpected)
return model
def set_lora(
model: Flux,
lora_rank: int,
double_blocks_indices: list[int] | None = None,
single_blocks_indices: list[int] | None = None,
device: str | torch.device = "cpu",
) -> Flux:
double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices
single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \
else single_blocks_indices
lora_attn_procs = {}
with torch.device(device):
for name, attn_processor in model.attn_processors.items():
match = re.search(r'\.(\d+)\.', name)
if match:
layer_index = int(match.group(1))
if name.startswith("double_blocks") and layer_index in double_blocks_indices:
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
elif name.startswith("single_blocks") and layer_index in single_blocks_indices:
lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
else:
lora_attn_procs[name] = attn_processor
model.set_attn_processor(lora_attn_procs)
return model
def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
# Loading Flux
from optimum.quanto import requantize
print("Init model")
ckpt_path = configs[name].ckpt_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
model = Flux(configs[name].params).to(torch.bfloat16)
print("Loading checkpoint")
# load_sft doesn't support torch.device
sd = load_sft(ckpt_path, device='cpu')
with open(json_path, "r") as f:
quantization_map = json.load(f)
print("Start a quantization process...")
requantize(model, sd, quantization_map, device=device)
print("Model is quantized!")
return model
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders")
return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
version = os.environ.get("CLIP", "openai/clip-vit-large-patch14")
return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
ckpt_path = configs[name].ae_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
# Loading the autoencoder
print("Init AE")
with torch.device("meta" if ckpt_path is not None else device):
ae = AutoEncoder(configs[name].ae_params)
if ckpt_path is not None:
sd = load_sft(ckpt_path, device=str(device))
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
return ae