File size: 4,397 Bytes
b99dca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
import cv2

from transformers import AutoImageProcessor, SegformerForSemanticSegmentation

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------- Load Human Parser Model from Hugging Face Hub -----------------
processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
parser_model = SegformerForSemanticSegmentation.from_pretrained(
    "matei-dorian/segformer-b5-finetuned-human-parsing"
).to(device).eval()

# ----------------- UNet Generator Definition -----------------
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=6, out_channels=3):
        super(UNetGenerator, self).__init__()

        def block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 4, 2, 1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )

        def up_block(in_c, out_c):
            return nn.Sequential(
                nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )

        self.down1 = block(in_channels, 64)
        self.down2 = block(64, 128)
        self.down3 = block(128, 256)
        self.down4 = block(256, 512)

        self.up1 = up_block(512, 256)
        self.up2 = up_block(512, 128)
        self.up3 = up_block(256, 64)
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        u1 = self.up1(d4)
        u2 = self.up2(torch.cat([u1, d3], dim=1))
        u3 = self.up3(torch.cat([u2, d2], dim=1))
        u4 = self.up4(torch.cat([u3, d1], dim=1))
        return u4

# ----------------- Load UNet Try-On Model -----------------
tryon_model = UNetGenerator().to(device)
checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
tryon_model.load_state_dict(checkpoint['model_state_dict'])
tryon_model.eval()

# ----------------- Image Transforms -----------------
img_transform = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.ToTensor()
])

# ----------------- Helper Functions -----------------
def get_segmentation(image: Image.Image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = parser_model(**inputs)
    logits = outputs.logits
    predicted = torch.argmax(logits, dim=1)[0].cpu().numpy()
    return predicted

def generate_agnostic(image: Image.Image, segmentation):
    img_np = np.array(image.resize((192, 256)))
    agnostic_np = img_np.copy()
    segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
    agnostic_np[segmentation_resized == 4] = [128, 128, 128]  # Mask upper clothes
    return Image.fromarray(agnostic_np)

def generate_tryon_output(agnostic_img, cloth_img):
    agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
    cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
    input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)

    with torch.no_grad():
        output = tryon_model(input_tensor)
    output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
    output_img = (output_img * 255).astype(np.uint8)
    return Image.fromarray(output_img)

# ----------------- Gradio Interface -----------------
def virtual_tryon(person_image, cloth_image):
    segmentation = get_segmentation(person_image)
    agnostic = generate_agnostic(person_image, segmentation)
    result = generate_tryon_output(agnostic, cloth_image)
    return agnostic, result

demo = gr.Interface(
    fn=virtual_tryon,
    inputs=[
        gr.Image(type="pil", label="Person Image"),
        gr.Image(type="pil", label="Cloth Image")
    ],
    outputs=[
        gr.Image(type="pil", label="Agnostic (Torso Masked)"),
        gr.Image(type="pil", label="Virtual Try-On Output")
    ],
    title="👕 Virtual Try-On (UNet + Segformer)",
    description="Upload a person image and a cloth image to try on the cloth virtually."
)

if __name__ == "__main__":
    demo.launch()