Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,12 +5,13 @@ import torch.nn as nn
|
|
5 |
import numpy as np
|
6 |
from torchvision import transforms
|
7 |
import cv2
|
8 |
-
|
9 |
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
|
|
|
10 |
|
|
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
-
# ----------------- Load Human Parser Model
|
14 |
processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
|
15 |
parser_model = SegformerForSemanticSegmentation.from_pretrained(
|
16 |
"matei-dorian/segformer-b5-finetuned-human-parsing"
|
@@ -60,17 +61,25 @@ class UNetGenerator(nn.Module):
|
|
60 |
u4 = self.up4(torch.cat([u3, d1], dim=1))
|
61 |
return u4
|
62 |
|
63 |
-
# ----------------- Load UNet Try-On Model -----------------
|
64 |
-
tryon_model = UNetGenerator().to(device)
|
65 |
-
checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
|
66 |
-
tryon_model.load_state_dict(checkpoint['model_state_dict'])
|
67 |
-
tryon_model.eval()
|
68 |
-
|
69 |
# ----------------- Image Transforms -----------------
|
70 |
-
img_transform = transforms.Compose([
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# ----------------- Helper Functions -----------------
|
76 |
def get_segmentation(image: Image.Image):
|
@@ -85,39 +94,108 @@ def generate_agnostic(image: Image.Image, segmentation):
|
|
85 |
img_np = np.array(image.resize((192, 256)))
|
86 |
agnostic_np = img_np.copy()
|
87 |
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
|
88 |
-
|
|
|
|
|
89 |
return Image.fromarray(agnostic_np)
|
90 |
|
91 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
|
93 |
cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
|
94 |
input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
|
95 |
|
96 |
with torch.no_grad():
|
97 |
-
output =
|
98 |
-
output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
|
99 |
-
output_img = (output_img * 255).astype(np.uint8)
|
100 |
-
return Image.fromarray(output_img)
|
101 |
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
segmentation = get_segmentation(person_image)
|
105 |
agnostic = generate_agnostic(person_image, segmentation)
|
106 |
-
|
|
|
|
|
107 |
return agnostic, result
|
108 |
|
|
|
109 |
demo = gr.Interface(
|
110 |
fn=virtual_tryon,
|
111 |
inputs=[
|
112 |
gr.Image(type="pil", label="Person Image"),
|
113 |
-
gr.Image(type="pil", label="Cloth Image")
|
|
|
114 |
],
|
115 |
outputs=[
|
116 |
gr.Image(type="pil", label="Agnostic (Torso Masked)"),
|
117 |
gr.Image(type="pil", label="Virtual Try-On Output")
|
118 |
],
|
119 |
-
title="👕 Virtual Try-On
|
120 |
-
description="Upload a person image and a
|
121 |
)
|
122 |
|
123 |
if __name__ == "__main__":
|
|
|
5 |
import numpy as np
|
6 |
from torchvision import transforms
|
7 |
import cv2
|
|
|
8 |
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
|
9 |
+
from improved_viton import ImprovedUNetGenerator
|
10 |
|
11 |
+
# ----------------- Device -----------------
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
|
14 |
+
# ----------------- Load Human Parser Model -----------------
|
15 |
processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
|
16 |
parser_model = SegformerForSemanticSegmentation.from_pretrained(
|
17 |
"matei-dorian/segformer-b5-finetuned-human-parsing"
|
|
|
61 |
u4 = self.up4(torch.cat([u3, d1], dim=1))
|
62 |
return u4
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
# ----------------- Image Transforms -----------------
|
65 |
+
# img_transform = transforms.Compose([
|
66 |
+
# transforms.Resize((256, 192)),
|
67 |
+
# transforms.ToTensor(),
|
68 |
+
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
69 |
+
# ])
|
70 |
+
#new changes
|
71 |
+
if model_type == "UNet":
|
72 |
+
img_transform = transforms.Compose([
|
73 |
+
transforms.Resize((256, 192)),
|
74 |
+
transforms.ToTensor()
|
75 |
+
])
|
76 |
+
else:
|
77 |
+
img_transform = transforms.Compose([
|
78 |
+
transforms.Resize((256, 192)),
|
79 |
+
transforms.ToTensor(),
|
80 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
81 |
+
])
|
82 |
+
#end new changes
|
83 |
|
84 |
# ----------------- Helper Functions -----------------
|
85 |
def get_segmentation(image: Image.Image):
|
|
|
94 |
img_np = np.array(image.resize((192, 256)))
|
95 |
agnostic_np = img_np.copy()
|
96 |
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
|
97 |
+
clothing_labels = [4]
|
98 |
+
for label in clothing_labels:
|
99 |
+
agnostic_np[segmentation_resized == label] = [128, 128, 128]
|
100 |
return Image.fromarray(agnostic_np)
|
101 |
|
102 |
+
def load_model(model_type):
|
103 |
+
if model_type == "UNet":
|
104 |
+
model = UNetGenerator().to(device)
|
105 |
+
checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
|
106 |
+
state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
|
107 |
+
elif model_type == "GAN":
|
108 |
+
model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
|
109 |
+
checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device)
|
110 |
+
state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
|
111 |
+
elif model_type == "Diffusion":
|
112 |
+
model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
|
113 |
+
checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device)
|
114 |
+
state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
|
115 |
+
else:
|
116 |
+
raise ValueError("Invalid model type")
|
117 |
+
|
118 |
+
if state_dict is None:
|
119 |
+
raise KeyError(f"No valid state_dict found for model type {model_type}")
|
120 |
+
|
121 |
+
model.load_state_dict(state_dict)
|
122 |
+
model.eval()
|
123 |
+
return model
|
124 |
+
|
125 |
+
# def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model):
|
126 |
+
# agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
|
127 |
+
# cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
|
128 |
+
# input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
|
129 |
+
|
130 |
+
# with torch.no_grad():
|
131 |
+
# output = model(input_tensor)
|
132 |
+
|
133 |
+
# output_img = output[0].cpu().permute(1, 2, 0).numpy()
|
134 |
+
# output_img = (output_img + 1) / 2
|
135 |
+
# output_img = np.clip(output_img, 0, 1)
|
136 |
+
|
137 |
+
# person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
|
138 |
+
# segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
|
139 |
+
# blend_mask = (segmentation_resized == 0).astype(np.float32)
|
140 |
+
# blend_mask = np.expand_dims(blend_mask, axis=2)
|
141 |
+
|
142 |
+
# final_output = blend_mask * person_np + (1 - blend_mask) * output_img
|
143 |
+
# final_output = (final_output * 255).astype(np.uint8)
|
144 |
+
|
145 |
+
# return Image.fromarray(final_output)
|
146 |
+
#new changes
|
147 |
+
def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
|
148 |
agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
|
149 |
cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
|
150 |
input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
|
151 |
|
152 |
with torch.no_grad():
|
153 |
+
output = model(input_tensor)
|
|
|
|
|
|
|
154 |
|
155 |
+
if model_type == "UNet":
|
156 |
+
output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
|
157 |
+
output_img = (output_img * 255).astype(np.uint8)
|
158 |
+
return Image.fromarray(output_img)
|
159 |
+
else:
|
160 |
+
output_img = output[0].cpu().permute(1, 2, 0).numpy()
|
161 |
+
output_img = (output_img + 1) / 2
|
162 |
+
output_img = np.clip(output_img, 0, 1)
|
163 |
+
|
164 |
+
person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
|
165 |
+
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
|
166 |
+
blend_mask = (segmentation_resized == 0).astype(np.float32)
|
167 |
+
blend_mask = np.expand_dims(blend_mask, axis=2)
|
168 |
+
|
169 |
+
final_output = blend_mask * person_np + (1 - blend_mask) * output_img
|
170 |
+
final_output = (final_output * 255).astype(np.uint8)
|
171 |
+
|
172 |
+
return Image.fromarray(final_output)
|
173 |
+
#new changes end
|
174 |
+
|
175 |
+
|
176 |
+
# ----------------- Inference Pipeline -----------------
|
177 |
+
def virtual_tryon(person_image, cloth_image, model_type):
|
178 |
segmentation = get_segmentation(person_image)
|
179 |
agnostic = generate_agnostic(person_image, segmentation)
|
180 |
+
model = load_model(model_type)
|
181 |
+
result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model, model_type)
|
182 |
+
# result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model)
|
183 |
return agnostic, result
|
184 |
|
185 |
+
# ----------------- Gradio Interface -----------------
|
186 |
demo = gr.Interface(
|
187 |
fn=virtual_tryon,
|
188 |
inputs=[
|
189 |
gr.Image(type="pil", label="Person Image"),
|
190 |
+
gr.Image(type="pil", label="Cloth Image"),
|
191 |
+
gr.Radio(choices=["UNet", "GAN", "Diffusion"], label="Model Type", value="UNet")
|
192 |
],
|
193 |
outputs=[
|
194 |
gr.Image(type="pil", label="Agnostic (Torso Masked)"),
|
195 |
gr.Image(type="pil", label="Virtual Try-On Output")
|
196 |
],
|
197 |
+
title="👕 Virtual Try-On App",
|
198 |
+
description="Upload a person image and a clothing image, select a model (UNet, GAN, or Diffusion), and try it on virtually."
|
199 |
)
|
200 |
|
201 |
if __name__ == "__main__":
|