Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import os | |
from typing import List | |
import contextlib | |
import torch.multiprocessing as mp | |
from dataclasses import dataclass, field | |
from collections import defaultdict | |
import random | |
import numpy as np | |
from PIL import Image, ImageOps | |
import json | |
import torch | |
from peft import PeftModel | |
import torch.nn.functional as F | |
import accelerate | |
import diffusers | |
from diffusers import FluxPipeline | |
from diffusers.utils.torch_utils import is_compiled_module | |
import transformers | |
from tqdm import tqdm | |
from peft import LoraConfig, set_peft_model_state_dict | |
from peft.utils import get_peft_model_state_dict | |
from dreamfuse.models.dreamfuse_flux.transformer import ( | |
FluxTransformer2DModel, | |
FluxTransformerBlock, | |
FluxSingleTransformerBlock, | |
) | |
from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( | |
FlowMatchEulerDiscreteScheduler, | |
) | |
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps | |
from dreamfuse.trains.utils.inference_utils import ( | |
compute_text_embeddings, | |
prepare_latents, | |
_unpack_latents, | |
_pack_latents, | |
_prepare_image_ids, | |
encode_images_cond, | |
get_mask_affine, | |
warp_affine_tensor | |
) | |
def seed_everything(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
class InferenceConfig: | |
# Model paths | |
flux_model_id: str = 'black-forest-labs/FLUX.1-dev' | |
lora_id: str = '' | |
model_choice: str = 'dev' | |
# Model configs | |
lora_rank: int = 16 | |
max_sequence_length: int = 256 | |
guidance_scale: float = 3.5 | |
num_inference_steps: int = 28 | |
mask_ids: int = 16 | |
mask_in_chans: int = 128 | |
mask_out_chans: int = 3072 | |
inference_scale = 1024 | |
# Training configs | |
gradient_checkpointing: bool = False | |
mix_attention_double: bool = True | |
mix_attention_single: bool = True | |
# Image processing | |
image_ids_offset: List[int] = field(default_factory=lambda: [0, 0, 0]) | |
image_tags: List[int] = field(default_factory=lambda: [0, 1, 2]) | |
context_tags: List[int] = None | |
# Runtime configs | |
device: str = "cuda:0" # if torch.cuda.is_available() else "cpu" | |
dtype: torch.dtype = torch.bfloat16 | |
seed: int = 1234 | |
debug: bool = True | |
# I/O configs | |
valid_output_dir: str = "./inference_output" | |
valid_roots: List[str] = field(default_factory=lambda: [ | |
"./", | |
]) | |
valid_jsons: List[str] = field(default_factory=lambda: [ | |
"./examples/data_dreamfuse.json", | |
]) | |
ref_prompts: str = "" | |
truecfg: bool = False | |
text_strength: int = 5 | |
# multi gpu | |
sub_idx:int = 0 | |
total_num:int = 1 | |
def adjust_fg_to_bg(image: Image.Image, mask: Image.Image, target_size: tuple) -> tuple[Image.Image, Image.Image]: | |
width, height = image.size | |
target_w, target_h = target_size | |
scale = min(target_w / width, target_h / height) | |
if scale < 1: | |
new_w = int(width * scale) | |
new_h = int(height * scale) | |
image = image.resize((new_w, new_h)) | |
mask = mask.resize((new_w, new_h)) | |
width, height = new_w, new_h | |
pad_w = target_w - width | |
pad_h = target_h - height | |
padding = ( | |
pad_w // 2, # left | |
pad_h // 2, # top | |
(pad_w + 1) // 2, # right | |
(pad_h + 1) // 2 # bottom | |
) | |
image = ImageOps.expand(image, border=padding, fill=(255, 255, 255)) | |
mask = ImageOps.expand(mask, border=padding, fill=0) | |
return image, mask | |
def find_nearest_bucket_size(input_width, input_height, mode="x64", bucket_size=1024): | |
""" | |
Finds the nearest bucket size for the given input size. | |
""" | |
buckets = { | |
512: [[ 256, 768 ], [ 320, 768 ], [ 320, 704 ], [ 384, 640 ], [ 448, 576 ], [ 512, 512 ], [ 576, 448 ], [ 640, 384 ], [ 704, 320 ], [ 768, 320 ], [ 768, 256 ]], | |
768: [[ 384, 1152 ], [ 480, 1152 ], [ 480, 1056 ], [ 576, 960 ], [ 672, 864 ], [ 768, 768 ], [ 864, 672 ], [ 960, 576 ], [ 1056, 480 ], [ 1152, 480 ], [ 1152, 384 ]], | |
1024: [[ 512, 1536 ], [ 640, 1536 ], [ 640, 1408 ], [ 768, 1280 ], [ 896, 1152 ], [ 1024, 1024 ], [ 1152, 896 ], [ 1280, 768 ], [ 1408, 640 ], [ 1536, 640 ], [ 1536, 512 ]] | |
} | |
buckets = buckets[bucket_size] | |
aspect_ratios = [w / h for (w, h) in buckets] | |
assert mode in ["x64", "x8"] | |
if mode == "x64": | |
asp = input_width / input_height | |
diff = [abs(ar - asp) for ar in aspect_ratios] | |
bucket_id = int(np.argmin(diff)) | |
gen_width, gen_height = buckets[bucket_id] | |
elif mode == "x8": | |
max_pixels = 1024 * 1024 | |
ratio = (max_pixels / (input_width * input_height)) ** (0.5) | |
gen_width, gen_height = round(input_width * ratio), round(input_height * ratio) | |
gen_width = gen_width - gen_width % 8 | |
gen_height = gen_height - gen_height % 8 | |
else: | |
raise NotImplementedError | |
return (gen_width, gen_height) | |
def make_image_grid(images, rows, cols, size=None): | |
assert len(images) == rows * cols | |
if size is not None: | |
images = [img.resize((size[0], size[1])) for img in images] | |
w, h = images[0].size | |
grid = Image.new("RGB", size=(cols * w, rows * h)) | |
for i, img in enumerate(images): | |
grid.paste(img.convert("RGB"), box=(i % cols * w, i // cols * h)) | |
return grid | |
class DreamFuseInference: | |
def __init__(self, config: InferenceConfig): | |
self.config = config | |
self.device = torch.device(config.device) | |
torch.backends.cuda.matmul.allow_tf32 = True | |
seed_everything(config.seed) | |
self._init_models() | |
def _init_models(self): | |
# Initialize tokenizers | |
self.tokenizer_one = transformers.CLIPTokenizer.from_pretrained( | |
self.config.flux_model_id, subfolder="tokenizer" | |
) | |
self.tokenizer_two = transformers.T5TokenizerFast.from_pretrained( | |
self.config.flux_model_id, subfolder="tokenizer_2" | |
) | |
# Initialize text encoders | |
self.text_encoder_one = transformers.CLIPTextModel.from_pretrained( | |
self.config.flux_model_id, subfolder="text_encoder" | |
).to(device=self.device, dtype=self.config.dtype) | |
self.text_encoder_two = transformers.T5EncoderModel.from_pretrained( | |
self.config.flux_model_id, subfolder="text_encoder_2" | |
).to(device=self.device, dtype=self.config.dtype) | |
# Initialize VAE | |
self.vae = diffusers.AutoencoderKL.from_pretrained( | |
self.config.flux_model_id, subfolder="vae" | |
).to(device=self.device, dtype=self.config.dtype) | |
# Initialize denoising model | |
self.denoise_model = FluxTransformer2DModel.from_pretrained( | |
self.config.flux_model_id, subfolder="transformer" | |
).to(device=self.device, dtype=self.config.dtype) | |
if self.config.image_tags is not None or self.config.context_tags is not None: | |
num_image_tag_embeddings = max(self.config.image_tags) + 1 if self.config.image_tags is not None else 0 | |
num_context_tag_embeddings = max(self.config.context_tags) + 1 if self.config.context_tags is not None else 0 | |
self.denoise_model.set_tag_embeddings( | |
num_image_tag_embeddings=num_image_tag_embeddings, | |
num_context_tag_embeddings=num_context_tag_embeddings, | |
) | |
# Add LoRA | |
self.denoise_model = PeftModel.from_pretrained( | |
self.denoise_model, | |
self.config.lora_id, | |
adapter_weights=[1.0], | |
device_map={"": self.device} | |
) | |
# Initialize scheduler | |
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
self.config.flux_model_id, subfolder="scheduler" | |
) | |
# Set models to eval mode | |
for model in [self.text_encoder_one, self.text_encoder_two, self.vae, self.denoise_model]: | |
model.eval() | |
model.requires_grad_(False) | |
def _compute_text_embeddings(self, prompt): | |
return compute_text_embeddings( | |
self.config, | |
prompt, | |
[self.text_encoder_one, self.text_encoder_two], | |
[self.tokenizer_one, self.tokenizer_two], | |
self.device | |
) | |
def resize_to_fit_within(self, reference_image, target_image): | |
ref_width, ref_height = reference_image.size | |
target_width, target_height = target_image.size | |
scale_width = ref_width / target_width | |
scale_height = ref_height / target_height | |
scale = min(scale_width, scale_height) | |
new_width = int(target_width * scale) | |
new_height = int(target_height * scale) | |
resized_image = target_image.resize((new_width, new_height), Image.LANCZOS) | |
return resized_image | |
def pad_or_crop(self, img, target_size, fill_color=(255, 255, 255)): | |
iw, ih = img.size | |
tw, th = target_size | |
# 计算裁剪区域:若原图大于目标尺寸,则裁剪出中间部分;否则全部保留 | |
left = (iw - tw) // 2 if iw >= tw else 0 | |
top = (ih - th) // 2 if ih >= th else 0 | |
cropped = img.crop((left, top, left + min(iw, tw), top + min(ih, th))) | |
# 新建目标尺寸的图像,并将裁剪后的图像居中粘贴 | |
new_img = Image.new(img.mode, target_size, fill_color) | |
offset = ((tw - cropped.width) // 2, (th - cropped.height) // 2) | |
new_img.paste(cropped, offset) | |
return new_img | |
def transform_foreground_original(self, original_fg, original_bg, transformation_info, canvas_size=400): | |
drag_left = float(transformation_info.get("drag_left", 0)) | |
drag_top = float(transformation_info.get("drag_top", 0)) | |
scale_ratio = float(transformation_info.get("scale_ratio", 1)) | |
data_orig_width = float(transformation_info.get("data_original_width", canvas_size)) | |
data_orig_height = float(transformation_info.get("data_original_height", canvas_size)) | |
drag_width = float(transformation_info.get("drag_width", 0)) | |
drag_height = float(transformation_info.get("drag_height", 0)) | |
scale_ori_fg = canvas_size / max(original_fg.width, original_fg.height) | |
scale_ori_bg = canvas_size / max(original_bg.width, original_bg.height) | |
# 计算未缩放状态下(预览中)的默认居中位置(前景图未拖拽时的理想位置) | |
default_left = (canvas_size - data_orig_width) / 2.0 | |
default_top = (canvas_size - data_orig_height) / 2.0 | |
# 在未缩放状态下,计算实际拖拽产生的偏移(单位:像素,在预览尺寸下计算) | |
offset_preview_x = drag_left - default_left | |
offset_preview_y = drag_top - default_top | |
offset_ori_x = offset_preview_x / scale_ori_fg | |
offset_ori_y = offset_preview_y / scale_ori_fg | |
new_width = int(original_fg.width * scale_ratio) | |
new_height = int(original_fg.height * scale_ratio) | |
scale_fg = original_fg.resize((new_width, new_height)) | |
output = Image.new("RGBA", (original_fg.width, original_fg.height), (255, 255, 255, 0)) | |
output.paste(scale_fg, (int(offset_ori_x), int(offset_ori_y))) | |
new_width_fgbg = original_fg.width * scale_ori_fg / scale_ori_bg | |
new_height_fgbg = original_fg.height * scale_ori_fg / scale_ori_bg | |
scale_fgbg = output.resize((int(new_width_fgbg), int(new_height_fgbg))) | |
final_output = Image.new("RGBA", (original_bg.width, original_bg.height), (255, 255, 255, 0)) | |
scale_fgbg = self.pad_or_crop(scale_fgbg, (original_bg.width, original_bg.height), (255, 255, 255, 0)) | |
final_output.paste(scale_fgbg, (0, 0)) | |
fit_fg = self.resize_to_fit_within(original_bg, original_fg) | |
fit_fg = self.pad_or_crop(fit_fg, original_bg.size, (255, 255, 255, 0)) | |
return final_output, fit_fg | |
def gradio_generate(self, background_img, foreground_img, transformation_info, seed, prompt, enable_gui, cfg=3.5, size_select="1024", text_strength=1, truecfg=False): | |
try: | |
trans = json.loads(transformation_info) | |
except: | |
trans = {} | |
size_select = int(size_select) | |
# if size_select == 1024 and prompt != "": text_strength = 5 | |
# if size_select == 768 and prompt != "": text_strength = 3 | |
r, g, b, ori_a = foreground_img.split() | |
fg_img_scale, fg_img = self.transform_foreground_original(foreground_img, background_img, trans) | |
new_r, new_g, new_b, new_a = fg_img_scale.split() | |
foreground_img_scale = Image.merge("RGB", (new_r, new_g, new_b)) | |
r, g, b, ori_a = fg_img.split() | |
foreground_img = Image.merge("RGB", (r, g, b)) | |
foreground_img_save = foreground_img.copy() | |
ori_a = ori_a.convert("L") | |
new_a = new_a.convert("L") | |
foreground_img.paste((255, 255, 255), mask=ImageOps.invert(ori_a)) | |
images = self.model_generate(foreground_img.copy(), background_img.copy(), | |
ori_a, new_a, | |
enable_mask_affine=enable_gui, | |
prompt=prompt, | |
offset_cond=[0, 1, 0] if not enable_gui else None, | |
seed=seed, | |
cfg=cfg, | |
size_select=size_select, | |
text_strength=text_strength, | |
truecfg=truecfg) | |
images = Image.fromarray(images[0], "RGB") | |
images = images.resize(background_img.size) | |
# images.thumbnail((640, 640), Image.LANCZOS) | |
return images | |
def model_generate(self, fg_image, bg_image, ori_fg_mask, new_fg_mask, enable_mask_affine=True, prompt="", offset_cond=None, seed=None, cfg=3.5, size_select=1024, text_strength=1, truecfg=False): | |
batch_size = 1 | |
# Prepare images | |
# adjust bg->fg size | |
fg_image, ori_fg_mask = adjust_fg_to_bg(fg_image, ori_fg_mask, bg_image.size) | |
bucket_size = find_nearest_bucket_size(bg_image.size[0], bg_image.size[1], bucket_size=size_select) | |
fg_image = fg_image.resize(bucket_size) | |
bg_image = bg_image.resize(bucket_size) | |
mask_affine = None | |
if enable_mask_affine: | |
ori_fg_mask = ori_fg_mask.resize(bucket_size) | |
new_fg_mask = new_fg_mask.resize(bucket_size) | |
mask_affine = get_mask_affine(new_fg_mask, ori_fg_mask) | |
# Get embeddings | |
prompt_embeds, pooled_prompt_embeds, text_ids = self._compute_text_embeddings(prompt) | |
prompt_embeds = prompt_embeds.repeat(1, text_strength, 1) | |
text_ids = text_ids.repeat(text_strength, 1) | |
# Prepare | |
if self.config.model_choice == "dev": | |
guidance = torch.full([1], cfg, device=self.device, dtype=torch.float32) | |
guidance = guidance.expand(batch_size) | |
else: | |
guidance = None | |
# Prepare generator | |
if seed is None: | |
seed = self.config.seed | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
# Prepare condition latents | |
condition_image_latents = self._encode_images([fg_image, bg_image]) | |
if offset_cond is None: | |
offset_cond = self.config.image_ids_offset | |
offset_cond = offset_cond[1:] | |
cond_latent_image_ids = [] | |
for offset_ in offset_cond: | |
cond_latent_image_ids.append( | |
self._prepare_image_ids( | |
condition_image_latents.shape[2] // 2, | |
condition_image_latents.shape[3] // 2, | |
offset_w=offset_ * condition_image_latents.shape[3] // 2 | |
) | |
) | |
if mask_affine is not None: | |
affine_H, affine_W = condition_image_latents.shape[2] // 2, condition_image_latents.shape[3] // 2 | |
scale_factor = 1 / 16 | |
cond_latent_image_ids_fg = cond_latent_image_ids[0].reshape(affine_H, affine_W, 3).clone() | |
# opt 1 | |
cond_latent_image_ids[0] = warp_affine_tensor( | |
cond_latent_image_ids_fg, mask_affine, output_size=(affine_H, affine_W), | |
scale_factor=scale_factor, device=self.device, | |
) | |
cond_latent_image_ids = torch.stack(cond_latent_image_ids) | |
# Pack condition latents | |
cond_image_latents = self._pack_latents(condition_image_latents) | |
cond_input = { | |
"image_latents": cond_image_latents, | |
"image_ids": cond_latent_image_ids, | |
} | |
# Prepare initial latents | |
width, height = bucket_size | |
num_channels_latents = self.denoise_model.config.in_channels // 4 | |
latents, latent_image_ids = self._prepare_latents( | |
batch_size, num_channels_latents, height, width, generator | |
) | |
# Setup timesteps | |
sigmas = np.linspace(1.0, 1 / self.config.num_inference_steps, self.config.num_inference_steps) | |
image_seq_len = latents.shape[1] | |
mu = calculate_shift( | |
image_seq_len, | |
self.scheduler.config.base_image_seq_len, | |
self.scheduler.config.max_image_seq_len, | |
self.scheduler.config.base_shift, | |
self.scheduler.config.max_shift, | |
) | |
timesteps, num_inference_steps = retrieve_timesteps( | |
self.scheduler, | |
self.config.num_inference_steps, | |
self.device, | |
sigmas=sigmas, | |
mu=mu, | |
) | |
# Denoising loop | |
for i, t in enumerate(timesteps): | |
timestep = t.expand(latents.shape[0]).to(latents.dtype) | |
with torch.autocast(enabled=True, device_type="cuda", dtype=self.config.dtype): | |
noise_pred = self.denoise_model( | |
hidden_states=latents, | |
cond_input=cond_input, | |
timestep=timestep / 1000, | |
guidance=guidance, | |
pooled_projections=pooled_prompt_embeds, | |
encoder_hidden_states=prompt_embeds, | |
txt_ids=text_ids, | |
img_ids=latent_image_ids, | |
data_num_per_group=batch_size, | |
image_tags=self.config.image_tags, | |
context_tags=self.config.context_tags, | |
max_sequence_length=self.config.max_sequence_length, | |
mix_attention_double=self.config.mix_attention_double, | |
mix_attention_single=self.config.mix_attention_single, | |
joint_attention_kwargs=None, | |
return_dict=False, | |
)[0] | |
if truecfg and i >= 1: | |
guidance_neg = torch.full([1], 1, device=self.device, dtype=torch.float32) | |
guidance_neg = guidance_neg.expand(batch_size) | |
noise_pred_neg = self.denoise_model( | |
hidden_states=latents, | |
cond_input=cond_input, | |
timestep=timestep / 1000, | |
guidance=guidance, | |
pooled_projections=pooled_prompt_embeds, | |
encoder_hidden_states=prompt_embeds, | |
txt_ids=text_ids, | |
img_ids=latent_image_ids, | |
data_num_per_group=batch_size, | |
image_tags=self.config.image_tags, | |
context_tags=self.config.context_tags, | |
max_sequence_length=self.config.max_sequence_length, | |
mix_attention_double=self.config.mix_attention_double, | |
mix_attention_single=self.config.mix_attention_single, | |
joint_attention_kwargs=None, | |
return_dict=False, | |
)[0] | |
noise_pred = noise_pred_neg + 5 * (noise_pred - noise_pred_neg) | |
# Compute previous noisy sample | |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
# Decode latents | |
latents = self._unpack_latents(latents, height, width) | |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor | |
images = self.vae.decode(latents, return_dict=False)[0] | |
# Post-process images | |
images = images.add(1).mul(127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy() | |
return images | |
def _encode_images(self, images): | |
return encode_images_cond(self.vae, [images], self.device) | |
def _prepare_image_ids(self, h, w, offset_w=0): | |
return _prepare_image_ids(h, w, offset_w=offset_w).to(self.device) | |
def _pack_latents(self, latents): | |
b, c, h, w = latents.shape | |
return _pack_latents(latents, b, c, h, w) | |
def _unpack_latents(self, latents, height, width): | |
vae_scale = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
return _unpack_latents(latents, height, width, vae_scale) | |
def _prepare_latents(self, batch_size, num_channels_latents, height, width, generator): | |
vae_scale = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
latents, latent_image_ids = prepare_latents( | |
batch_size=batch_size, | |
num_channels_latents=num_channels_latents, | |
vae_downsample_factor=vae_scale, | |
height=height, | |
width=width, | |
dtype=self.config.dtype, | |
device=self.device, | |
generator=generator, | |
offset=None | |
) | |
return latents, latent_image_ids | |