sengourav012 commited on
Commit
31d417a
·
verified ·
1 Parent(s): 3e7b4c7

make app agnostic to all the models

Browse files
Files changed (1) hide show
  1. app.py +54 -12
app.py CHANGED
@@ -68,17 +68,7 @@ class UNetGenerator(nn.Module):
68
  # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69
  # ])
70
  #new changes
71
- if model_type == "UNet":
72
- img_transform = transforms.Compose([
73
- transforms.Resize((256, 192)),
74
- transforms.ToTensor()
75
- ])
76
- else:
77
- img_transform = transforms.Compose([
78
- transforms.Resize((256, 192)),
79
- transforms.ToTensor(),
80
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
81
- ])
82
  #end new changes
83
 
84
  # ----------------- Helper Functions -----------------
@@ -144,7 +134,60 @@ def load_model(model_type):
144
 
145
  # return Image.fromarray(final_output)
146
  #new changes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
 
 
 
 
 
 
 
 
 
 
 
 
148
  agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
149
  cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
150
  input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
@@ -168,7 +211,6 @@ def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, mod
168
 
169
  final_output = blend_mask * person_np + (1 - blend_mask) * output_img
170
  final_output = (final_output * 255).astype(np.uint8)
171
-
172
  return Image.fromarray(final_output)
173
  #new changes end
174
 
 
68
  # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69
  # ])
70
  #new changes
71
+
 
 
 
 
 
 
 
 
 
 
72
  #end new changes
73
 
74
  # ----------------- Helper Functions -----------------
 
134
 
135
  # return Image.fromarray(final_output)
136
  #new changes
137
+ # def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
138
+
139
+ # if model_type == "UNet":
140
+ # img_transform = transforms.Compose([
141
+ # transforms.Resize((256, 192)),
142
+ # transforms.ToTensor()
143
+ # ])
144
+
145
+ # else:
146
+
147
+ # img_transform = transforms.Compose([
148
+ # transforms.Resize((256, 192)),
149
+ # transforms.ToTensor(),
150
+ # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
151
+ # ])
152
+
153
+ # agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
154
+ # cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
155
+ # input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
156
+
157
+ # with torch.no_grad():
158
+ # output = model(input_tensor)
159
+
160
+ # if model_type == "UNet":
161
+ # output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
162
+ # output_img = (output_img * 255).astype(np.uint8)
163
+ # return Image.fromarray(output_img)
164
+ # else:
165
+ # output_img = output[0].cpu().permute(1, 2, 0).numpy()
166
+ # output_img = (output_img + 1) / 2
167
+ # output_img = np.clip(output_img, 0, 1)
168
+
169
+ # person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
170
+ # segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
171
+ # blend_mask = (segmentation_resized == 0).astype(np.float32)
172
+ # blend_mask = np.expand_dims(blend_mask, axis=2)
173
+
174
+ # final_output = blend_mask * person_np + (1 - blend_mask) * output_img
175
+ # final_output = (final_output * 255).astype(np.uint8)
176
+
177
+ # return Image.fromarray(final_output)
178
  def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model, model_type):
179
+ if model_type == "UNet":
180
+ img_transform = transforms.Compose([
181
+ transforms.Resize((256, 192)),
182
+ transforms.ToTensor()
183
+ ])
184
+ else:
185
+ img_transform = transforms.Compose([
186
+ transforms.Resize((256, 192)),
187
+ transforms.ToTensor(),
188
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
189
+ ])
190
+
191
  agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
192
  cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
193
  input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
 
211
 
212
  final_output = blend_mask * person_np + (1 - blend_mask) * output_img
213
  final_output = (final_output * 255).astype(np.uint8)
 
214
  return Image.fromarray(final_output)
215
  #new changes end
216