import gradio as gr from PIL import Image import torch import torch.nn as nn import numpy as np from torchvision import transforms import cv2 from transformers import AutoImageProcessor, SegformerForSemanticSegmentation from improved_viton import ImprovedUNetGenerator # ----------------- Device ----------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------------- Load Human Parser Model ----------------- processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing") parser_model = SegformerForSemanticSegmentation.from_pretrained( "matei-dorian/segformer-b5-finetuned-human-parsing" ).to(device).eval() # ----------------- UNet Generator Definition ----------------- class UNetGenerator(nn.Module): def __init__(self, in_channels=6, out_channels=3): super(UNetGenerator, self).__init__() def block(in_c, out_c): return nn.Sequential( nn.Conv2d(in_c, out_c, 4, 2, 1), nn.BatchNorm2d(out_c), nn.ReLU(inplace=True) ) def up_block(in_c, out_c): return nn.Sequential( nn.ConvTranspose2d(in_c, out_c, 4, 2, 1), nn.BatchNorm2d(out_c), nn.ReLU(inplace=True) ) self.down1 = block(in_channels, 64) self.down2 = block(64, 128) self.down3 = block(128, 256) self.down4 = block(256, 512) self.up1 = up_block(512, 256) self.up2 = up_block(512, 128) self.up3 = up_block(256, 64) self.up4 = nn.Sequential( nn.ConvTranspose2d(128, out_channels, 4, 2, 1), nn.Tanh() ) def forward(self, x): d1 = self.down1(x) d2 = self.down2(d1) d3 = self.down3(d2) d4 = self.down4(d3) u1 = self.up1(d4) u2 = self.up2(torch.cat([u1, d3], dim=1)) u3 = self.up3(torch.cat([u2, d2], dim=1)) u4 = self.up4(torch.cat([u3, d1], dim=1)) return u4 # ----------------- Image Segmentation ----------------- def get_segmentation(image: Image.Image): inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = parser_model(**inputs) logits = outputs.logits predicted = torch.argmax(logits, dim=1)[0].cpu().numpy() return predicted # ----------------- Agnostic Creation ----------------- def generate_agnostic(image: Image.Image, segmentation): img_np = np.array(image.resize((192, 256))) agnostic_np = img_np.copy() segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST) clothing_labels = [4] for label in clothing_labels: agnostic_np[segmentation_resized == label] = [128, 128, 128] return Image.fromarray(agnostic_np) # ----------------- Load Model ----------------- def load_model(model_type): if model_type == "UNet": model = UNetGenerator().to(device) checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device) state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict") elif model_type == "GAN": model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device) checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device) state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict") elif model_type == "Diffusion": model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device) checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device) state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict") else: raise ValueError("Invalid model type") if state_dict is None: raise KeyError(f"No valid state_dict found for model type {model_type}") model.load_state_dict(state_dict) model.eval() return model # ----------------- Generate Try-On ----------------- def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type): if model_type == "UNet": img_transform = transforms.Compose([ transforms.Resize((256, 192)), transforms.ToTensor() ]) else: img_transform = transforms.Compose([ transforms.Resize((256, 192)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device) cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device) input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1) with torch.no_grad(): output = model(input_tensor) if model_type == "UNet": output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy() output_img = (output_img * 255).astype(np.uint8) return Image.fromarray(output_img) else: output_img = output[0].cpu().permute(1, 2, 0).numpy() output_img = (output_img + 1) / 2 output_img = np.clip(output_img, 0, 1) person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0 segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST) blend_mask = (segmentation_resized == 0).astype(np.float32) blend_mask = np.expand_dims(blend_mask, axis=2) final_output = blend_mask * person_np + (1 - blend_mask) * output_img final_output = (final_output * 255).astype(np.uint8) return Image.fromarray(final_output) # ----------------- Traditional CV Pipeline ----------------- def create_agnostic_traditional(person_np, label_np): mask = (label_np == 4).astype(np.uint8) kernel = np.ones((7, 7), np.uint8) dilated = cv2.dilate(mask, kernel, iterations=2) agnostic = person_np.copy() agnostic[dilated == 1] = [128, 128, 128] return agnostic, dilated def improved_warp_cloth(cloth_np, person_np, label_np): mask = (label_np == 4).astype(np.uint8) * 255 contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return cloth_np cnt = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(cnt) src_h, src_w = cloth_np.shape[:2] src_points = np.array([[0,0],[src_w-1,0],[src_w-1,src_h-1],[0,src_h-1]], dtype=np.float32) padding_x, padding_y = int(w*0.05), int(h*0.05) dst_points = np.array([ [max(0, x - padding_x), max(0, y - padding_y)], [min(person_np.shape[1] - 1, x + w + padding_x), max(0, y - padding_y)], [min(person_np.shape[1] - 1, x + w + padding_x), min(person_np.shape[0] - 1, y + h + padding_y)], [max(0, x - padding_x), min(person_np.shape[0] - 1, y + h + padding_y)] ], dtype=np.float32) M = cv2.getPerspectiveTransform(src_points, dst_points) warped = cv2.warpPerspective(cloth_np, M, (person_np.shape[1], person_np.shape[0]), borderMode=cv2.BORDER_CONSTANT) return warped def improved_blend_traditional(agnostic_np, warped_cloth_np, label_np): target_mask = (label_np == 4).astype(np.uint8) kernel = np.ones((9, 9), np.uint8) target_mask = cv2.dilate(target_mask, kernel, iterations=2) * 255 gray = cv2.cvtColor(warped_cloth_np, cv2.COLOR_BGR2GRAY) _, cloth_mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY) combined_mask = cv2.bitwise_and(target_mask, cloth_mask) combined_mask = cv2.GaussianBlur(combined_mask, (5, 5), 0) M = cv2.moments(combined_mask) center = (int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])) if M["m00"] != 0 else (96, 128) try: output = cv2.seamlessClone(warped_cloth_np, agnostic_np, combined_mask, center, cv2.NORMAL_CLONE) except: mask_3d = np.stack([combined_mask / 255.0] * 3, axis=2) output = warped_cloth_np * mask_3d + agnostic_np * (1 - mask_3d) output = output.astype(np.uint8) return output # ----------------- Main Pipeline ----------------- def virtual_tryon(person_image, cloth_image, model_type): if model_type == "Traditional": person_np = np.array(person_image.resize((192, 256)))[:, :, ::-1] cloth_np = np.array(cloth_image.resize((192, 256)))[:, :, ::-1] segmentation = get_segmentation(person_image) label_np = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST) agnostic_np, _ = create_agnostic_traditional(person_np, label_np) warped_cloth = improved_warp_cloth(cloth_np, person_np, label_np) output_np = improved_blend_traditional(agnostic_np, warped_cloth, label_np) return Image.fromarray(agnostic_np[:, :, ::-1]), Image.fromarray(output_np[:, :, ::-1]) segmentation = get_segmentation(person_image) agnostic = generate_agnostic(person_image, segmentation) model = load_model(model_type) result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model, model_type) return agnostic, result # ----------------- Gradio UI ----------------- demo = gr.Interface( fn=virtual_tryon, inputs=[ gr.Image(type="pil", label="Person Image"), gr.Image(type="pil", label="Cloth Image"), gr.Radio(choices=["UNet", "GAN", "Diffusion", "Traditional"], label="Model Type", value="UNet") ], outputs=[ gr.Image(type="pil", label="Agnostic (Torso Masked)"), gr.Image(type="pil", label="Virtual Try-On Output") ], title="👕 Virtual Try-On App", description="Upload a person image and a clothing image, select a model (UNet, GAN, Diffusion, Traditional), and try it on virtually." ) if __name__ == "__main__": demo.launch(share=True)