virtual-try-on / app.py
sengourav012's picture
Create app.py
b99dca9 verified
raw
history blame
4.4 kB
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
import cv2
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----------------- Load Human Parser Model from Hugging Face Hub -----------------
processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
parser_model = SegformerForSemanticSegmentation.from_pretrained(
"matei-dorian/segformer-b5-finetuned-human-parsing"
).to(device).eval()
# ----------------- UNet Generator Definition -----------------
class UNetGenerator(nn.Module):
def __init__(self, in_channels=6, out_channels=3):
super(UNetGenerator, self).__init__()
def block(in_c, out_c):
return nn.Sequential(
nn.Conv2d(in_c, out_c, 4, 2, 1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
def up_block(in_c, out_c):
return nn.Sequential(
nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
self.down1 = block(in_channels, 64)
self.down2 = block(64, 128)
self.down3 = block(128, 256)
self.down4 = block(256, 512)
self.up1 = up_block(512, 256)
self.up2 = up_block(512, 128)
self.up3 = up_block(256, 64)
self.up4 = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
nn.Tanh()
)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
u1 = self.up1(d4)
u2 = self.up2(torch.cat([u1, d3], dim=1))
u3 = self.up3(torch.cat([u2, d2], dim=1))
u4 = self.up4(torch.cat([u3, d1], dim=1))
return u4
# ----------------- Load UNet Try-On Model -----------------
tryon_model = UNetGenerator().to(device)
checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
tryon_model.load_state_dict(checkpoint['model_state_dict'])
tryon_model.eval()
# ----------------- Image Transforms -----------------
img_transform = transforms.Compose([
transforms.Resize((256, 192)),
transforms.ToTensor()
])
# ----------------- Helper Functions -----------------
def get_segmentation(image: Image.Image):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = parser_model(**inputs)
logits = outputs.logits
predicted = torch.argmax(logits, dim=1)[0].cpu().numpy()
return predicted
def generate_agnostic(image: Image.Image, segmentation):
img_np = np.array(image.resize((192, 256)))
agnostic_np = img_np.copy()
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
agnostic_np[segmentation_resized == 4] = [128, 128, 128] # Mask upper clothes
return Image.fromarray(agnostic_np)
def generate_tryon_output(agnostic_img, cloth_img):
agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
with torch.no_grad():
output = tryon_model(input_tensor)
output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
output_img = (output_img * 255).astype(np.uint8)
return Image.fromarray(output_img)
# ----------------- Gradio Interface -----------------
def virtual_tryon(person_image, cloth_image):
segmentation = get_segmentation(person_image)
agnostic = generate_agnostic(person_image, segmentation)
result = generate_tryon_output(agnostic, cloth_image)
return agnostic, result
demo = gr.Interface(
fn=virtual_tryon,
inputs=[
gr.Image(type="pil", label="Person Image"),
gr.Image(type="pil", label="Cloth Image")
],
outputs=[
gr.Image(type="pil", label="Agnostic (Torso Masked)"),
gr.Image(type="pil", label="Virtual Try-On Output")
],
title="πŸ‘• Virtual Try-On (UNet + Segformer)",
description="Upload a person image and a cloth image to try on the cloth virtually."
)
if __name__ == "__main__":
demo.launch()