virtual-try-on / app.py
sengourav012's picture
make app agnostic to all the models
31d417a verified
raw
history blame
9.74 kB
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
import cv2
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
from improved_viton import ImprovedUNetGenerator
# ----------------- Device -----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----------------- Load Human Parser Model -----------------
processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
parser_model = SegformerForSemanticSegmentation.from_pretrained(
"matei-dorian/segformer-b5-finetuned-human-parsing"
).to(device).eval()
# ----------------- UNet Generator Definition -----------------
class UNetGenerator(nn.Module):
def __init__(self, in_channels=6, out_channels=3):
super(UNetGenerator, self).__init__()
def block(in_c, out_c):
return nn.Sequential(
nn.Conv2d(in_c, out_c, 4, 2, 1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
def up_block(in_c, out_c):
return nn.Sequential(
nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
self.down1 = block(in_channels, 64)
self.down2 = block(64, 128)
self.down3 = block(128, 256)
self.down4 = block(256, 512)
self.up1 = up_block(512, 256)
self.up2 = up_block(512, 128)
self.up3 = up_block(256, 64)
self.up4 = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
nn.Tanh()
)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
u1 = self.up1(d4)
u2 = self.up2(torch.cat([u1, d3], dim=1))
u3 = self.up3(torch.cat([u2, d2], dim=1))
u4 = self.up4(torch.cat([u3, d1], dim=1))
return u4
# ----------------- Image Transforms -----------------
# img_transform = transforms.Compose([
# transforms.Resize((256, 192)),
# transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# ])
#new changes
#end new changes
# ----------------- Helper Functions -----------------
def get_segmentation(image: Image.Image):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = parser_model(**inputs)
logits = outputs.logits
predicted = torch.argmax(logits, dim=1)[0].cpu().numpy()
return predicted
def generate_agnostic(image: Image.Image, segmentation):
img_np = np.array(image.resize((192, 256)))
agnostic_np = img_np.copy()
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
clothing_labels = [4]
for label in clothing_labels:
agnostic_np[segmentation_resized == label] = [128, 128, 128]
return Image.fromarray(agnostic_np)
def load_model(model_type):
if model_type == "UNet":
model = UNetGenerator().to(device)
checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
elif model_type == "GAN":
model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device)
state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
elif model_type == "Diffusion":
model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device)
state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
else:
raise ValueError("Invalid model type")
if state_dict is None:
raise KeyError(f"No valid state_dict found for model type {model_type}")
model.load_state_dict(state_dict)
model.eval()
return model
# def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model):
# agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
# cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
# input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
# with torch.no_grad():
# output = model(input_tensor)
# output_img = output[0].cpu().permute(1, 2, 0).numpy()
# output_img = (output_img + 1) / 2
# output_img = np.clip(output_img, 0, 1)
# person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
# segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
# blend_mask = (segmentation_resized == 0).astype(np.float32)
# blend_mask = np.expand_dims(blend_mask, axis=2)
# final_output = blend_mask * person_np + (1 - blend_mask) * output_img
# final_output = (final_output * 255).astype(np.uint8)
# return Image.fromarray(final_output)
#new changes
# def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
# if model_type == "UNet":
# img_transform = transforms.Compose([
# transforms.Resize((256, 192)),
# transforms.ToTensor()
# ])
# else:
# img_transform = transforms.Compose([
# transforms.Resize((256, 192)),
# transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# ])
# agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
# cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
# input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
# with torch.no_grad():
# output = model(input_tensor)
# if model_type == "UNet":
# output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
# output_img = (output_img * 255).astype(np.uint8)
# return Image.fromarray(output_img)
# else:
# output_img = output[0].cpu().permute(1, 2, 0).numpy()
# output_img = (output_img + 1) / 2
# output_img = np.clip(output_img, 0, 1)
# person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
# segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
# blend_mask = (segmentation_resized == 0).astype(np.float32)
# blend_mask = np.expand_dims(blend_mask, axis=2)
# final_output = blend_mask * person_np + (1 - blend_mask) * output_img
# final_output = (final_output * 255).astype(np.uint8)
# return Image.fromarray(final_output)
def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
if model_type == "UNet":
img_transform = transforms.Compose([
transforms.Resize((256, 192)),
transforms.ToTensor()
])
else:
img_transform = transforms.Compose([
transforms.Resize((256, 192)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
with torch.no_grad():
output = model(input_tensor)
if model_type == "UNet":
output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
output_img = (output_img * 255).astype(np.uint8)
return Image.fromarray(output_img)
else:
output_img = output[0].cpu().permute(1, 2, 0).numpy()
output_img = (output_img + 1) / 2
output_img = np.clip(output_img, 0, 1)
person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
blend_mask = (segmentation_resized == 0).astype(np.float32)
blend_mask = np.expand_dims(blend_mask, axis=2)
final_output = blend_mask * person_np + (1 - blend_mask) * output_img
final_output = (final_output * 255).astype(np.uint8)
return Image.fromarray(final_output)
#new changes end
# ----------------- Inference Pipeline -----------------
def virtual_tryon(person_image, cloth_image, model_type):
segmentation = get_segmentation(person_image)
agnostic = generate_agnostic(person_image, segmentation)
model = load_model(model_type)
result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model, model_type)
# result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model)
return agnostic, result
# ----------------- Gradio Interface -----------------
demo = gr.Interface(
fn=virtual_tryon,
inputs=[
gr.Image(type="pil", label="Person Image"),
gr.Image(type="pil", label="Cloth Image"),
gr.Radio(choices=["UNet", "GAN", "Diffusion"], label="Model Type", value="UNet")
],
outputs=[
gr.Image(type="pil", label="Agnostic (Torso Masked)"),
gr.Image(type="pil", label="Virtual Try-On Output")
],
title="👕 Virtual Try-On App",
description="Upload a person image and a clothing image, select a model (UNet, GAN, or Diffusion), and try it on virtually."
)
if __name__ == "__main__":
demo.launch(share=True)