sengourav012 commited on
Commit
6706839
·
verified ·
1 Parent(s): 31d417a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -81
app.py CHANGED
@@ -61,17 +61,7 @@ class UNetGenerator(nn.Module):
61
  u4 = self.up4(torch.cat([u3, d1], dim=1))
62
  return u4
63
 
64
- # ----------------- Image Transforms -----------------
65
- # img_transform = transforms.Compose([
66
- # transforms.Resize((256, 192)),
67
- # transforms.ToTensor(),
68
- # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69
- # ])
70
- #new changes
71
-
72
- #end new changes
73
-
74
- # ----------------- Helper Functions -----------------
75
  def get_segmentation(image: Image.Image):
76
  inputs = processor(images=image, return_tensors="pt").to(device)
77
  with torch.no_grad():
@@ -80,6 +70,7 @@ def get_segmentation(image: Image.Image):
80
  predicted = torch.argmax(logits, dim=1)[0].cpu().numpy()
81
  return predicted
82
 
 
83
  def generate_agnostic(image: Image.Image, segmentation):
84
  img_np = np.array(image.resize((192, 256)))
85
  agnostic_np = img_np.copy()
@@ -89,6 +80,7 @@ def generate_agnostic(image: Image.Image, segmentation):
89
  agnostic_np[segmentation_resized == label] = [128, 128, 128]
90
  return Image.fromarray(agnostic_np)
91
 
 
92
  def load_model(model_type):
93
  if model_type == "UNet":
94
  model = UNetGenerator().to(device)
@@ -112,69 +104,7 @@ def load_model(model_type):
112
  model.eval()
113
  return model
114
 
115
- # def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model):
116
- # agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
117
- # cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
118
- # input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
119
-
120
- # with torch.no_grad():
121
- # output = model(input_tensor)
122
-
123
- # output_img = output[0].cpu().permute(1, 2, 0).numpy()
124
- # output_img = (output_img + 1) / 2
125
- # output_img = np.clip(output_img, 0, 1)
126
-
127
- # person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
128
- # segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
129
- # blend_mask = (segmentation_resized == 0).astype(np.float32)
130
- # blend_mask = np.expand_dims(blend_mask, axis=2)
131
-
132
- # final_output = blend_mask * person_np + (1 - blend_mask) * output_img
133
- # final_output = (final_output * 255).astype(np.uint8)
134
-
135
- # return Image.fromarray(final_output)
136
- #new changes
137
- # def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
138
-
139
- # if model_type == "UNet":
140
- # img_transform = transforms.Compose([
141
- # transforms.Resize((256, 192)),
142
- # transforms.ToTensor()
143
- # ])
144
-
145
- # else:
146
-
147
- # img_transform = transforms.Compose([
148
- # transforms.Resize((256, 192)),
149
- # transforms.ToTensor(),
150
- # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
151
- # ])
152
-
153
- # agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
154
- # cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
155
- # input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
156
-
157
- # with torch.no_grad():
158
- # output = model(input_tensor)
159
-
160
- # if model_type == "UNet":
161
- # output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
162
- # output_img = (output_img * 255).astype(np.uint8)
163
- # return Image.fromarray(output_img)
164
- # else:
165
- # output_img = output[0].cpu().permute(1, 2, 0).numpy()
166
- # output_img = (output_img + 1) / 2
167
- # output_img = np.clip(output_img, 0, 1)
168
-
169
- # person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
170
- # segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
171
- # blend_mask = (segmentation_resized == 0).astype(np.float32)
172
- # blend_mask = np.expand_dims(blend_mask, axis=2)
173
-
174
- # final_output = blend_mask * person_np + (1 - blend_mask) * output_img
175
- # final_output = (final_output * 255).astype(np.uint8)
176
-
177
- # return Image.fromarray(final_output)
178
  def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
179
  if model_type == "UNet":
