Virtual_Try-On / app.py
NikhilJoson's picture
Update app.py
0717ff9 verified
import gradio as gr
import spaces
from PIL import Image
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.unet_hacked_tryon import UNet2DConditionModel
from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTextModelWithProjection,)
from diffusers import DDPMScheduler,AutoencoderKL
from typing import List
import torch
import os
from transformers import AutoTokenizer
import numpy as np
from utils_mask import get_mask_location
from torchvision import transforms
import apply_net
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.openpose.run_openpose import OpenPose
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
from torchvision.transforms.functional import to_pil_image
def pil_to_binary_mask(pil_image, threshold=0):
np_image = np.array(pil_image)
grayscale_image = Image.fromarray(np_image).convert("L")
binary_mask = np.array(grayscale_image) > threshold
mask = np.zeros(binary_mask.shape, dtype=np.uint8)
for i in range(binary_mask.shape[0]):
for j in range(binary_mask.shape[1]):
if binary_mask[i,j] == True :
mask[i,j] = 1
mask = (mask*255).astype(np.uint8)
output_mask = Image.fromarray(mask)
return output_mask
base_path = 'yisol/IDM-VTON'
example_path = os.path.join(os.path.dirname(__file__), 'example')
unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16,)
unet.requires_grad_(False)
tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", revision=None, use_fast=False,)
tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", revision=None, use_fast=False,)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16,)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16,)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16,)
vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16,)
# "stabilityai/stable-diffusion-xl-base-1.0",
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16,)
parsing_model = Parsing(0)
openpose_model = OpenPose(0)
UNet_Encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
tensor_transfrom = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
pipe = TryonPipeline.from_pretrained(
base_path,
unet=unet,
vae=vae,
feature_extractor= CLIPImageProcessor(),
text_encoder = text_encoder_one,
text_encoder_2 = text_encoder_two,
tokenizer = tokenizer_one,
tokenizer_2 = tokenizer_two,
scheduler = noise_scheduler,
image_encoder=image_encoder,
torch_dtype=torch.float16,
)
pipe.unet_encoder = UNet_Encoder
@spaces.GPU
def start_tryon(img,garm_img,garment_des,cloth_type,is_checked,is_checked_crop,denoise_steps,seed):
print(img)
device = "cuda"
openpose_model.preprocessor.body_estimation.model.to(device)
pipe.to(device)
pipe.unet_encoder.to(device)
garm_img= garm_img.convert("RGB").resize((768,1024))
human_img_orig = img["background"].convert("RGB")
if is_checked_crop:
width, height = human_img_orig.size
target_width = int(min(width, height * (3 / 4)))
target_height = int(min(height, width * (4 / 3)))
left = (width - target_width) / 2
top = (height - target_height) / 2
right = (width + target_width) / 2
bottom = (height + target_height) / 2
cropped_img = human_img_orig.crop((left, top, right, bottom))
crop_size = cropped_img.size
human_img = cropped_img.resize((768,1024))
else:
human_img = human_img_orig.resize((768,1024))
if is_checked:
keypoints = openpose_model(human_img.resize((384,512)))
model_parse, _ = parsing_model(human_img.resize((384,512)))
mask, mask_gray = get_mask_location('hd', cloth_type, model_parse, keypoints)
mask = mask.resize((768,1024))
else:
mask = pil_to_binary_mask(img.resize((768, 1024)))
# mask = transforms.ToTensor()(mask)
# mask = mask.unsqueeze(0)
mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
mask_gray = to_pil_image((mask_gray+1.0)/2.0)
human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
# verbosity = getattr(args, "verbosity", None)
pose_img = args.func(args,human_img_arg)
pose_img = pose_img[:,:,::-1]
pose_img = Image.fromarray(pose_img).resize((768,1024))
with torch.no_grad():
# Extract the images
with torch.cuda.amp.autocast():
with torch.no_grad():
prompt = "model is wearing " + garment_des
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
prompt,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt = "a photo of " + garment_des
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * 1
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * 1
with torch.inference_mode():
(prompt_embeds_c,_,_,_,) = pipe.encode_prompt(prompt,num_images_per_prompt=1,do_classifier_free_guidance=False,negative_prompt=negative_prompt,)
pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
images = pipe(
prompt_embeds=prompt_embeds.to(device,torch.float16),
negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
num_inference_steps=denoise_steps,
generator=generator,
strength = 1.0,
pose_img = pose_img.to(device,torch.float16),
text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
cloth = garm_tensor.to(device,torch.float16),
mask_image=mask,
image=human_img,
height=1024,
width=768,
ip_adapter_image = garm_img.resize((768,1024)),
guidance_scale=2.0,
)[0]
if is_checked_crop:
out_img = images[0].resize(crop_size)
human_img_orig.paste(out_img, (int(left), int(top)))
return human_img_orig, mask_gray
else:
return garm_img, images[0], mask_gray
# return images[0], mask_gray
def main_(imgs,topwear_img,topwear_des,bottomwear_img,bottomwear_des,dress_img,dress_des,is_checked,is_checked_crop,denoise_steps,seed):
if dress_img!=None:
return start_tryon(imgs,dress_img,dress_des,"dresses",is_checked,is_checked_crop,denoise_steps,seed)
elif topwear_img!=None and bottomwear_img==None:
return start_tryon(imgs,topwear_img,topwear_des,"upper_body",is_checked,is_checked_crop,denoise_steps,seed)
elif topwear_img==None and bottomwear_img!=None:
return start_tryon(imgs,bottomwear_img,bottomwear_des,"lower_body",is_checked,is_checked_crop,denoise_steps,seed)
elif topwear_img!=None and bottomwear_img!=None:
_, half_img, half_mask = start_tryon(imgs,topwear_img,topwear_des,"upper_body",is_checked,is_checked_crop,denoise_steps,seed)
half_dict= {}
half_dict['background'],half_dict['layers'],half_dict['composite'] = half_img,None,None
return start_tryon(half_dict,bottomwear_img,bottomwear_des,"lower_body",is_checked,is_checked_crop,denoise_steps,seed)
garm_list = os.listdir(os.path.join(example_path,"cloth"))
garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
human_list = os.listdir(os.path.join(example_path,"human"))
human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
human_ex_list = []
for ex_human in human_list_path:
ex_dict= {}
ex_dict['background'] = ex_human
ex_dict['layers'] = None
ex_dict['composite'] = None
human_ex_list.append(ex_dict)
def get_examples_lst(folder_name, file_list):
example_list = []
for i in file_list:
example_list.append([os.path.join(example_path,folder_name,i),i.split(".")[0]])
return example_list
topwear_list = os.listdir(os.path.join(example_path,"topwear"))
#topwear_list_path = [os.path.join(example_path,"topwear",topwear) for topwear in topwear_list]
topwear_ex_list = get_examples_lst("topwear", topwear_list)
bottomwear_list = os.listdir(os.path.join(example_path,"bottomwear"))
# bottomwear_list_path = [os.path.join(example_path,"bottomwear",bottomwear) for bottomwear in bottomwear_list]
bottomwear_ex_list = get_examples_lst("bottomwear", bottomwear_list)
dress_list = os.listdir(os.path.join(example_path,"dresses"))
# dress_list_path = [os.path.join(example_path,"dresses",dress) for dress in dress_list]
dress_ex_list = get_examples_lst("dresses", dress_list)
image_blocks = gr.Blocks(theme="Nymbo/Alyx_Theme").queue()
with image_blocks as demo:
gr.HTML("<center><h1>Virtual Try-On</h1></center>")
gr.HTML("<center><p>Upload an image of a person and images of the clothes✨</p></center>")
gr.Markdown("""NOTES:
- Upload/choose any 'Human' you want to try the clothes on, and then upload/choose any 'Topwear' and 'Bottomwear' to try the virtual try-on.
- If 'Dress' is uploaded/chosen, 'Topwear' and 'Bottomwear', if any, are ignored.
""")
with gr.Row():
with gr.Column(scale=1):
inp_img = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
with gr.Row():
is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
with gr.Row():
is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
example = gr.Examples(inputs=inp_img, examples_per_page=10, examples=human_ex_list)
with gr.Column(scale=1):
with gr.Row():
dress_image = gr.Image(label="Dress", sources='upload', type="pil")
with gr.Row():
dress_desc = gr.Textbox(placeholder="Description of garment ex) Jumper suit", show_label=False, elem_id="prompt")
example = gr.Examples(inputs=[dress_image,dress_desc], examples_per_page=4, examples=dress_ex_list)
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
topwear_image = gr.Image(label="Topwear", sources='upload', type="pil")
with gr.Row():
topwear_desc = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
example = gr.Examples(inputs=[topwear_image,topwear_desc], examples_per_page=4,examples=topwear_ex_list)
with gr.Column(scale=1):
with gr.Row():
bottomwear_image = gr.Image(label="Bottomwear", sources='upload', type="pil")
with gr.Row():
bottomwear_desc = gr.Textbox(placeholder="Description of garment ex) Cargo pants", show_label=False, elem_id="prompt")
example = gr.Examples(inputs=[bottomwear_image,bottomwear_desc], examples_per_page=4, examples=bottomwear_ex_list)
with gr.Column():
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row():
denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=26, step=1)
seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
try_button = gr.Button(value="Try-on",variant='primary')
with gr.Row():
with gr.Column():
# image_out = gr.Image(label="Output", elem_id="output-img", height=400)
image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
with gr.Accordion("Debug Info", open=False):
image_in = gr.Image(label="Midway Image", elem_id="midway-img",show_share_button=False)
masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
try_button.click(fn=main_, inputs=[inp_img,topwear_image,topwear_desc,bottomwear_image,bottomwear_desc,dress_image,dress_desc,is_checked,is_checked_crop,denoise_steps,seed],
outputs=[image_in, image_out, masked_img], api_name='tryon')
image_blocks.launch()