Spaces:
Running
Running
Create app.py
Browse files
app.py
CHANGED
@@ -61,17 +61,7 @@ class UNetGenerator(nn.Module):
|
|
61 |
u4 = self.up4(torch.cat([u3, d1], dim=1))
|
62 |
return u4
|
63 |
|
64 |
-
# ----------------- Image
|
65 |
-
# img_transform = transforms.Compose([
|
66 |
-
# transforms.Resize((256, 192)),
|
67 |
-
# transforms.ToTensor(),
|
68 |
-
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
69 |
-
# ])
|
70 |
-
#new changes
|
71 |
-
|
72 |
-
#end new changes
|
73 |
-
|
74 |
-
# ----------------- Helper Functions -----------------
|
75 |
def get_segmentation(image: Image.Image):
|
76 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
77 |
with torch.no_grad():
|
@@ -80,6 +70,7 @@ def get_segmentation(image: Image.Image):
|
|
80 |
predicted = torch.argmax(logits, dim=1)[0].cpu().numpy()
|
81 |
return predicted
|
82 |
|
|
|
83 |
def generate_agnostic(image: Image.Image, segmentation):
|
84 |
img_np = np.array(image.resize((192, 256)))
|
85 |
agnostic_np = img_np.copy()
|
@@ -89,6 +80,7 @@ def generate_agnostic(image: Image.Image, segmentation):
|
|
89 |
agnostic_np[segmentation_resized == label] = [128, 128, 128]
|
90 |
return Image.fromarray(agnostic_np)
|
91 |
|
|
|
92 |
def load_model(model_type):
|
93 |
if model_type == "UNet":
|
94 |
model = UNetGenerator().to(device)
|
@@ -112,69 +104,7 @@ def load_model(model_type):
|
|
112 |
model.eval()
|
113 |
return model
|
114 |
|
115 |
-
#
|
116 |
-
# agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
|
117 |
-
# cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
|
118 |
-
# input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
|
119 |
-
|
120 |
-
# with torch.no_grad():
|
121 |
-
# output = model(input_tensor)
|
122 |
-
|
123 |
-
# output_img = output[0].cpu().permute(1, 2, 0).numpy()
|
124 |
-
# output_img = (output_img + 1) / 2
|
125 |
-
# output_img = np.clip(output_img, 0, 1)
|
126 |
-
|
127 |
-
# person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
|
128 |
-
# segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
|
129 |
-
# blend_mask = (segmentation_resized == 0).astype(np.float32)
|
130 |
-
# blend_mask = np.expand_dims(blend_mask, axis=2)
|
131 |
-
|
132 |
-
# final_output = blend_mask * person_np + (1 - blend_mask) * output_img
|
133 |
-
# final_output = (final_output * 255).astype(np.uint8)
|
134 |
-
|
135 |
-
# return Image.fromarray(final_output)
|
136 |
-
#new changes
|
137 |
-
# def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
|
138 |
-
|
139 |
-
# if model_type == "UNet":
|
140 |
-
# img_transform = transforms.Compose([
|
141 |
-
# transforms.Resize((256, 192)),
|
142 |
-
# transforms.ToTensor()
|
143 |
-
# ])
|
144 |
-
|
145 |
-
# else:
|
146 |
-
|
147 |
-
# img_transform = transforms.Compose([
|
148 |
-
# transforms.Resize((256, 192)),
|
149 |
-
# transforms.ToTensor(),
|
150 |
-
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
151 |
-
# ])
|
152 |
-
|
153 |
-
# agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
|
154 |
-
# cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
|
155 |
-
# input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
|
156 |
-
|
157 |
-
# with torch.no_grad():
|
158 |
-
# output = model(input_tensor)
|
159 |
-
|
160 |
-
# if model_type == "UNet":
|
161 |
-
# output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
|
162 |
-
# output_img = (output_img * 255).astype(np.uint8)
|
163 |
-
# return Image.fromarray(output_img)
|
164 |
-
# else:
|
165 |
-
# output_img = output[0].cpu().permute(1, 2, 0).numpy()
|
166 |
-
# output_img = (output_img + 1) / 2
|
167 |
-
# output_img = np.clip(output_img, 0, 1)
|
168 |
-
|
169 |
-
# person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
|
170 |
-
# segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
|
171 |
-
# blend_mask = (segmentation_resized == 0).astype(np.float32)
|
172 |
-
# blend_mask = np.expand_dims(blend_mask, axis=2)
|
173 |
-
|
174 |
-
# final_output = blend_mask * person_np + (1 - blend_mask) * output_img
|
175 |
-
# final_output = (final_output * 255).astype(np.uint8)
|
176 |
-
|
177 |
-
# return Image.fromarray(final_output)
|
178 |
def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
|
179 |
if model_type == "UNet":
|
180 |
img_transform = transforms.Compose([
|
@@ -212,32 +142,92 @@ def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, mod
|
|
212 |
final_output = blend_mask * person_np + (1 - blend_mask) * output_img
|
213 |
final_output = (final_output * 255).astype(np.uint8)
|
214 |
return Image.fromarray(final_output)
|
215 |
-
#new changes end
|
216 |
-
|
217 |
|
218 |
-
# -----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
def virtual_tryon(person_image, cloth_image, model_type):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
segmentation = get_segmentation(person_image)
|
221 |
agnostic = generate_agnostic(person_image, segmentation)
|
222 |
model = load_model(model_type)
|
223 |
result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model, model_type)
|
224 |
-
# result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model)
|
225 |
return agnostic, result
|
226 |
|
227 |
-
# ----------------- Gradio
|
228 |
demo = gr.Interface(
|
229 |
fn=virtual_tryon,
|
230 |
inputs=[
|
231 |
gr.Image(type="pil", label="Person Image"),
|
232 |
gr.Image(type="pil", label="Cloth Image"),
|
233 |
-
gr.Radio(choices=["UNet", "GAN", "Diffusion"], label="Model Type", value="UNet")
|
234 |
],
|
235 |
outputs=[
|
236 |
gr.Image(type="pil", label="Agnostic (Torso Masked)"),
|
237 |
gr.Image(type="pil", label="Virtual Try-On Output")
|
238 |
],
|
239 |
title="👕 Virtual Try-On App",
|
240 |
-
description="Upload a person image and a clothing image, select a model (UNet, GAN,
|
241 |
)
|
242 |
|
243 |
if __name__ == "__main__":
|
|
|
61 |
u4 = self.up4(torch.cat([u3, d1], dim=1))
|
62 |
return u4
|
63 |
|
64 |
+
# ----------------- Image Segmentation -----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
def get_segmentation(image: Image.Image):
|
66 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
67 |
with torch.no_grad():
|
|
|
70 |
predicted = torch.argmax(logits, dim=1)[0].cpu().numpy()
|
71 |
return predicted
|
72 |
|
73 |
+
# ----------------- Agnostic Creation -----------------
|
74 |
def generate_agnostic(image: Image.Image, segmentation):
|
75 |
img_np = np.array(image.resize((192, 256)))
|
76 |
agnostic_np = img_np.copy()
|
|
|
80 |
agnostic_np[segmentation_resized == label] = [128, 128, 128]
|
81 |
return Image.fromarray(agnostic_np)
|
82 |
|
83 |
+
# ----------------- Load Model -----------------
|
84 |
def load_model(model_type):
|
85 |
if model_type == "UNet":
|
86 |
model = UNetGenerator().to(device)
|
|
|
104 |
model.eval()
|
105 |
return model
|
106 |
|
107 |
+
# ----------------- Generate Try-On -----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
|
109 |
if model_type == "UNet":
|
110 |
img_transform = transforms.Compose([
|
|
|
142 |
final_output = blend_mask * person_np + (1 - blend_mask) * output_img
|
143 |
final_output = (final_output * 255).astype(np.uint8)
|
144 |
return Image.fromarray(final_output)
|
|
|
|
|
145 |
|
146 |
+
# ----------------- Traditional CV Pipeline -----------------
|
147 |
+
def create_agnostic_traditional(person_np, label_np):
|
148 |
+
mask = (label_np == 4).astype(np.uint8)
|
149 |
+
kernel = np.ones((7, 7), np.uint8)
|
150 |
+
dilated = cv2.dilate(mask, kernel, iterations=2)
|
151 |
+
agnostic = person_np.copy()
|
152 |
+
agnostic[dilated == 1] = [128, 128, 128]
|
153 |
+
return agnostic, dilated
|
154 |
+
|
155 |
+
def improved_warp_cloth(cloth_np, person_np, label_np):
|
156 |
+
mask = (label_np == 4).astype(np.uint8) * 255
|
157 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
158 |
+
if not contours:
|
159 |
+
return cloth_np
|
160 |
+
|
161 |
+
cnt = max(contours, key=cv2.contourArea)
|
162 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
163 |
+
src_h, src_w = cloth_np.shape[:2]
|
164 |
+
src_points = np.array([[0,0],[src_w-1,0],[src_w-1,src_h-1],[0,src_h-1]], dtype=np.float32)
|
165 |
+
padding_x, padding_y = int(w*0.05), int(h*0.05)
|
166 |
+
dst_points = np.array([
|
167 |
+
[max(0, x - padding_x), max(0, y - padding_y)],
|
168 |
+
[min(person_np.shape[1] - 1, x + w + padding_x), max(0, y - padding_y)],
|
169 |
+
[min(person_np.shape[1] - 1, x + w + padding_x), min(person_np.shape[0] - 1, y + h + padding_y)],
|
170 |
+
[max(0, x - padding_x), min(person_np.shape[0] - 1, y + h + padding_y)]
|
171 |
+
], dtype=np.float32)
|
172 |
+
M = cv2.getPerspectiveTransform(src_points, dst_points)
|
173 |
+
warped = cv2.warpPerspective(cloth_np, M, (person_np.shape[1], person_np.shape[0]), borderMode=cv2.BORDER_CONSTANT)
|
174 |
+
return warped
|
175 |
+
|
176 |
+
def improved_blend_traditional(agnostic_np, warped_cloth_np, label_np):
|
177 |
+
target_mask = (label_np == 4).astype(np.uint8)
|
178 |
+
kernel = np.ones((9, 9), np.uint8)
|
179 |
+
target_mask = cv2.dilate(target_mask, kernel, iterations=2) * 255
|
180 |
+
|
181 |
+
gray = cv2.cvtColor(warped_cloth_np, cv2.COLOR_BGR2GRAY)
|
182 |
+
_, cloth_mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
|
183 |
+
combined_mask = cv2.bitwise_and(target_mask, cloth_mask)
|
184 |
+
combined_mask = cv2.GaussianBlur(combined_mask, (5, 5), 0)
|
185 |
+
|
186 |
+
M = cv2.moments(combined_mask)
|
187 |
+
center = (int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])) if M["m00"] != 0 else (96, 128)
|
188 |
+
|
189 |
+
try:
|
190 |
+
output = cv2.seamlessClone(warped_cloth_np, agnostic_np, combined_mask, center, cv2.NORMAL_CLONE)
|
191 |
+
except:
|
192 |
+
mask_3d = np.stack([combined_mask / 255.0] * 3, axis=2)
|
193 |
+
output = warped_cloth_np * mask_3d + agnostic_np * (1 - mask_3d)
|
194 |
+
output = output.astype(np.uint8)
|
195 |
+
return output
|
196 |
+
|
197 |
+
# ----------------- Main Pipeline -----------------
|
198 |
def virtual_tryon(person_image, cloth_image, model_type):
|
199 |
+
if model_type == "Traditional":
|
200 |
+
person_np = np.array(person_image.resize((192, 256)))[:, :, ::-1]
|
201 |
+
cloth_np = np.array(cloth_image.resize((192, 256)))[:, :, ::-1]
|
202 |
+
segmentation = get_segmentation(person_image)
|
203 |
+
label_np = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
|
204 |
+
|
205 |
+
agnostic_np, _ = create_agnostic_traditional(person_np, label_np)
|
206 |
+
warped_cloth = improved_warp_cloth(cloth_np, person_np, label_np)
|
207 |
+
output_np = improved_blend_traditional(agnostic_np, warped_cloth, label_np)
|
208 |
+
|
209 |
+
return Image.fromarray(agnostic_np[:, :, ::-1]), Image.fromarray(output_np[:, :, ::-1])
|
210 |
+
|
211 |
segmentation = get_segmentation(person_image)
|
212 |
agnostic = generate_agnostic(person_image, segmentation)
|
213 |
model = load_model(model_type)
|
214 |
result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model, model_type)
|
|
|
215 |
return agnostic, result
|
216 |
|
217 |
+
# ----------------- Gradio UI -----------------
|
218 |
demo = gr.Interface(
|
219 |
fn=virtual_tryon,
|
220 |
inputs=[
|
221 |
gr.Image(type="pil", label="Person Image"),
|
222 |
gr.Image(type="pil", label="Cloth Image"),
|
223 |
+
gr.Radio(choices=["UNet", "GAN", "Diffusion", "Traditional"], label="Model Type", value="UNet")
|
224 |
],
|
225 |
outputs=[
|
226 |
gr.Image(type="pil", label="Agnostic (Torso Masked)"),
|
227 |
gr.Image(type="pil", label="Virtual Try-On Output")
|
228 |
],
|
229 |
title="👕 Virtual Try-On App",
|
230 |
+
description="Upload a person image and a clothing image, select a model (UNet, GAN, Diffusion, Traditional), and try it on virtually."
|
231 |
)
|
232 |
|
233 |
if __name__ == "__main__":
|