virtual-try-on / app.py
sengourav012's picture
Create app.py
6706839 verified
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
import cv2
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
from improved_viton import ImprovedUNetGenerator
# ----------------- Device -----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----------------- Load Human Parser Model -----------------
processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
parser_model = SegformerForSemanticSegmentation.from_pretrained(
"matei-dorian/segformer-b5-finetuned-human-parsing"
).to(device).eval()
# ----------------- UNet Generator Definition -----------------
class UNetGenerator(nn.Module):
def __init__(self, in_channels=6, out_channels=3):
super(UNetGenerator, self).__init__()
def block(in_c, out_c):
return nn.Sequential(
nn.Conv2d(in_c, out_c, 4, 2, 1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
def up_block(in_c, out_c):
return nn.Sequential(
nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
self.down1 = block(in_channels, 64)
self.down2 = block(64, 128)
self.down3 = block(128, 256)
self.down4 = block(256, 512)
self.up1 = up_block(512, 256)
self.up2 = up_block(512, 128)
self.up3 = up_block(256, 64)
self.up4 = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
nn.Tanh()
)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
u1 = self.up1(d4)
u2 = self.up2(torch.cat([u1, d3], dim=1))
u3 = self.up3(torch.cat([u2, d2], dim=1))
u4 = self.up4(torch.cat([u3, d1], dim=1))
return u4
# ----------------- Image Segmentation -----------------
def get_segmentation(image: Image.Image):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = parser_model(**inputs)
logits = outputs.logits
predicted = torch.argmax(logits, dim=1)[0].cpu().numpy()
return predicted
# ----------------- Agnostic Creation -----------------
def generate_agnostic(image: Image.Image, segmentation):
img_np = np.array(image.resize((192, 256)))
agnostic_np = img_np.copy()
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
clothing_labels = [4]
for label in clothing_labels:
agnostic_np[segmentation_resized == label] = [128, 128, 128]
return Image.fromarray(agnostic_np)
# ----------------- Load Model -----------------
def load_model(model_type):
if model_type == "UNet":
model = UNetGenerator().to(device)
checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
elif model_type == "GAN":
model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device)
state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
elif model_type == "Diffusion":
model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device)
state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
else:
raise ValueError("Invalid model type")
if state_dict is None:
raise KeyError(f"No valid state_dict found for model type {model_type}")
model.load_state_dict(state_dict)
model.eval()
return model
# ----------------- Generate Try-On -----------------
def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
if model_type == "UNet":
img_transform = transforms.Compose([
transforms.Resize((256, 192)),
transforms.ToTensor()
])
else:
img_transform = transforms.Compose([
transforms.Resize((256, 192)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
with torch.no_grad():
output = model(input_tensor)
if model_type == "UNet":
output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
output_img = (output_img * 255).astype(np.uint8)
return Image.fromarray(output_img)
else:
output_img = output[0].cpu().permute(1, 2, 0).numpy()
output_img = (output_img + 1) / 2
output_img = np.clip(output_img, 0, 1)
person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
blend_mask = (segmentation_resized == 0).astype(np.float32)
blend_mask = np.expand_dims(blend_mask, axis=2)
final_output = blend_mask * person_np + (1 - blend_mask) * output_img
final_output = (final_output * 255).astype(np.uint8)
return Image.fromarray(final_output)
# ----------------- Traditional CV Pipeline -----------------
def create_agnostic_traditional(person_np, label_np):
mask = (label_np == 4).astype(np.uint8)
kernel = np.ones((7, 7), np.uint8)
dilated = cv2.dilate(mask, kernel, iterations=2)
agnostic = person_np.copy()
agnostic[dilated == 1] = [128, 128, 128]
return agnostic, dilated
def improved_warp_cloth(cloth_np, person_np, label_np):
mask = (label_np == 4).astype(np.uint8) * 255
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return cloth_np
cnt = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(cnt)
src_h, src_w = cloth_np.shape[:2]
src_points = np.array([[0,0],[src_w-1,0],[src_w-1,src_h-1],[0,src_h-1]], dtype=np.float32)
padding_x, padding_y = int(w*0.05), int(h*0.05)
dst_points = np.array([
[max(0, x - padding_x), max(0, y - padding_y)],
[min(person_np.shape[1] - 1, x + w + padding_x), max(0, y - padding_y)],
[min(person_np.shape[1] - 1, x + w + padding_x), min(person_np.shape[0] - 1, y + h + padding_y)],
[max(0, x - padding_x), min(person_np.shape[0] - 1, y + h + padding_y)]
], dtype=np.float32)
M = cv2.getPerspectiveTransform(src_points, dst_points)
warped = cv2.warpPerspective(cloth_np, M, (person_np.shape[1], person_np.shape[0]), borderMode=cv2.BORDER_CONSTANT)
return warped
def improved_blend_traditional(agnostic_np, warped_cloth_np, label_np):
target_mask = (label_np == 4).astype(np.uint8)
kernel = np.ones((9, 9), np.uint8)
target_mask = cv2.dilate(target_mask, kernel, iterations=2) * 255
gray = cv2.cvtColor(warped_cloth_np, cv2.COLOR_BGR2GRAY)
_, cloth_mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
combined_mask = cv2.bitwise_and(target_mask, cloth_mask)
combined_mask = cv2.GaussianBlur(combined_mask, (5, 5), 0)
M = cv2.moments(combined_mask)
center = (int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])) if M["m00"] != 0 else (96, 128)
try:
output = cv2.seamlessClone(warped_cloth_np, agnostic_np, combined_mask, center, cv2.NORMAL_CLONE)
except:
mask_3d = np.stack([combined_mask / 255.0] * 3, axis=2)
output = warped_cloth_np * mask_3d + agnostic_np * (1 - mask_3d)
output = output.astype(np.uint8)
return output
# ----------------- Main Pipeline -----------------
def virtual_tryon(person_image, cloth_image, model_type):
if model_type == "Traditional":
person_np = np.array(person_image.resize((192, 256)))[:, :, ::-1]
cloth_np = np.array(cloth_image.resize((192, 256)))[:, :, ::-1]
segmentation = get_segmentation(person_image)
label_np = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
agnostic_np, _ = create_agnostic_traditional(person_np, label_np)
warped_cloth = improved_warp_cloth(cloth_np, person_np, label_np)
output_np = improved_blend_traditional(agnostic_np, warped_cloth, label_np)
return Image.fromarray(agnostic_np[:, :, ::-1]), Image.fromarray(output_np[:, :, ::-1])
segmentation = get_segmentation(person_image)
agnostic = generate_agnostic(person_image, segmentation)
model = load_model(model_type)
result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model, model_type)
return agnostic, result
# ----------------- Gradio UI -----------------
demo = gr.Interface(
fn=virtual_tryon,
inputs=[
gr.Image(type="pil", label="Person Image"),
gr.Image(type="pil", label="Cloth Image"),
gr.Radio(choices=["UNet", "GAN", "Diffusion", "Traditional"], label="Model Type", value="UNet")
],
outputs=[
gr.Image(type="pil", label="Agnostic (Torso Masked)"),
gr.Image(type="pil", label="Virtual Try-On Output")
],
title="👕 Virtual Try-On App",
description="Upload a person image and a clothing image, select a model (UNet, GAN, Diffusion, Traditional), and try it on virtually."
)
if __name__ == "__main__":
demo.launch(share=True)