Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. | |
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
from typing import Literal | |
import torch | |
from einops import rearrange | |
from PIL import ExifTags, Image | |
import torchvision.transforms.functional as TVF | |
from uno.flux.modules.layers import ( | |
DoubleStreamBlockLoraProcessor, | |
DoubleStreamBlockProcessor, | |
SingleStreamBlockLoraProcessor, | |
SingleStreamBlockProcessor, | |
) | |
from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack | |
from uno.flux.util import ( | |
get_lora_rank, | |
load_ae, | |
load_checkpoint, | |
load_clip, | |
load_flow_model, | |
load_flow_model_only_lora, | |
load_flow_model_quintized, | |
load_t5, | |
) | |
def find_nearest_scale(image_h, image_w, predefined_scales): | |
""" | |
根据图片的高度和宽度,找到最近的预定义尺度。 | |
:param image_h: 图片的高度 | |
:param image_w: 图片的宽度 | |
:param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...] | |
:return: 最近的预定义尺度 (h, w) | |
""" | |
# 计算输入图片的长宽比 | |
image_ratio = image_h / image_w | |
# 初始化变量以存储最小差异和最近的尺度 | |
min_diff = float('inf') | |
nearest_scale = None | |
# 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度 | |
for scale_h, scale_w in predefined_scales: | |
predefined_ratio = scale_h / scale_w | |
diff = abs(predefined_ratio - image_ratio) | |
if diff < min_diff: | |
min_diff = diff | |
nearest_scale = (scale_h, scale_w) | |
return nearest_scale | |
def preprocess_ref(raw_image: Image.Image, long_size: int = 512): | |
# 获取原始图像的宽度和高度 | |
image_w, image_h = raw_image.size | |
# 计算长边和短边 | |
if image_w >= image_h: | |
new_w = long_size | |
new_h = int((long_size / image_w) * image_h) | |
else: | |
new_h = long_size | |
new_w = int((long_size / image_h) * image_w) | |
# 按新的宽高进行等比例缩放 | |
raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) | |
target_w = new_w // 16 * 16 | |
target_h = new_h // 16 * 16 | |
# 计算裁剪的起始坐标以实现中心裁剪 | |
left = (new_w - target_w) // 2 | |
top = (new_h - target_h) // 2 | |
right = left + target_w | |
bottom = top + target_h | |
# 进行中心裁剪 | |
raw_image = raw_image.crop((left, top, right, bottom)) | |
# 转换为 RGB 模式 | |
raw_image = raw_image.convert("RGB") | |
return raw_image | |
class UNOPipeline: | |
def __init__( | |
self, | |
model_type: str, | |
device: torch.device, | |
offload: bool = False, | |
only_lora: bool = False, | |
lora_rank: int = 16 | |
): | |
self.device = device | |
self.offload = offload | |
self.model_type = model_type | |
self.clip = load_clip(self.device) | |
self.t5 = load_t5(self.device, max_length=512) | |
self.ae = load_ae(model_type, device="cpu" if offload else self.device) | |
if "fp8" in model_type: | |
self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) | |
elif only_lora: | |
self.model = load_flow_model_only_lora( | |
model_type, device="cpu" if offload else self.device, lora_rank=lora_rank | |
) | |
else: | |
self.model = load_flow_model(model_type, device="cpu" if offload else self.device) | |
def load_ckpt(self, ckpt_path): | |
if ckpt_path is not None: | |
from safetensors.torch import load_file as load_sft | |
print("Loading checkpoint to replace old keys") | |
# load_sft doesn't support torch.device | |
if ckpt_path.endswith('safetensors'): | |
sd = load_sft(ckpt_path, device='cpu') | |
missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True) | |
else: | |
dit_state = torch.load(ckpt_path, map_location='cpu') | |
sd = {} | |
for k in dit_state.keys(): | |
sd[k.replace('module.','')] = dit_state[k] | |
missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True) | |
self.model.to(str(self.device)) | |
print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}") | |
def set_lora(self, local_path: str = None, repo_id: str = None, | |
name: str = None, lora_weight: int = 0.7): | |
checkpoint = load_checkpoint(local_path, repo_id, name) | |
self.update_model_with_lora(checkpoint, lora_weight) | |
def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): | |
checkpoint = load_checkpoint( | |
None, self.hf_lora_collection, self.lora_types_to_names[lora_type] | |
) | |
self.update_model_with_lora(checkpoint, lora_weight) | |
def update_model_with_lora(self, checkpoint, lora_weight): | |
rank = get_lora_rank(checkpoint) | |
lora_attn_procs = {} | |
for name, _ in self.model.attn_processors.items(): | |
lora_state_dict = {} | |
for k in checkpoint.keys(): | |
if name in k: | |
lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight | |
if len(lora_state_dict): | |
if name.startswith("single_blocks"): | |
lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) | |
else: | |
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) | |
lora_attn_procs[name].load_state_dict(lora_state_dict) | |
lora_attn_procs[name].to(self.device) | |
else: | |
if name.startswith("single_blocks"): | |
lora_attn_procs[name] = SingleStreamBlockProcessor() | |
else: | |
lora_attn_procs[name] = DoubleStreamBlockProcessor() | |
self.model.set_attn_processor(lora_attn_procs) | |
def __call__( | |
self, | |
prompt: str, | |
width: int = 512, | |
height: int = 512, | |
guidance: float = 4, | |
num_steps: int = 50, | |
seed: int = 123456789, | |
**kwargs | |
): | |
width = 16 * (width // 16) | |
height = 16 * (height // 16) | |
return self.forward( | |
prompt, | |
width, | |
height, | |
guidance, | |
num_steps, | |
seed, | |
**kwargs | |
) | |
def gradio_generate( | |
self, | |
prompt: str, | |
width: int, | |
height: int, | |
guidance: float, | |
num_steps: int, | |
seed: int, | |
image_prompt1: Image.Image, | |
image_prompt2: Image.Image, | |
image_prompt3: Image.Image, | |
image_prompt4: Image.Image, | |
): | |
ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4] | |
ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)] | |
ref_long_side = 512 if len(ref_imgs) <= 1 else 320 | |
ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs] | |
seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item() | |
img = self(prompt=prompt, width=width, height=height, guidance=guidance, | |
num_steps=num_steps, seed=seed, ref_imgs=ref_imgs) | |
filename = f"output/gradio/{seed}_{prompt[:20]}.png" | |
os.makedirs(os.path.dirname(filename), exist_ok=True) | |
exif_data = Image.Exif() | |
exif_data[ExifTags.Base.Make] = "UNO" | |
exif_data[ExifTags.Base.Model] = self.model_type | |
info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}" | |
exif_data[ExifTags.Base.ImageDescription] = info | |
img.save(filename, format="png", exif=exif_data) | |
return img, filename | |
def forward( | |
self, | |
prompt: str, | |
width: int, | |
height: int, | |
guidance: float, | |
num_steps: int, | |
seed: int, | |
ref_imgs: list[Image.Image] | None = None, | |
pe: Literal['d', 'h', 'w', 'o'] = 'd', | |
): | |
x = get_noise( | |
1, height, width, device=self.device, | |
dtype=torch.bfloat16, seed=seed | |
) | |
timesteps = get_schedule( | |
num_steps, | |
(width // 8) * (height // 8) // (16 * 16), | |
shift=True, | |
) | |
if self.offload: | |
self.ae.encoder = self.ae.encoder.to(self.device) | |
x_1_refs = [ | |
self.ae.encode( | |
(TVF.to_tensor(ref_img) * 2.0 - 1.0) | |
.unsqueeze(0).to(self.device, torch.float32) | |
).to(torch.bfloat16) | |
for ref_img in ref_imgs | |
] | |
if self.offload: | |
self.ae.encoder = self.offload_model_to_cpu(self.ae.encoder) | |
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) | |
inp_cond = prepare_multi_ip( | |
t5=self.t5, clip=self.clip, | |
img=x, | |
prompt=prompt, ref_imgs=x_1_refs, pe=pe | |
) | |
if self.offload: | |
self.offload_model_to_cpu(self.t5, self.clip) | |
self.model = self.model.to(self.device) | |
x = denoise( | |
self.model, | |
**inp_cond, | |
timesteps=timesteps, | |
guidance=guidance, | |
) | |
if self.offload: | |
self.offload_model_to_cpu(self.model) | |
self.ae.decoder.to(x.device) | |
x = unpack(x.float(), height, width) | |
x = self.ae.decode(x) | |
self.offload_model_to_cpu(self.ae.decoder) | |
x1 = x.clamp(-1, 1) | |
x1 = rearrange(x1[-1], "c h w -> h w c") | |
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) | |
return output_img | |
def offload_model_to_cpu(self, *models): | |
if not self.offload: return | |
for model in models: | |
model.cpu() | |
torch.cuda.empty_cache() | |