import gc import os from typing import List import contextlib import torch.multiprocessing as mp from dataclasses import dataclass, field from collections import defaultdict import random import numpy as np from PIL import Image, ImageOps import json import torch from peft import PeftModel import torch.nn.functional as F import accelerate import diffusers from diffusers import FluxPipeline from diffusers.utils.torch_utils import is_compiled_module import transformers from tqdm import tqdm from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from dreamfuse.models.dreamfuse_flux.transformer import ( FluxTransformer2DModel, FluxTransformerBlock, FluxSingleTransformerBlock, ) from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( FlowMatchEulerDiscreteScheduler, ) from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps from dreamfuse.trains.utils.inference_utils import ( compute_text_embeddings, prepare_latents, _unpack_latents, _pack_latents, _prepare_image_ids, encode_images_cond, get_mask_affine, warp_affine_tensor ) def seed_everything(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(seed) @dataclass class InferenceConfig: # Model paths flux_model_id: str = 'black-forest-labs/FLUX.1-dev' lora_id: str = '' model_choice: str = 'dev' # Model configs lora_rank: int = 16 max_sequence_length: int = 256 guidance_scale: float = 3.5 num_inference_steps: int = 28 mask_ids: int = 16 mask_in_chans: int = 128 mask_out_chans: int = 3072 inference_scale = 1024 # Training configs gradient_checkpointing: bool = False mix_attention_double: bool = True mix_attention_single: bool = True # Image processing image_ids_offset: List[int] = field(default_factory=lambda: [0, 0, 0]) image_tags: List[int] = field(default_factory=lambda: [0, 1, 2]) context_tags: List[int] = None # Runtime configs device: str = "cuda:0" # if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.bfloat16 seed: int = 1234 debug: bool = True # I/O configs valid_output_dir: str = "./inference_output" valid_roots: List[str] = field(default_factory=lambda: [ "./", ]) valid_jsons: List[str] = field(default_factory=lambda: [ "./examples/data_dreamfuse.json", ]) ref_prompts: str = "" truecfg: bool = False text_strength: int = 5 # multi gpu sub_idx:int = 0 total_num:int = 1 def adjust_fg_to_bg(image: Image.Image, mask: Image.Image, target_size: tuple) -> tuple[Image.Image, Image.Image]: width, height = image.size target_w, target_h = target_size scale = min(target_w / width, target_h / height) if scale < 1: new_w = int(width * scale) new_h = int(height * scale) image = image.resize((new_w, new_h)) mask = mask.resize((new_w, new_h)) width, height = new_w, new_h pad_w = target_w - width pad_h = target_h - height padding = ( pad_w // 2, # left pad_h // 2, # top (pad_w + 1) // 2, # right (pad_h + 1) // 2 # bottom ) image = ImageOps.expand(image, border=padding, fill=(255, 255, 255)) mask = ImageOps.expand(mask, border=padding, fill=0) return image, mask def find_nearest_bucket_size(input_width, input_height, mode="x64", bucket_size=1024): """ Finds the nearest bucket size for the given input size. """ buckets = { 512: [[ 256, 768 ], [ 320, 768 ], [ 320, 704 ], [ 384, 640 ], [ 448, 576 ], [ 512, 512 ], [ 576, 448 ], [ 640, 384 ], [ 704, 320 ], [ 768, 320 ], [ 768, 256 ]], 768: [[ 384, 1152 ], [ 480, 1152 ], [ 480, 1056 ], [ 576, 960 ], [ 672, 864 ], [ 768, 768 ], [ 864, 672 ], [ 960, 576 ], [ 1056, 480 ], [ 1152, 480 ], [ 1152, 384 ]], 1024: [[ 512, 1536 ], [ 640, 1536 ], [ 640, 1408 ], [ 768, 1280 ], [ 896, 1152 ], [ 1024, 1024 ], [ 1152, 896 ], [ 1280, 768 ], [ 1408, 640 ], [ 1536, 640 ], [ 1536, 512 ]] } buckets = buckets[bucket_size] aspect_ratios = [w / h for (w, h) in buckets] assert mode in ["x64", "x8"] if mode == "x64": asp = input_width / input_height diff = [abs(ar - asp) for ar in aspect_ratios] bucket_id = int(np.argmin(diff)) gen_width, gen_height = buckets[bucket_id] elif mode == "x8": max_pixels = 1024 * 1024 ratio = (max_pixels / (input_width * input_height)) ** (0.5) gen_width, gen_height = round(input_width * ratio), round(input_height * ratio) gen_width = gen_width - gen_width % 8 gen_height = gen_height - gen_height % 8 else: raise NotImplementedError return (gen_width, gen_height) def make_image_grid(images, rows, cols, size=None): assert len(images) == rows * cols if size is not None: images = [img.resize((size[0], size[1])) for img in images] w, h = images[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(images): grid.paste(img.convert("RGB"), box=(i % cols * w, i // cols * h)) return grid class DreamFuseInference: def __init__(self, config: InferenceConfig): self.config = config self.device = torch.device(config.device) torch.backends.cuda.matmul.allow_tf32 = True seed_everything(config.seed) self._init_models() def _init_models(self): # Initialize tokenizers self.tokenizer_one = transformers.CLIPTokenizer.from_pretrained( self.config.flux_model_id, subfolder="tokenizer" ) self.tokenizer_two = transformers.T5TokenizerFast.from_pretrained( self.config.flux_model_id, subfolder="tokenizer_2" ) # Initialize text encoders self.text_encoder_one = transformers.CLIPTextModel.from_pretrained( self.config.flux_model_id, subfolder="text_encoder" ).to(device=self.device, dtype=self.config.dtype) self.text_encoder_two = transformers.T5EncoderModel.from_pretrained( self.config.flux_model_id, subfolder="text_encoder_2" ).to(device=self.device, dtype=self.config.dtype) # Initialize VAE self.vae = diffusers.AutoencoderKL.from_pretrained( self.config.flux_model_id, subfolder="vae" ).to(device=self.device, dtype=self.config.dtype) # Initialize denoising model self.denoise_model = FluxTransformer2DModel.from_pretrained( self.config.flux_model_id, subfolder="transformer" ).to(device=self.device, dtype=self.config.dtype) if self.config.image_tags is not None or self.config.context_tags is not None: num_image_tag_embeddings = max(self.config.image_tags) + 1 if self.config.image_tags is not None else 0 num_context_tag_embeddings = max(self.config.context_tags) + 1 if self.config.context_tags is not None else 0 self.denoise_model.set_tag_embeddings( num_image_tag_embeddings=num_image_tag_embeddings, num_context_tag_embeddings=num_context_tag_embeddings, ) # Add LoRA self.denoise_model = PeftModel.from_pretrained( self.denoise_model, self.config.lora_id, adapter_weights=[1.0], device_map={"": self.device} ) # Initialize scheduler self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( self.config.flux_model_id, subfolder="scheduler" ) # Set models to eval mode for model in [self.text_encoder_one, self.text_encoder_two, self.vae, self.denoise_model]: model.eval() model.requires_grad_(False) def _compute_text_embeddings(self, prompt): return compute_text_embeddings( self.config, prompt, [self.text_encoder_one, self.text_encoder_two], [self.tokenizer_one, self.tokenizer_two], self.device ) def resize_to_fit_within(self, reference_image, target_image): ref_width, ref_height = reference_image.size target_width, target_height = target_image.size scale_width = ref_width / target_width scale_height = ref_height / target_height scale = min(scale_width, scale_height) new_width = int(target_width * scale) new_height = int(target_height * scale) resized_image = target_image.resize((new_width, new_height), Image.LANCZOS) return resized_image def pad_or_crop(self, img, target_size, fill_color=(255, 255, 255)): iw, ih = img.size tw, th = target_size # 计算裁剪区域:若原图大于目标尺寸,则裁剪出中间部分;否则全部保留 left = (iw - tw) // 2 if iw >= tw else 0 top = (ih - th) // 2 if ih >= th else 0 cropped = img.crop((left, top, left + min(iw, tw), top + min(ih, th))) # 新建目标尺寸的图像,并将裁剪后的图像居中粘贴 new_img = Image.new(img.mode, target_size, fill_color) offset = ((tw - cropped.width) // 2, (th - cropped.height) // 2) new_img.paste(cropped, offset) return new_img def transform_foreground_original(self, original_fg, original_bg, transformation_info, canvas_size=400): drag_left = float(transformation_info.get("drag_left", 0)) drag_top = float(transformation_info.get("drag_top", 0)) scale_ratio = float(transformation_info.get("scale_ratio", 1)) data_orig_width = float(transformation_info.get("data_original_width", canvas_size)) data_orig_height = float(transformation_info.get("data_original_height", canvas_size)) drag_width = float(transformation_info.get("drag_width", 0)) drag_height = float(transformation_info.get("drag_height", 0)) scale_ori_fg = canvas_size / max(original_fg.width, original_fg.height) scale_ori_bg = canvas_size / max(original_bg.width, original_bg.height) # 计算未缩放状态下(预览中)的默认居中位置(前景图未拖拽时的理想位置) default_left = (canvas_size - data_orig_width) / 2.0 default_top = (canvas_size - data_orig_height) / 2.0 # 在未缩放状态下,计算实际拖拽产生的偏移(单位:像素,在预览尺寸下计算) offset_preview_x = drag_left - default_left offset_preview_y = drag_top - default_top offset_ori_x = offset_preview_x / scale_ori_fg offset_ori_y = offset_preview_y / scale_ori_fg new_width = int(original_fg.width * scale_ratio) new_height = int(original_fg.height * scale_ratio) scale_fg = original_fg.resize((new_width, new_height)) output = Image.new("RGBA", (original_fg.width, original_fg.height), (255, 255, 255, 0)) output.paste(scale_fg, (int(offset_ori_x), int(offset_ori_y))) new_width_fgbg = original_fg.width * scale_ori_fg / scale_ori_bg new_height_fgbg = original_fg.height * scale_ori_fg / scale_ori_bg scale_fgbg = output.resize((int(new_width_fgbg), int(new_height_fgbg))) final_output = Image.new("RGBA", (original_bg.width, original_bg.height), (255, 255, 255, 0)) scale_fgbg = self.pad_or_crop(scale_fgbg, (original_bg.width, original_bg.height), (255, 255, 255, 0)) final_output.paste(scale_fgbg, (0, 0)) fit_fg = self.resize_to_fit_within(original_bg, original_fg) fit_fg = self.pad_or_crop(fit_fg, original_bg.size, (255, 255, 255, 0)) return final_output, fit_fg @torch.inference_mode() def gradio_generate(self, background_img, foreground_img, transformation_info, seed, prompt, enable_gui, cfg=3.5, size_select="1024", text_strength=1, truecfg=False): try: trans = json.loads(transformation_info) except: trans = {} size_select = int(size_select) # if size_select == 1024 and prompt != "": text_strength = 5 # if size_select == 768 and prompt != "": text_strength = 3 r, g, b, ori_a = foreground_img.split() fg_img_scale, fg_img = self.transform_foreground_original(foreground_img, background_img, trans) new_r, new_g, new_b, new_a = fg_img_scale.split() foreground_img_scale = Image.merge("RGB", (new_r, new_g, new_b)) r, g, b, ori_a = fg_img.split() foreground_img = Image.merge("RGB", (r, g, b)) foreground_img_save = foreground_img.copy() ori_a = ori_a.convert("L") new_a = new_a.convert("L") foreground_img.paste((255, 255, 255), mask=ImageOps.invert(ori_a)) images = self.model_generate(foreground_img.copy(), background_img.copy(), ori_a, new_a, enable_mask_affine=enable_gui, prompt=prompt, offset_cond=[0, 1, 0] if not enable_gui else None, seed=seed, cfg=cfg, size_select=size_select, text_strength=text_strength, truecfg=truecfg) images = Image.fromarray(images[0], "RGB") images = images.resize(background_img.size) # images.thumbnail((640, 640), Image.LANCZOS) return images @torch.inference_mode() def model_generate(self, fg_image, bg_image, ori_fg_mask, new_fg_mask, enable_mask_affine=True, prompt="", offset_cond=None, seed=None, cfg=3.5, size_select=1024, text_strength=1, truecfg=False): batch_size = 1 # Prepare images # adjust bg->fg size fg_image, ori_fg_mask = adjust_fg_to_bg(fg_image, ori_fg_mask, bg_image.size) bucket_size = find_nearest_bucket_size(bg_image.size[0], bg_image.size[1], bucket_size=size_select) fg_image = fg_image.resize(bucket_size) bg_image = bg_image.resize(bucket_size) mask_affine = None if enable_mask_affine: ori_fg_mask = ori_fg_mask.resize(bucket_size) new_fg_mask = new_fg_mask.resize(bucket_size) mask_affine = get_mask_affine(new_fg_mask, ori_fg_mask) # Get embeddings prompt_embeds, pooled_prompt_embeds, text_ids = self._compute_text_embeddings(prompt) prompt_embeds = prompt_embeds.repeat(1, text_strength, 1) text_ids = text_ids.repeat(text_strength, 1) # Prepare if self.config.model_choice == "dev": guidance = torch.full([1], cfg, device=self.device, dtype=torch.float32) guidance = guidance.expand(batch_size) else: guidance = None # Prepare generator if seed is None: seed = self.config.seed generator = torch.Generator(device=self.device).manual_seed(seed) # Prepare condition latents condition_image_latents = self._encode_images([fg_image, bg_image]) if offset_cond is None: offset_cond = self.config.image_ids_offset offset_cond = offset_cond[1:] cond_latent_image_ids = [] for offset_ in offset_cond: cond_latent_image_ids.append( self._prepare_image_ids( condition_image_latents.shape[2] // 2, condition_image_latents.shape[3] // 2, offset_w=offset_ * condition_image_latents.shape[3] // 2 ) ) if mask_affine is not None: affine_H, affine_W = condition_image_latents.shape[2] // 2, condition_image_latents.shape[3] // 2 scale_factor = 1 / 16 cond_latent_image_ids_fg = cond_latent_image_ids[0].reshape(affine_H, affine_W, 3).clone() # opt 1 cond_latent_image_ids[0] = warp_affine_tensor( cond_latent_image_ids_fg, mask_affine, output_size=(affine_H, affine_W), scale_factor=scale_factor, device=self.device, ) cond_latent_image_ids = torch.stack(cond_latent_image_ids) # Pack condition latents cond_image_latents = self._pack_latents(condition_image_latents) cond_input = { "image_latents": cond_image_latents, "image_ids": cond_latent_image_ids, } # Prepare initial latents width, height = bucket_size num_channels_latents = self.denoise_model.config.in_channels // 4 latents, latent_image_ids = self._prepare_latents( batch_size, num_channels_latents, height, width, generator ) # Setup timesteps sigmas = np.linspace(1.0, 1 / self.config.num_inference_steps, self.config.num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, self.config.num_inference_steps, self.device, sigmas=sigmas, mu=mu, ) # Denoising loop for i, t in enumerate(timesteps): timestep = t.expand(latents.shape[0]).to(latents.dtype) with torch.autocast(enabled=True, device_type="cuda", dtype=self.config.dtype): noise_pred = self.denoise_model( hidden_states=latents, cond_input=cond_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, data_num_per_group=batch_size, image_tags=self.config.image_tags, context_tags=self.config.context_tags, max_sequence_length=self.config.max_sequence_length, mix_attention_double=self.config.mix_attention_double, mix_attention_single=self.config.mix_attention_single, joint_attention_kwargs=None, return_dict=False, )[0] if truecfg and i >= 1: guidance_neg = torch.full([1], 1, device=self.device, dtype=torch.float32) guidance_neg = guidance_neg.expand(batch_size) noise_pred_neg = self.denoise_model( hidden_states=latents, cond_input=cond_input, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, data_num_per_group=batch_size, image_tags=self.config.image_tags, context_tags=self.config.context_tags, max_sequence_length=self.config.max_sequence_length, mix_attention_double=self.config.mix_attention_double, mix_attention_single=self.config.mix_attention_single, joint_attention_kwargs=None, return_dict=False, )[0] noise_pred = noise_pred_neg + 5 * (noise_pred - noise_pred_neg) # Compute previous noisy sample latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # Decode latents latents = self._unpack_latents(latents, height, width) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor images = self.vae.decode(latents, return_dict=False)[0] # Post-process images images = images.add(1).mul(127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy() return images def _encode_images(self, images): return encode_images_cond(self.vae, [images], self.device) def _prepare_image_ids(self, h, w, offset_w=0): return _prepare_image_ids(h, w, offset_w=offset_w).to(self.device) def _pack_latents(self, latents): b, c, h, w = latents.shape return _pack_latents(latents, b, c, h, w) def _unpack_latents(self, latents, height, width): vae_scale = 2 ** (len(self.vae.config.block_out_channels) - 1) return _unpack_latents(latents, height, width, vae_scale) def _prepare_latents(self, batch_size, num_channels_latents, height, width, generator): vae_scale = 2 ** (len(self.vae.config.block_out_channels) - 1) latents, latent_image_ids = prepare_latents( batch_size=batch_size, num_channels_latents=num_channels_latents, vae_downsample_factor=vae_scale, height=height, width=width, dtype=self.config.dtype, device=self.device, generator=generator, offset=None ) return latents, latent_image_ids