sengourav012 commited on
Commit
6184732
·
verified ·
1 Parent(s): 9982d28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -64,7 +64,8 @@ class UNetGenerator(nn.Module):
64
  # ----------------- Image Transforms -----------------
65
  img_transform = transforms.Compose([
66
  transforms.Resize((256, 192)),
67
- transforms.ToTensor()
 
68
  ])
69
 
70
  # ----------------- Helper Functions -----------------
@@ -108,24 +109,34 @@ def load_model(model_type):
108
  model.eval()
109
  return model
110
 
111
- def generate_tryon_output(agnostic_img, cloth_img, model):
112
  agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
113
  cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
114
  input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
115
 
116
  with torch.no_grad():
117
  output = model(input_tensor)
118
- output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
 
119
  output_img = (output_img + 1) / 2
120
- output_img = (output_img * 255).astype(np.uint8)
121
- return Image.fromarray(output_img)
 
 
 
 
 
 
 
 
 
122
 
123
  # ----------------- Inference Pipeline -----------------
124
  def virtual_tryon(person_image, cloth_image, model_type):
125
  segmentation = get_segmentation(person_image)
126
  agnostic = generate_agnostic(person_image, segmentation)
127
  model = load_model(model_type)
128
- result = generate_tryon_output(agnostic, cloth_image, model)
129
  return agnostic, result
130
 
131
  # ----------------- Gradio Interface -----------------
 
64
  # ----------------- Image Transforms -----------------
65
  img_transform = transforms.Compose([
66
  transforms.Resize((256, 192)),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69
  ])
70
 
71
  # ----------------- Helper Functions -----------------
 
109
  model.eval()
110
  return model
111
 
112
+ def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model):
113
  agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
114
  cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
115
  input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
116
 
117
  with torch.no_grad():
118
  output = model(input_tensor)
119
+
120
+ output_img = output[0].cpu().permute(1, 2, 0).numpy()
121
  output_img = (output_img + 1) / 2
122
+ output_img = np.clip(output_img, 0, 1)
123
+
124
+ person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
125
+ segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
126
+ blend_mask = (segmentation_resized == 0).astype(np.float32)
127
+ blend_mask = np.expand_dims(blend_mask, axis=2)
128
+
129
+ final_output = blend_mask * person_np + (1 - blend_mask) * output_img
130
+ final_output = (final_output * 255).astype(np.uint8)
131
+
132
+ return Image.fromarray(final_output)
133
 
134
  # ----------------- Inference Pipeline -----------------
135
  def virtual_tryon(person_image, cloth_image, model_type):
136
  segmentation = get_segmentation(person_image)
137
  agnostic = generate_agnostic(person_image, segmentation)
138
  model = load_model(model_type)
139
+ result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model)
140
  return agnostic, result
141
 
142
  # ----------------- Gradio Interface -----------------