import os import torch from PIL import Image from DAI.pipeline_onestep import OneStepPipeline from DAI.controlnetvae import ControlNetVAEModel import numpy as np from diffusers import ( AutoencoderKL, ControlNetModel, DDPMScheduler, StableDiffusionControlNetPipeline, UNet2DConditionModel, UniPCMultistepScheduler, StableDiffusionPipeline ) from transformers import CLIPTextModel, AutoTokenizer from glob import glob import json import random from diffusers.utils import make_image_grid, load_image from peft import PeftModel from peft import LoraConfig, get_peft_model from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict from safetensors.torch import load_file from DAI.pipeline_all import DAIPipeline from DAI.decoder import CustomAutoencoderKL from tqdm import tqdm import argparse device = torch.device("cuda" if torch.cuda.is_available() else "cpu") weight_dtype = torch.float32 model_dir = "./weights" pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1" revision = None variant = None # Load the model # normal controlnet = ControlNetVAEModel.from_pretrained(model_dir + "/controlnet", torch_dtype=weight_dtype).to(device) unet = UNet2DConditionModel.from_pretrained(model_dir + "/unet", torch_dtype=weight_dtype).to(device) vae_2 = CustomAutoencoderKL.from_pretrained(model_dir + "/vae_2", torch_dtype=weight_dtype).to(device) # Load other components of the pipeline vae = AutoencoderKL.from_pretrained( pretrained_model_name_or_path, subfolder="vae", revision=revision, variant=variant ).to(device) # import pdb; pdb.set_trace() text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, variant=variant ).to(device) tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", revision=revision, use_fast=False, ) pipeline = DAIPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, safety_checker=None, scheduler=None, feature_extractor=None, t_start=0 ).to(device) # Create a directory to save the results # Parse command line arguments parser = argparse.ArgumentParser(description="Run reflection removal on images.") parser.add_argument("--input_dir", type=str, required=True, help="Directory for evaluation inputs.") parser.add_argument("--result_dir", type=str, required=True, help="Directory for evaluation results.") parser.add_argument("--concat_dir", type=str, required=True, help="Directory for concat evaluation results.") args = parser.parse_args() input_dir = args.input_dir result_dir = args.result_dir concat_dir = args.concat_dir os.makedirs(result_dir, exist_ok=True) os.makedirs(concat_dir, exist_ok=True) input_files = sorted(glob(os.path.join(input_dir, "*"))) for input_file in tqdm(input_files, desc="Processing images"): input_image = load_image(input_file) resolution = 0 if max(input_image.size) < 768: resolution = None result_image = pipeline( image=torch.tensor(np.array(input_image)).permute(2, 0, 1).float().div(255).unsqueeze(0).to(device), prompt="remove glass reflection", vae_2=vae_2, processing_resolution=resolution ).prediction[0] result_image = (result_image + 1) / 2 result_image = result_image.clip(0., 1.) result_image = result_image * 255 result_image = result_image.astype(np.uint8) result_image = Image.fromarray(result_image) concat_image = make_image_grid([input_image, result_image], rows=1, cols=2) # Save the concatenated image input_filename = os.path.basename(input_file) concat_image.save(os.path.join(concat_dir, f"{input_filename}")) result_image.save(os.path.join(result_dir, f"{input_filename}"))