virtual-try-on / app.py
sengourav012's picture
Included GAN and Diffusion model also
e208ed3 verified
raw
history blame
5.08 kB
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
import cv2
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
# ----------------- Device -----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----------------- Load Human Parser Model -----------------
processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
parser_model = SegformerForSemanticSegmentation.from_pretrained(
"matei-dorian/segformer-b5-finetuned-human-parsing"
).to(device).eval()
# ----------------- UNet Generator Definition -----------------
class UNetGenerator(nn.Module):
def __init__(self, in_channels=6, out_channels=3):
super(UNetGenerator, self).__init__()
def block(in_c, out_c):
return nn.Sequential(
nn.Conv2d(in_c, out_c, 4, 2, 1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
def up_block(in_c, out_c):
return nn.Sequential(
nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
self.down1 = block(in_channels, 64)
self.down2 = block(64, 128)
self.down3 = block(128, 256)
self.down4 = block(256, 512)
self.up1 = up_block(512, 256)
self.up2 = up_block(512, 128)
self.up3 = up_block(256, 64)
self.up4 = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
nn.Tanh()
)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
u1 = self.up1(d4)
u2 = self.up2(torch.cat([u1, d3], dim=1))
u3 = self.up3(torch.cat([u2, d2], dim=1))
u4 = self.up4(torch.cat([u3, d1], dim=1))
return u4
# ----------------- Image Transforms -----------------
img_transform = transforms.Compose([
transforms.Resize((256, 192)),
transforms.ToTensor()
])
# ----------------- Helper Functions -----------------
def get_segmentation(image: Image.Image):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = parser_model(**inputs)
logits = outputs.logits
predicted = torch.argmax(logits, dim=1)[0].cpu().numpy()
return predicted
def generate_agnostic(image: Image.Image, segmentation):
img_np = np.array(image.resize((192, 256)))
agnostic_np = img_np.copy()
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
clothing_labels = [4, 5, 6, 7, 8, 16]
for label in clothing_labels:
agnostic_np[segmentation_resized == label] = [128, 128, 128]
return Image.fromarray(agnostic_np)
def load_model(model_type):
model = UNetGenerator().to(device)
if model_type == "UNet":
checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
elif model_type == "GAN":
checkpoint = torch.load("viton_gan_generator_checkpoint.pth", map_location=device)
elif model_type == "Diffusion":
checkpoint = torch.load("viton_diffusion_generator_checkpoint.pth", map_location=device)
else:
raise ValueError("Invalid model type")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
def generate_tryon_output(agnostic_img, cloth_img, model):
agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
with torch.no_grad():
output = model(input_tensor)
output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
output_img = (output_img + 1) / 2
output_img = (output_img * 255).astype(np.uint8)
return Image.fromarray(output_img)
# ----------------- Inference Pipeline -----------------
def virtual_tryon(person_image, cloth_image, model_type):
segmentation = get_segmentation(person_image)
agnostic = generate_agnostic(person_image, segmentation)
model = load_model(model_type)
result = generate_tryon_output(agnostic, cloth_image, model)
return agnostic, result
# ----------------- Gradio Interface -----------------
demo = gr.Interface(
fn=virtual_tryon,
inputs=[
gr.Image(type="pil", label="Person Image"),
gr.Image(type="pil", label="Cloth Image"),
gr.Radio(choices=["UNet", "GAN", "Diffusion"], label="Model Type", value="UNet")
],
outputs=[
gr.Image(type="pil", label="Agnostic (Torso Masked)"),
gr.Image(type="pil", label="Virtual Try-On Output")
],
title="πŸ‘• Virtual Try-On App",
description="Upload a person image and a clothing image, select a model (UNet, GAN, Diffusion), and try it on virtually."
)
if __name__ == "__main__":
demo.launch(share=True)