# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import torch import json import os import torchvision from torchvision.utils import make_grid from torchvision.transforms.functional import to_pil_image from tqdm import tqdm from PIL import Image from models.text import TextModel from models.vae import AutoencoderKL from models.unet_2d_condition_custom import UNet2DConditionModel as UNet2DConditionModelDiffusers from schedulers.ddim import DDIMScheduler from schedulers.dpm_s import DPMSolverSingleStepScheduler from schedulers.utils import get_betas from inference_utils import find_phrase_positions_in_text, classifier_free_guidance_image_prompt_cascade from mask_generation import mask_generation from utils import instantiate_from_config # Argument parser parser = argparse.ArgumentParser() parser.add_argument("--width", type=int, default=512) parser.add_argument("--height", type=int, default=512) parser.add_argument("--samples_per_prompt", type=int, required=True) parser.add_argument("--nrow", type=int, default=4) parser.add_argument("--sample_steps", type=int, required=True) parser.add_argument("--schedule_type", type=str, default="squared_linear") # default, `squared_linear parser.add_argument("--scheduler_type", type=str, default="dpm", choices=["ddim", "dpm"]) # default, "dpm" parser.add_argument("--schedule_shift_snr", type=float, default=1) # default, 1 parser.add_argument("--text_encoder_variant", type=str, nargs="+") parser.add_argument("--vae_config", type=str, default="configs/vae.json") # default parser.add_argument("--vae_checkpoint", type=str, required=True) parser.add_argument("--unet_config", type=str, required=True) parser.add_argument("--unet_checkpoint", type=str, required=True) parser.add_argument("--unet_checkpoint_base_model", type=str, default="") parser.add_argument("--unet_prediction", type=str, choices=DDIMScheduler.prediction_types, default="epsilon") # default, "epsilon" parser.add_argument("--negative_prompt", type=str, default="prompts/validation_negative.txt") # default parser.add_argument("--compile", action="store_true", default=False) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--guidance_weight", type=float, default=7.5) parser.add_argument("--seed", type=int, default=666) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--text_prompt", type=str, required=True) parser.add_argument("--image_prompt_path", type=str, required=True) parser.add_argument("--target_phrase", type=str, required=True) parser.add_argument("--mask_scope", type=float, default=0.20) parser.add_argument("--mask_strategy", type=str, nargs="+", default=["max_norm"]) parser.add_argument("--mask_reused_step", type=int, default=12) args = parser.parse_args() # Initialize unet model with open(args.unet_config) as unet_config_file: unet_config = json.load(unet_config_file) # Settings for image encoder vision_model_config = unet_config.pop("vision_model_config", None) args.vision_model_config = vision_model_config.pop("vision_model_config", None) unet_type = unet_config.pop("type", None) unet_model = UNet2DConditionModelDiffusers(**unet_config) unet_model.eval().to(args.device) unet_model.load_state_dict(torch.load(args.unet_checkpoint, map_location=args.device), strict=False) print("loading unet model finished.") if args.unet_checkpoint_base_model != "": if "safetensors" in args.unet_checkpoint_base_model: from safetensors import safe_open tensors = {} with safe_open(args.unet_checkpoint_base_model, framework="pt", device='cpu') as f: for k in f.keys(): new_k = k.replace("model.diffusion_model.", "") tensors[k] = f.get_tensor(k) unet_model.load_state_dict(tensors, strict=False) else: unet_model.load_state_dict(torch.load(args.unet_checkpoint_base_model, map_location=args.device), strict=False) unet_model = torch.compile(unet_model, disable=not args.compile) print("loading unet base model finished.") # Initialize vae model with open(args.vae_config) as vae_config_file: vae_config = json.load(vae_config_file) vae_downsample_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) # 2 ** 3 = 8 vae_model = AutoencoderKL(**vae_config) vae_model.eval().to(args.device) vae_model.load_state_dict(torch.load(args.vae_checkpoint, map_location=args.device)) vae_decoder = torch.compile(lambda x: vae_model.decode(x / vae_model.scaling_factor).sample.clip(-1, 1), disable=not args.compile) vae_encoder = torch.compile(lambda x: vae_model.encode(x).latent_dist.mode().mul_(vae_model.scaling_factor), disable=not args.compile) print("loading vae finished.") # Initialize ddim scheduler ddim_train_steps = 1000 ddim_betas = get_betas(name=args.schedule_type, num_steps=ddim_train_steps, shift_snr=args.schedule_shift_snr, terminal_pure_noise=False) scheduler_class = DPMSolverSingleStepScheduler if args.scheduler_type == 'dpm' else DDIMScheduler scheduler = scheduler_class(betas=ddim_betas, num_train_timesteps=ddim_train_steps, num_inference_timesteps=args.sample_steps, device=args.device) infer_timesteps = scheduler.timesteps # Initialize text model text_model = TextModel(args.text_encoder_variant, ["penultimate_nonorm"]) text_model.eval().to(args.device) print("loading text model finished.") # Initialize image model. vision_model = instantiate_from_config(args.vision_model_config) vision_model = vision_model.eval().to(args.device) print("loading image model finished.") negative_prompt = "" if args.negative_prompt: with open(args.negative_prompt) as f: negative_prompt = f.read().strip() image_metadata_validate = torch.tensor( data=[ args.width, # original_height args.height, # original_width 0, # coordinate top 0, # coordinate left args.width, # target_height args.height, # target_width ], device=args.device, dtype=torch.float32 ).view(1, -1).repeat(args.samples_per_prompt, 1) # Create output directory os.makedirs(args.output_dir, exist_ok=True) args.output_image_grid_dir = os.path.join(args.output_dir, "images_grid") args.output_image_dir = os.path.join(args.output_dir, "images") args.output_mask_grid_dir = os.path.join(args.output_dir, "masks_grid") args.output_mask_dir = os.path.join(args.output_dir, "masks") os.makedirs(args.output_image_grid_dir, exist_ok=True) os.makedirs(args.output_image_dir, exist_ok=True) os.makedirs(args.output_mask_grid_dir, exist_ok=True) os.makedirs(args.output_mask_dir, exist_ok=True) with torch.no_grad(): # Prepare negative prompt. if args.guidance_weight != 1: text_negative_output = text_model(negative_prompt) positive_prompt = args.text_prompt positive_promt_image_path = args.image_prompt_path target_phrase = args.target_phrase # Compute target phrases target_token = torch.zeros(1, 77).to(args.device) positions = find_phrase_positions_in_text(positive_prompt, target_phrase) for position in positions: prompt_before = positive_prompt[:position] # NOTE We do not need -1 here because the SDXL text encoder does not encode the trailing space. prompt_include = positive_prompt[:position+len(target_phrase)] print("prompt before: ", prompt_before, ", prompt_include: ", prompt_include) prompt_before_length = text_model.get_vaild_token_length(prompt_before) + 1 prompt_include_length = text_model.get_vaild_token_length(prompt_include) + 1 print("prompt_before_length: ", prompt_before_length, ", prompt_include_length: ", prompt_include_length) target_token[:, prompt_before_length:prompt_include_length] = 1 # Text used for progress bar pbar_text = positive_prompt[:40] # Compute text embeddings text_positive_output = text_model(positive_prompt) text_positive_embeddings = text_positive_output.embeddings.repeat_interleave(args.samples_per_prompt, dim=0) text_positive_pooled = text_positive_output.pooled[-1].repeat_interleave(args.samples_per_prompt, dim=0) if args.guidance_weight != 1: text_negative_embeddings = text_negative_output.embeddings.repeat_interleave(args.samples_per_prompt, dim=0) text_negative_pooled = text_negative_output.pooled[-1].repeat_interleave(args.samples_per_prompt, dim=0) # Compute image embeddings positive_image = Image.open(positive_promt_image_path).convert("RGB") positive_image = torchvision.transforms.ToTensor()(positive_image) positive_image = positive_image.unsqueeze(0).repeat_interleave(args.samples_per_prompt, dim=0) positive_image = torch.nn.functional.interpolate( positive_image, size=(768, 768), mode="bilinear", align_corners=False ) negative_image = torch.zeros_like(positive_image) print(positive_image.size(), negative_image.size()) positive_image = positive_image.to(args.device) negative_image = negative_image.to(args.device) positive_image_dict = {"image_ref": positive_image} positive_image_output = vision_model(positive_image_dict, device=args.device) negative_image_dict = {"image_ref": negative_image} negative_image_output = vision_model(negative_image_dict, device=args.device) # Initialize latent with input latent + noise (i2i) / pure noise (t2i) latent = torch.randn( size=[ args.samples_per_prompt, vae_config["latent_channels"], args.height // vae_downsample_factor, args.width // vae_downsample_factor ], device=args.device, generator=torch.Generator(args.device).manual_seed(args.seed)) target_h = (args.height // vae_downsample_factor) // 2 target_w = (args.width // vae_downsample_factor) // 2 # Real Reverse diffusion process. text2image_crossmap_2d_all_timesteps_list = [] current_step = 0 for timestep in tqdm(iterable=infer_timesteps, desc=f"[{pbar_text}]", dynamic_ncols=True): if current_step < args.mask_reused_step: pred_cond, pred_cond_dict = unet_model( sample=latent, timestep=timestep, encoder_hidden_states=text_positive_embeddings, encoder_attention_mask=None, added_cond_kwargs=dict( text_embeds=text_positive_pooled, time_ids=image_metadata_validate ), vision_input_dict=None, vision_guided_mask=None, return_as_origin=False, return_text2image_mask=True, ) crossmap_2d_avg = mask_generation( crossmap_2d_list=pred_cond_dict["text2image_crossmap_2d"], selfmap_2d_list=pred_cond_dict.get("self_attention_map", []), target_token=target_token, mask_scope=args.mask_scope, mask_target_h=target_h, mask_target_w=target_w, mask_mode=args.mask_strategy, ) else: # using previous step's mask crossmap_2d_avg = text2image_crossmap_2d_all_timesteps_list[-1].squeeze(1) if crossmap_2d_avg.dim() == 5: # Means that each layer uses a separate mask weight. text2image_crossmap_2d_all_timesteps_list.append(crossmap_2d_avg.mean(dim=2).unsqueeze(1)) else: text2image_crossmap_2d_all_timesteps_list.append(crossmap_2d_avg.unsqueeze(1)) pred_cond, pred_cond_dict = unet_model( sample=latent, timestep=timestep, encoder_hidden_states=text_positive_embeddings, encoder_attention_mask=None, added_cond_kwargs=dict( text_embeds=text_positive_pooled, time_ids=image_metadata_validate ), vision_input_dict=positive_image_output, vision_guided_mask=crossmap_2d_avg, return_as_origin=False, return_text2image_mask=True, multiple_reference_image=False ) crossmap_2d_avg_neg = crossmap_2d_avg.mean(dim=1, keepdim=True) pred_negative, pred_negative_dict = unet_model( sample=latent, timestep=timestep, encoder_hidden_states=text_negative_embeddings, encoder_attention_mask=None, added_cond_kwargs=dict( text_embeds=text_negative_pooled, time_ids=image_metadata_validate ), vision_input_dict=negative_image_output, vision_guided_mask=crossmap_2d_avg, return_as_origin=False, return_text2image_mask=True, multiple_reference_image=False ) pred = classifier_free_guidance_image_prompt_cascade( pred_t_cond=None, pred_ti_cond=pred_cond, pred_uncond=pred_negative, guidance_weight_t=args.guidance_weight, guidance_weight_i=args.guidance_weight, guidance_stdev_rescale_factor=0, cfg_rescale_mode="naive_global_direct" ) step = scheduler.step( model_output=pred, model_output_type=args.unet_prediction, timestep=timestep, sample=latent) latent = step.prev_sample current_step += 1 sample = vae_decoder(step.pred_original_sample) # save each image for sample_i in range(sample.size(0)): sample_i_image = torch.clamp(sample[sample_i] * 0.5 + 0.5, min=0, max=1).float() to_pil_image(sample_i_image).save(args.output_image_dir + "/output_{}.jpg".format(sample_i)) # save grid images sample = make_grid(sample, normalize=True, value_range=(-1, 1), nrow=args.nrow).float() to_pil_image(sample).save(args.output_image_grid_dir + "/grid_image.jpg")