180
  img_transform = transforms.Compose([
@@ -212,32 +142,92 @@ def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, mod
212
  final_output = blend_mask * person_np + (1 - blend_mask) * output_img
213
  final_output = (final_output * 255).astype(np.uint8)
214
  return Image.fromarray(final_output)
215
- #new changes end
216
-
217
 
218
- # ----------------- Inference Pipeline -----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def virtual_tryon(person_image, cloth_image, model_type):
 
 
 
 
 
 
 
 
 
 
 
 
220
  segmentation = get_segmentation(person_image)
221
  agnostic = generate_agnostic(person_image, segmentation)
222
  model = load_model(model_type)
223
  result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model, model_type)
224
- # result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model)
225
  return agnostic, result
226
 
227
- # ----------------- Gradio Interface -----------------
228
  demo = gr.Interface(
229
  fn=virtual_tryon,
230
  inputs=[
231
  gr.Image(type="pil", label="Person Image"),
232
  gr.Image(type="pil", label="Cloth Image"),
233
- gr.Radio(choices=["UNet", "GAN", "Diffusion"], label="Model Type", value="UNet")
234
  ],
235
  outputs=[
236
  gr.Image(type="pil", label="Agnostic (Torso Masked)"),
237
  gr.Image(type="pil", label="Virtual Try-On Output")
238
  ],
239
  title="👕 Virtual Try-On App",
240
- description="Upload a person image and a clothing image, select a model (UNet, GAN, or Diffusion), and try it on virtually."
241
  )
242
 
243
  if __name__ == "__main__":
 
61
  u4 = self.up4(torch.cat([u3, d1], dim=1))
62
  return u4
63
 
64
+ # ----------------- Image Segmentation -----------------
 
 
 
 
 
 
 
 
 
 
65
  def get_segmentation(image: Image.Image):
66
  inputs = processor(images=image, return_tensors="pt").to(device)
67
  with torch.no_grad():
 
70
  predicted = torch.argmax(logits, dim=1)[0].cpu().numpy()
71
  return predicted
72
 
73
+ # ----------------- Agnostic Creation -----------------
74
  def generate_agnostic(image: Image.Image, segmentation):
75
  img_np = np.array(image.resize((192, 256)))
76
  agnostic_np = img_np.copy()
 
80
  agnostic_np[segmentation_resized == label] = [128, 128, 128]
81
  return Image.fromarray(agnostic_np)
82
 
83
+ # ----------------- Load Model -----------------
84
  def load_model(model_type):
85
  if model_type == "UNet":
86
  model = UNetGenerator().to(device)
 
104
  model.eval()
105
  return model
106
 
107
+ # ----------------- Generate Try-On -----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
109
  if model_type == "UNet":
110
  img_transform = transforms.Compose([
 
142
  final_output = blend_mask * person_np + (1 - blend_mask) * output_img
143
  final_output = (final_output * 255).astype(np.uint8)
144
  return Image.fromarray(final_output)
 
 
145
 
146
+ # ----------------- Traditional CV Pipeline -----------------
147
+ def create_agnostic_traditional(person_np, label_np):
148
+ mask = (label_np == 4).astype(np.uint8)
149
+ kernel = np.ones((7, 7), np.uint8)
150
+ dilated = cv2.dilate(mask, kernel, iterations=2)
151
+ agnostic = person_np.copy()
152
+ agnostic[dilated == 1] = [128, 128, 128]
153
+ return agnostic, dilated
154
+
155
+ def improved_warp_cloth(cloth_np, person_np, label_np):
156
+ mask = (label_np == 4).astype(np.uint8) * 255
157
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
158
+ if not contours:
159
+ return cloth_np
160
+
161
+ cnt = max(contours, key=cv2.contourArea)
162
+ x, y, w, h = cv2.boundingRect(cnt)
163
+ src_h, src_w = cloth_np.shape[:2]
164
+ src_points = np.array([[0,0],[src_w-1,0],[src_w-1,src_h-1],[0,src_h-1]], dtype=np.float32)
165
+ padding_x, padding_y = int(w*0.05), int(h*0.05)
166
+ dst_points = np.array([
167
+ [max(0, x - padding_x), max(0, y - padding_y)],
168
+ [min(person_np.shape[1] - 1, x + w + padding_x), max(0, y - padding_y)],
169
+ [min(person_np.shape[1] - 1, x + w + padding_x), min(person_np.shape[0] - 1, y + h + padding_y)],
170
+ [max(0, x - padding_x), min(person_np.shape[0] - 1, y + h + padding_y)]
171
+ ], dtype=np.float32)
172
+ M = cv2.getPerspectiveTransform(src_points, dst_points)
173
+ warped = cv2.warpPerspective(cloth_np, M, (person_np.shape[1], person_np.shape[0]), borderMode=cv2.BORDER_CONSTANT)
174
+ return warped
175
+
176
+ def improved_blend_traditional(agnostic_np, warped_cloth_np, label_np):
177
+ target_mask = (label_np == 4).astype(np.uint8)
178
+ kernel = np.ones((9, 9), np.uint8)
179
+ target_mask = cv2.dilate(target_mask, kernel, iterations=2) * 255
180
+
181
+ gray = cv2.cvtColor(warped_cloth_np, cv2.COLOR_BGR2GRAY)
182
+ _, cloth_mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
183
+ combined_mask = cv2.bitwise_and(target_mask, cloth_mask)
184
+ combined_mask = cv2.GaussianBlur(combined_mask, (5, 5), 0)
185
+
186
+ M = cv2.moments(combined_mask)
187
+ center = (int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])) if M["m00"] != 0 else (96, 128)
188
+
189
+ try:
190
+ output = cv2.seamlessClone(warped_cloth_np, agnostic_np, combined_mask, center, cv2.NORMAL_CLONE)
191
+ except:
192
+ mask_3d = np.stack([combined_mask / 255.0] * 3, axis=2)
193
+ output = warped_cloth_np * mask_3d + agnostic_np * (1 - mask_3d)
194
+ output = output.astype(np.uint8)
195
+ return output
196
+
197
+ # ----------------- Main Pipeline -----------------
198
  def virtual_tryon(person_image, cloth_image, model_type):
199
+ if model_type == "Traditional":
200
+ person_np = np.array(person_image.resize((192, 256)))[:, :, ::-1]
201
+ cloth_np = np.array(cloth_image.resize((192, 256)))[:, :, ::-1]
202
+ segmentation = get_segmentation(person_image)
203
+ label_np = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
204
+
205
+ agnostic_np, _ = create_agnostic_traditional(person_np, label_np)
206
+ warped_cloth = improved_warp_cloth(cloth_np, person_np, label_np)
207
+ output_np = improved_blend_traditional(agnostic_np, warped_cloth, label_np)
208
+
209
+ return Image.fromarray(agnostic_np[:, :, ::-1]), Image.fromarray(output_np[:, :, ::-1])
210
+
211
  segmentation = get_segmentation(person_image)
212
  agnostic = generate_agnostic(person_image, segmentation)
213
  model = load_model(model_type)
214
  result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model, model_type)
 
215
  return agnostic, result
216
 
217
+ # ----------------- Gradio UI -----------------
218
  demo = gr.Interface(
219
  fn=virtual_tryon,
220
  inputs=[
221
  gr.Image(type="pil", label="Person Image"),
222
  gr.Image(type="pil", label="Cloth Image"),
223
+ gr.Radio(choices=["UNet", "GAN", "Diffusion", "Traditional"], label="Model Type", value="UNet")
224
  ],
225
  outputs=[
226
  gr.Image(type="pil", label="Agnostic (Torso Masked)"),
227
  gr.Image(type="pil", label="Virtual Try-On Output")
228
  ],
229
  title="👕 Virtual Try-On App",
230
+ description="Upload a person image and a clothing image, select a model (UNet, GAN, Diffusion, Traditional), and try it on virtually."
231
  )
232
 
233
  if __name__ == "__main__":