# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from dataclasses import dataclass import torch import json import numpy as np from huggingface_hub import hf_hub_download from safetensors import safe_open from safetensors.torch import load_file as load_sft from .model import Flux, FluxParams from .modules.autoencoder import AutoEncoder, AutoEncoderParams from .modules.conditioner import HFEmbedder import re from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor def load_model(ckpt, device='cpu'): if ckpt.endswith('safetensors'): from safetensors import safe_open pl_sd = {} with safe_open(ckpt, framework="pt", device=device) as f: for k in f.keys(): pl_sd[k] = f.get_tensor(k) else: pl_sd = torch.load(ckpt, map_location=device) return pl_sd def load_safetensors(path): tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) return tensors def get_lora_rank(checkpoint): for k in checkpoint.keys(): if k.endswith(".down.weight"): return checkpoint[k].shape[0] def load_checkpoint(local_path, repo_id, name): if local_path is not None: if '.safetensors' in local_path: print(f"Loading .safetensors checkpoint from {local_path}") checkpoint = load_safetensors(local_path) else: print(f"Loading checkpoint from {local_path}") checkpoint = torch.load(local_path, map_location='cpu') elif repo_id is not None and name is not None: print(f"Loading checkpoint {name} from repo id {repo_id}") checkpoint = load_from_repo_id(repo_id, name) else: raise ValueError( "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" ) return checkpoint def c_crop(image): width, height = image.size new_size = min(width, height) left = (width - new_size) / 2 top = (height - new_size) / 2 right = (width + new_size) / 2 bottom = (height + new_size) / 2 return image.crop((left, top, right, bottom)) def pad64(x): return int(np.ceil(float(x) / 64.0) * 64 - x) def HWC3(x): assert x.dtype == np.uint8 if x.ndim == 2: x = x[:, :, None] assert x.ndim == 3 H, W, C = x.shape assert C == 1 or C == 3 or C == 4 if C == 3: return x if C == 1: return np.concatenate([x, x, x], axis=2) if C == 4: color = x[:, :, 0:3].astype(np.float32) alpha = x[:, :, 3:4].astype(np.float32) / 255.0 y = color * alpha + 255.0 * (1.0 - alpha) y = y.clip(0, 255).astype(np.uint8) return y @dataclass class ModelSpec: params: FluxParams ae_params: AutoEncoderParams ckpt_path: str | None ae_path: str | None repo_id: str | None repo_flow: str | None repo_ae: str | None repo_id_ae: str | None configs = { "flux-dev": ModelSpec( repo_id="black-forest-labs/FLUX.1-dev", repo_id_ae="black-forest-labs/FLUX.1-dev", repo_flow="flux1-dev.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_DEV"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-fp8": ModelSpec( repo_id="black-forest-labs/FLUX.1-dev", repo_id_ae="black-forest-labs/FLUX.1-dev", repo_flow="flux1-dev.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_DEV_FP8"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-schnell": ModelSpec( repo_id="black-forest-labs/FLUX.1-schnell", repo_id_ae="black-forest-labs/FLUX.1-dev", repo_flow="flux1-schnell.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_SCHNELL"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=False, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), } def print_load_warning(missing: list[str], unexpected: list[str]) -> None: if len(missing) > 0 and len(unexpected) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) print("\n" + "-" * 79 + "\n") print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) elif len(missing) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) elif len(unexpected) > 0: print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) def load_from_repo_id(repo_id, checkpoint_name): ckpt_path = hf_hub_download(repo_id, checkpoint_name) sd = load_sft(ckpt_path, device='cpu') return sd def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): # Loading Flux print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) with torch.device("meta" if ckpt_path is not None else device): model = Flux(configs[name].params).to(torch.bfloat16) if ckpt_path is not None: print("Loading checkpoint") # load_sft doesn't support torch.device sd = load_model(ckpt_path, device=str(device)) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return model def load_flow_model_only_lora( name: str, device: str | torch.device = "cuda", hf_download: bool = True, lora_rank: int = 16, use_fp8: bool = False ): # Loading Flux print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) if hf_download: try: lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors") except: lora_ckpt_path = os.environ.get("LORA", None) else: lora_ckpt_path = os.environ.get("LORA", None) with torch.device("meta" if ckpt_path is not None else device): model = Flux(configs[name].params) model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device) if ckpt_path is not None: print("Loading lora") lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\ else torch.load(lora_ckpt_path, map_location='cpu') print("Loading main checkpoint") # load_sft doesn't support torch.device if ckpt_path.endswith('safetensors'): if use_fp8: print( "####\n" "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n" "we convert the fp8 checkpoint on flight from bf16 checkpoint\n" "If your storage is constrained" "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n" ) sd = load_sft(ckpt_path, device="cpu") sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()} else: sd = load_sft(ckpt_path, device=str(device)) sd.update(lora_sd) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) else: dit_state = torch.load(ckpt_path, map_location='cpu') sd = {} for k in dit_state.keys(): sd[k.replace('module.','')] = dit_state[k] sd.update(lora_sd) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) model.to(str(device)) print_load_warning(missing, unexpected) return model def set_lora( model: Flux, lora_rank: int, double_blocks_indices: list[int] | None = None, single_blocks_indices: list[int] | None = None, device: str | torch.device = "cpu", ) -> Flux: double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \ else single_blocks_indices lora_attn_procs = {} with torch.device(device): for name, attn_processor in model.attn_processors.items(): match = re.search(r'\.(\d+)\.', name) if match: layer_index = int(match.group(1)) if name.startswith("double_blocks") and layer_index in double_blocks_indices: lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank) elif name.startswith("single_blocks") and layer_index in single_blocks_indices: lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank) else: lora_attn_procs[name] = attn_processor model.set_attn_processor(lora_attn_procs) return model def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): # Loading Flux from optimum.quanto import requantize print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) # json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') model = Flux(configs[name].params).to(torch.bfloat16) print("Loading checkpoint") # load_sft doesn't support torch.device sd = load_sft(ckpt_path, device='cpu') sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()} model.load_state_dict(sd, assign=True) return model with open(json_path, "r") as f: quantization_map = json.load(f) print("Start a quantization process...") requantize(model, sd, quantization_map, device=device) print("Model is quantized!") return model def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: # max length 64, 128, 256 and 512 should work (if your sequence is short enough) version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders") return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device) def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: version = os.environ.get("CLIP", "openai/clip-vit-large-patch14") return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device) def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: ckpt_path = configs[name].ae_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_ae is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) # Loading the autoencoder print("Init AE") with torch.device("meta" if ckpt_path is not None else device): ae = AutoEncoder(configs[name].ae_params) if ckpt_path is not None: sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return ae