sengourav012 commited on
Commit
b99dca9
·
verified ·
1 Parent(s): 8e88383

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from torchvision import transforms
7
+ import cv2
8
+
9
+ from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # ----------------- Load Human Parser Model from Hugging Face Hub -----------------
14
+ processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
15
+ parser_model = SegformerForSemanticSegmentation.from_pretrained(
16
+ "matei-dorian/segformer-b5-finetuned-human-parsing"
17
+ ).to(device).eval()
18
+
19
+ # ----------------- UNet Generator Definition -----------------
20
+ class UNetGenerator(nn.Module):
21
+ def __init__(self, in_channels=6, out_channels=3):
22
+ super(UNetGenerator, self).__init__()
23
+
24
+ def block(in_c, out_c):
25
+ return nn.Sequential(
26
+ nn.Conv2d(in_c, out_c, 4, 2, 1),
27
+ nn.BatchNorm2d(out_c),
28
+ nn.ReLU(inplace=True)
29
+ )
30
+
31
+ def up_block(in_c, out_c):
32
+ return nn.Sequential(
33
+ nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
34
+ nn.BatchNorm2d(out_c),
35
+ nn.ReLU(inplace=True)
36
+ )
37
+
38
+ self.down1 = block(in_channels, 64)
39
+ self.down2 = block(64, 128)
40
+ self.down3 = block(128, 256)
41
+ self.down4 = block(256, 512)
42
+
43
+ self.up1 = up_block(512, 256)
44
+ self.up2 = up_block(512, 128)
45
+ self.up3 = up_block(256, 64)
46
+ self.up4 = nn.Sequential(
47
+ nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
48
+ nn.Tanh()
49
+ )
50
+
51
+ def forward(self, x):
52
+ d1 = self.down1(x)
53
+ d2 = self.down2(d1)
54
+ d3 = self.down3(d2)
55
+ d4 = self.down4(d3)
56
+
57
+ u1 = self.up1(d4)
58
+ u2 = self.up2(torch.cat([u1, d3], dim=1))
59
+ u3 = self.up3(torch.cat([u2, d2], dim=1))
60
+ u4 = self.up4(torch.cat([u3, d1], dim=1))
61
+ return u4
62
+
63
+ # ----------------- Load UNet Try-On Model -----------------
64
+ tryon_model = UNetGenerator().to(device)
65
+ checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
66
+ tryon_model.load_state_dict(checkpoint['model_state_dict'])
67
+ tryon_model.eval()
68
+
69
+ # ----------------- Image Transforms -----------------
70
+ img_transform = transforms.Compose([
71
+ transforms.Resize((256, 192)),
72
+ transforms.ToTensor()
73
+ ])
74
+
75
+ # ----------------- Helper Functions -----------------
76
+ def get_segmentation(image: Image.Image):
77
+ inputs = processor(images=image, return_tensors="pt").to(device)
78
+ with torch.no_grad():
79
+ outputs = parser_model(**inputs)
80
+ logits = outputs.logits
81
+ predicted = torch.argmax(logits, dim=1)[0].cpu().numpy()
82
+ return predicted
83
+
84
+ def generate_agnostic(image: Image.Image, segmentation):
85
+ img_np = np.array(image.resize((192, 256)))
86
+ agnostic_np = img_np.copy()
87
+ segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
88
+ agnostic_np[segmentation_resized == 4] = [128, 128, 128] # Mask upper clothes
89
+ return Image.fromarray(agnostic_np)
90
+
91
+ def generate_tryon_output(agnostic_img, cloth_img):
92
+ agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
93
+ cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
94
+ input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
95
+
96
+ with torch.no_grad():
97
+ output = tryon_model(input_tensor)
98
+ output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
99
+ output_img = (output_img * 255).astype(np.uint8)
100
+ return Image.fromarray(output_img)
101
+
102
+ # ----------------- Gradio Interface -----------------
103
+ def virtual_tryon(person_image, cloth_image):
104
+ segmentation = get_segmentation(person_image)
105
+ agnostic = generate_agnostic(person_image, segmentation)
106
+ result = generate_tryon_output(agnostic, cloth_image)
107
+ return agnostic, result
108
+
109
+ demo = gr.Interface(
110
+ fn=virtual_tryon,
111
+ inputs=[
112
+ gr.Image(type="pil", label="Person Image"),
113
+ gr.Image(type="pil", label="Cloth Image")
114
+ ],
115
+ outputs=[
116
+ gr.Image(type="pil", label="Agnostic (Torso Masked)"),
117
+ gr.Image(type="pil", label="Virtual Try-On Output")
118
+ ],
119
+ title="👕 Virtual Try-On (UNet + Segformer)",
120
+ description="Upload a person image and a cloth image to try on the cloth virtually."
121
+ )
122
+
123
+ if __name__ == "__main__":
124
+ demo.launch()