import gradio as gr import spaces from PIL import Image from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.unet_hacked_tryon import UNet2DConditionModel from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTextModelWithProjection,) from diffusers import DDPMScheduler,AutoencoderKL from typing import List import torch import os from transformers import AutoTokenizer import numpy as np from utils_mask import get_mask_location from torchvision import transforms import apply_net from preprocess.humanparsing.run_parsing import Parsing from preprocess.openpose.run_openpose import OpenPose from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation from torchvision.transforms.functional import to_pil_image def pil_to_binary_mask(pil_image, threshold=0): np_image = np.array(pil_image) grayscale_image = Image.fromarray(np_image).convert("L") binary_mask = np.array(grayscale_image) > threshold mask = np.zeros(binary_mask.shape, dtype=np.uint8) for i in range(binary_mask.shape[0]): for j in range(binary_mask.shape[1]): if binary_mask[i,j] == True : mask[i,j] = 1 mask = (mask*255).astype(np.uint8) output_mask = Image.fromarray(mask) return output_mask base_path = 'yisol/IDM-VTON' example_path = os.path.join(os.path.dirname(__file__), 'example') unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16,) unet.requires_grad_(False) tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", revision=None, use_fast=False,) tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", revision=None, use_fast=False,) noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16,) text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16,) image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16,) vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16,) # "stabilityai/stable-diffusion-xl-base-1.0", UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16,) parsing_model = Parsing(0) openpose_model = OpenPose(0) UNet_Encoder.requires_grad_(False) image_encoder.requires_grad_(False) vae.requires_grad_(False) unet.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) tensor_transfrom = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) pipe = TryonPipeline.from_pretrained( base_path, unet=unet, vae=vae, feature_extractor= CLIPImageProcessor(), text_encoder = text_encoder_one, text_encoder_2 = text_encoder_two, tokenizer = tokenizer_one, tokenizer_2 = tokenizer_two, scheduler = noise_scheduler, image_encoder=image_encoder, torch_dtype=torch.float16, ) pipe.unet_encoder = UNet_Encoder @spaces.GPU def start_tryon(img,garm_img,garment_des,cloth_type,is_checked,is_checked_crop,denoise_steps,seed): print(img) device = "cuda" openpose_model.preprocessor.body_estimation.model.to(device) pipe.to(device) pipe.unet_encoder.to(device) garm_img= garm_img.convert("RGB").resize((768,1024)) human_img_orig = img["background"].convert("RGB") if is_checked_crop: width, height = human_img_orig.size target_width = int(min(width, height * (3 / 4))) target_height = int(min(height, width * (4 / 3))) left = (width - target_width) / 2 top = (height - target_height) / 2 right = (width + target_width) / 2 bottom = (height + target_height) / 2 cropped_img = human_img_orig.crop((left, top, right, bottom)) crop_size = cropped_img.size human_img = cropped_img.resize((768,1024)) else: human_img = human_img_orig.resize((768,1024)) if is_checked: keypoints = openpose_model(human_img.resize((384,512))) model_parse, _ = parsing_model(human_img.resize((384,512))) mask, mask_gray = get_mask_location('hd', cloth_type, model_parse, keypoints) mask = mask.resize((768,1024)) else: mask = pil_to_binary_mask(img.resize((768, 1024))) # mask = transforms.ToTensor()(mask) # mask = mask.unsqueeze(0) mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img) mask_gray = to_pil_image((mask_gray+1.0)/2.0) human_img_arg = _apply_exif_orientation(human_img.resize((384,512))) human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')) # verbosity = getattr(args, "verbosity", None) pose_img = args.func(args,human_img_arg) pose_img = pose_img[:,:,::-1] pose_img = Image.fromarray(pose_img).resize((768,1024)) with torch.no_grad(): # Extract the images with torch.cuda.amp.autocast(): with torch.no_grad(): prompt = "model is wearing " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt = "a photo of " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, List): prompt = [prompt] * 1 if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * 1 with torch.inference_mode(): (prompt_embeds_c,_,_,_,) = pipe.encode_prompt(prompt,num_images_per_prompt=1,do_classifier_free_guidance=False,negative_prompt=negative_prompt,) pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16) garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16) generator = torch.Generator(device).manual_seed(seed) if seed is not None else None images = pipe( prompt_embeds=prompt_embeds.to(device,torch.float16), negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16), pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16), num_inference_steps=denoise_steps, generator=generator, strength = 1.0, pose_img = pose_img.to(device,torch.float16), text_embeds_cloth=prompt_embeds_c.to(device,torch.float16), cloth = garm_tensor.to(device,torch.float16), mask_image=mask, image=human_img, height=1024, width=768, ip_adapter_image = garm_img.resize((768,1024)), guidance_scale=2.0, )[0] if is_checked_crop: out_img = images[0].resize(crop_size) human_img_orig.paste(out_img, (int(left), int(top))) return human_img_orig, mask_gray else: return garm_img, images[0], mask_gray # return images[0], mask_gray def main_(imgs,topwear_img,topwear_des,bottomwear_img,bottomwear_des,dress_img,dress_des,is_checked,is_checked_crop,denoise_steps,seed): if dress_img!=None: return start_tryon(imgs,dress_img,dress_des,"dresses",is_checked,is_checked_crop,denoise_steps,seed) elif topwear_img!=None and bottomwear_img==None: return start_tryon(imgs,topwear_img,topwear_des,"upper_body",is_checked,is_checked_crop,denoise_steps,seed) elif topwear_img==None and bottomwear_img!=None: return start_tryon(imgs,bottomwear_img,bottomwear_des,"lower_body",is_checked,is_checked_crop,denoise_steps,seed) elif topwear_img!=None and bottomwear_img!=None: _, half_img, half_mask = start_tryon(imgs,topwear_img,topwear_des,"upper_body",is_checked,is_checked_crop,denoise_steps,seed) half_dict= {} half_dict['background'],half_dict['layers'],half_dict['composite'] = half_img,None,None return start_tryon(half_dict,bottomwear_img,bottomwear_des,"lower_body",is_checked,is_checked_crop,denoise_steps,seed) garm_list = os.listdir(os.path.join(example_path,"cloth")) garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list] human_list = os.listdir(os.path.join(example_path,"human")) human_list_path = [os.path.join(example_path,"human",human) for human in human_list] human_ex_list = [] for ex_human in human_list_path: ex_dict= {} ex_dict['background'] = ex_human ex_dict['layers'] = None ex_dict['composite'] = None human_ex_list.append(ex_dict) def get_examples_lst(folder_name, file_list): example_list = [] for i in file_list: example_list.append([os.path.join(example_path,folder_name,i),i.split(".")[0]]) return example_list topwear_list = os.listdir(os.path.join(example_path,"topwear")) #topwear_list_path = [os.path.join(example_path,"topwear",topwear) for topwear in topwear_list] topwear_ex_list = get_examples_lst("topwear", topwear_list) bottomwear_list = os.listdir(os.path.join(example_path,"bottomwear")) # bottomwear_list_path = [os.path.join(example_path,"bottomwear",bottomwear) for bottomwear in bottomwear_list] bottomwear_ex_list = get_examples_lst("bottomwear", bottomwear_list) dress_list = os.listdir(os.path.join(example_path,"dresses")) # dress_list_path = [os.path.join(example_path,"dresses",dress) for dress in dress_list] dress_ex_list = get_examples_lst("dresses", dress_list) image_blocks = gr.Blocks(theme="Nymbo/Alyx_Theme").queue() with image_blocks as demo: gr.HTML("

