Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. | |
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
from dataclasses import dataclass | |
import torch | |
import json | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
from safetensors import safe_open | |
from safetensors.torch import load_file as load_sft | |
from .model import Flux, FluxParams | |
from .modules.autoencoder import AutoEncoder, AutoEncoderParams | |
from .modules.conditioner import HFEmbedder | |
import re | |
from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor | |
def load_model(ckpt, device='cpu'): | |
if ckpt.endswith('safetensors'): | |
from safetensors import safe_open | |
pl_sd = {} | |
with safe_open(ckpt, framework="pt", device=device) as f: | |
for k in f.keys(): | |
pl_sd[k] = f.get_tensor(k) | |
else: | |
pl_sd = torch.load(ckpt, map_location=device) | |
return pl_sd | |
def load_safetensors(path): | |
tensors = {} | |
with safe_open(path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
tensors[key] = f.get_tensor(key) | |
return tensors | |
def get_lora_rank(checkpoint): | |
for k in checkpoint.keys(): | |
if k.endswith(".down.weight"): | |
return checkpoint[k].shape[0] | |
def load_checkpoint(local_path, repo_id, name): | |
if local_path is not None: | |
if '.safetensors' in local_path: | |
print(f"Loading .safetensors checkpoint from {local_path}") | |
checkpoint = load_safetensors(local_path) | |
else: | |
print(f"Loading checkpoint from {local_path}") | |
checkpoint = torch.load(local_path, map_location='cpu') | |
elif repo_id is not None and name is not None: | |
print(f"Loading checkpoint {name} from repo id {repo_id}") | |
checkpoint = load_from_repo_id(repo_id, name) | |
else: | |
raise ValueError( | |
"LOADING ERROR: you must specify local_path or repo_id with name in HF to download" | |
) | |
return checkpoint | |
def c_crop(image): | |
width, height = image.size | |
new_size = min(width, height) | |
left = (width - new_size) / 2 | |
top = (height - new_size) / 2 | |
right = (width + new_size) / 2 | |
bottom = (height + new_size) / 2 | |
return image.crop((left, top, right, bottom)) | |
def pad64(x): | |
return int(np.ceil(float(x) / 64.0) * 64 - x) | |
def HWC3(x): | |
assert x.dtype == np.uint8 | |
if x.ndim == 2: | |
x = x[:, :, None] | |
assert x.ndim == 3 | |
H, W, C = x.shape | |
assert C == 1 or C == 3 or C == 4 | |
if C == 3: | |
return x | |
if C == 1: | |
return np.concatenate([x, x, x], axis=2) | |
if C == 4: | |
color = x[:, :, 0:3].astype(np.float32) | |
alpha = x[:, :, 3:4].astype(np.float32) / 255.0 | |
y = color * alpha + 255.0 * (1.0 - alpha) | |
y = y.clip(0, 255).astype(np.uint8) | |
return y | |
class ModelSpec: | |
params: FluxParams | |
ae_params: AutoEncoderParams | |
ckpt_path: str | None | |
ae_path: str | None | |
repo_id: str | None | |
repo_flow: str | None | |
repo_ae: str | None | |
repo_id_ae: str | None | |
configs = { | |
"flux-dev": ModelSpec( | |
repo_id="black-forest-labs/FLUX.1-dev", | |
repo_id_ae="black-forest-labs/FLUX.1-dev", | |
repo_flow="flux1-dev.safetensors", | |
repo_ae="ae.safetensors", | |
ckpt_path=os.getenv("FLUX_DEV"), | |
params=FluxParams( | |
in_channels=64, | |
vec_in_dim=768, | |
context_in_dim=4096, | |
hidden_size=3072, | |
mlp_ratio=4.0, | |
num_heads=24, | |
depth=19, | |
depth_single_blocks=38, | |
axes_dim=[16, 56, 56], | |
theta=10_000, | |
qkv_bias=True, | |
guidance_embed=True, | |
), | |
ae_path=os.getenv("AE"), | |
ae_params=AutoEncoderParams( | |
resolution=256, | |
in_channels=3, | |
ch=128, | |
out_ch=3, | |
ch_mult=[1, 2, 4, 4], | |
num_res_blocks=2, | |
z_channels=16, | |
scale_factor=0.3611, | |
shift_factor=0.1159, | |
), | |
), | |
"flux-dev-fp8": ModelSpec( | |
repo_id="XLabs-AI/flux-dev-fp8", | |
repo_id_ae="black-forest-labs/FLUX.1-dev", | |
repo_flow="flux-dev-fp8.safetensors", | |
repo_ae="ae.safetensors", | |
ckpt_path=os.getenv("FLUX_DEV_FP8"), | |
params=FluxParams( | |
in_channels=64, | |
vec_in_dim=768, | |
context_in_dim=4096, | |
hidden_size=3072, | |
mlp_ratio=4.0, | |
num_heads=24, | |
depth=19, | |
depth_single_blocks=38, | |
axes_dim=[16, 56, 56], | |
theta=10_000, | |
qkv_bias=True, | |
guidance_embed=True, | |
), | |
ae_path=os.getenv("AE"), | |
ae_params=AutoEncoderParams( | |
resolution=256, | |
in_channels=3, | |
ch=128, | |
out_ch=3, | |
ch_mult=[1, 2, 4, 4], | |
num_res_blocks=2, | |
z_channels=16, | |
scale_factor=0.3611, | |
shift_factor=0.1159, | |
), | |
), | |
"flux-schnell": ModelSpec( | |
repo_id="black-forest-labs/FLUX.1-schnell", | |
repo_id_ae="black-forest-labs/FLUX.1-dev", | |
repo_flow="flux1-schnell.safetensors", | |
repo_ae="ae.safetensors", | |
ckpt_path=os.getenv("FLUX_SCHNELL"), | |
params=FluxParams( | |
in_channels=64, | |
vec_in_dim=768, | |
context_in_dim=4096, | |
hidden_size=3072, | |
mlp_ratio=4.0, | |
num_heads=24, | |
depth=19, | |
depth_single_blocks=38, | |
axes_dim=[16, 56, 56], | |
theta=10_000, | |
qkv_bias=True, | |
guidance_embed=False, | |
), | |
ae_path=os.getenv("AE"), | |
ae_params=AutoEncoderParams( | |
resolution=256, | |
in_channels=3, | |
ch=128, | |
out_ch=3, | |
ch_mult=[1, 2, 4, 4], | |
num_res_blocks=2, | |
z_channels=16, | |
scale_factor=0.3611, | |
shift_factor=0.1159, | |
), | |
), | |
} | |
def print_load_warning(missing: list[str], unexpected: list[str]) -> None: | |
if len(missing) > 0 and len(unexpected) > 0: | |
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
print("\n" + "-" * 79 + "\n") | |
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
elif len(missing) > 0: | |
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
elif len(unexpected) > 0: | |
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
def load_from_repo_id(repo_id, checkpoint_name): | |
ckpt_path = hf_hub_download(repo_id, checkpoint_name) | |
sd = load_sft(ckpt_path, device='cpu') | |
return sd | |
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): | |
# Loading Flux | |
print("Init model") | |
ckpt_path = configs[name].ckpt_path | |
if ( | |
ckpt_path is None | |
and configs[name].repo_id is not None | |
and configs[name].repo_flow is not None | |
and hf_download | |
): | |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) | |
with torch.device("meta" if ckpt_path is not None else device): | |
model = Flux(configs[name].params).to(torch.bfloat16) | |
if ckpt_path is not None: | |
print("Loading checkpoint") | |
# load_sft doesn't support torch.device | |
sd = load_model(ckpt_path, device=str(device)) | |
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) | |
print_load_warning(missing, unexpected) | |
return model | |
def load_flow_model_only_lora( | |
name: str, | |
device: str | torch.device = "cuda", | |
hf_download: bool = True, | |
lora_rank: int = 16 | |
): | |
# Loading Flux | |
print("Init model") | |
ckpt_path = configs[name].ckpt_path | |
if ( | |
ckpt_path is None | |
and configs[name].repo_id is not None | |
and configs[name].repo_flow is not None | |
and hf_download | |
): | |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) | |
if hf_download: | |
# lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors") | |
try: | |
lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors") | |
except: | |
lora_ckpt_path = os.environ.get("LORA", None) | |
else: | |
lora_ckpt_path = os.environ.get("LORA", None) | |
with torch.device("meta" if ckpt_path is not None else device): | |
model = Flux(configs[name].params) | |
model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device) | |
if ckpt_path is not None: | |
print("Loading lora") | |
lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\ | |
else torch.load(lora_ckpt_path, map_location='cpu') | |
print("Loading main checkpoint") | |
# load_sft doesn't support torch.device | |
if ckpt_path.endswith('safetensors'): | |
sd = load_sft(ckpt_path, device=str(device)) | |
sd.update(lora_sd) | |
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) | |
else: | |
dit_state = torch.load(ckpt_path, map_location='cpu') | |
sd = {} | |
for k in dit_state.keys(): | |
sd[k.replace('module.','')] = dit_state[k] | |
sd.update(lora_sd) | |
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) | |
model.to(str(device)) | |
print_load_warning(missing, unexpected) | |
return model | |
def set_lora( | |
model: Flux, | |
lora_rank: int, | |
double_blocks_indices: list[int] | None = None, | |
single_blocks_indices: list[int] | None = None, | |
device: str | torch.device = "cpu", | |
) -> Flux: | |
double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices | |
single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \ | |
else single_blocks_indices | |
lora_attn_procs = {} | |
with torch.device(device): | |
for name, attn_processor in model.attn_processors.items(): | |
match = re.search(r'\.(\d+)\.', name) | |
if match: | |
layer_index = int(match.group(1)) | |
if name.startswith("double_blocks") and layer_index in double_blocks_indices: | |
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank) | |
elif name.startswith("single_blocks") and layer_index in single_blocks_indices: | |
lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank) | |
else: | |
lora_attn_procs[name] = attn_processor | |
model.set_attn_processor(lora_attn_procs) | |
return model | |
def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): | |
# Loading Flux | |
from optimum.quanto import requantize | |
print("Init model") | |
ckpt_path = configs[name].ckpt_path | |
if ( | |
ckpt_path is None | |
and configs[name].repo_id is not None | |
and configs[name].repo_flow is not None | |
and hf_download | |
): | |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) | |
json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') | |
model = Flux(configs[name].params).to(torch.bfloat16) | |
print("Loading checkpoint") | |
# load_sft doesn't support torch.device | |
sd = load_sft(ckpt_path, device='cpu') | |
with open(json_path, "r") as f: | |
quantization_map = json.load(f) | |
print("Start a quantization process...") | |
requantize(model, sd, quantization_map, device=device) | |
print("Model is quantized!") | |
return model | |
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: | |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough) | |
version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders") | |
return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device) | |
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: | |
version = os.environ.get("CLIP", "openai/clip-vit-large-patch14") | |
return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device) | |
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: | |
ckpt_path = configs[name].ae_path | |
if ( | |
ckpt_path is None | |
and configs[name].repo_id is not None | |
and configs[name].repo_ae is not None | |
and hf_download | |
): | |
ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) | |
# Loading the autoencoder | |
print("Init AE") | |
with torch.device("meta" if ckpt_path is not None else device): | |
ae = AutoEncoder(configs[name].ae_params) | |
if ckpt_path is not None: | |
sd = load_sft(ckpt_path, device=str(device)) | |
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) | |
print_load_warning(missing, unexpected) | |
return ae |