Virtual Try-On

") gr.HTML("

Upload an image of a person and images of the clothes✨

") gr.Markdown("""NOTES: - Upload/choose any 'Human' you want to try the clothes on, and then upload/choose any 'Topwear' and 'Bottomwear' to try the virtual try-on. - If 'Dress' is uploaded/chosen, 'Topwear' and 'Bottomwear', if any, are ignored. """) with gr.Row(): with gr.Column(scale=1): inp_img = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True) with gr.Row(): is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True) with gr.Row(): is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False) example = gr.Examples(inputs=inp_img, examples_per_page=10, examples=human_ex_list) with gr.Column(scale=1): with gr.Row(): dress_image = gr.Image(label="Dress", sources='upload', type="pil") with gr.Row(): dress_desc = gr.Textbox(placeholder="Description of garment ex) Jumper suit", show_label=False, elem_id="prompt") example = gr.Examples(inputs=[dress_image,dress_desc], examples_per_page=4, examples=dress_ex_list) with gr.Row(): with gr.Column(scale=1): with gr.Row(): topwear_image = gr.Image(label="Topwear", sources='upload', type="pil") with gr.Row(): topwear_desc = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt") example = gr.Examples(inputs=[topwear_image,topwear_desc], examples_per_page=4,examples=topwear_ex_list) with gr.Column(scale=1): with gr.Row(): bottomwear_image = gr.Image(label="Bottomwear", sources='upload', type="pil") with gr.Row(): bottomwear_desc = gr.Textbox(placeholder="Description of garment ex) Cargo pants", show_label=False, elem_id="prompt") example = gr.Examples(inputs=[bottomwear_image,bottomwear_desc], examples_per_page=4, examples=bottomwear_ex_list) with gr.Column(): with gr.Accordion(label="Advanced Settings", open=False): with gr.Row(): denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=26, step=1) seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42) try_button = gr.Button(value="Try-on",variant='primary') with gr.Row(): with gr.Column(): # image_out = gr.Image(label="Output", elem_id="output-img", height=400) image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False) with gr.Accordion("Debug Info", open=False): image_in = gr.Image(label="Midway Image", elem_id="midway-img",show_share_button=False) masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False) try_button.click(fn=main_, inputs=[inp_img,topwear_image,topwear_desc,bottomwear_image,bottomwear_desc,dress_image,dress_desc,is_checked,is_checked_crop,denoise_steps,seed], outputs=[image_in, image_out, masked_img], api_name='tryon') image_blocks.launch()