sengourav012 commited on
Commit
ab2903b
·
verified ·
1 Parent(s): 783d5ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -57
app.py CHANGED
@@ -5,13 +5,12 @@ import torch.nn as nn
5
  import numpy as np
6
  from torchvision import transforms
7
  import cv2
 
8
  from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
9
- from improved_viton import ImprovedUNetGenerator
10
 
11
- # ----------------- Device -----------------
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- # ----------------- Load Human Parser Model -----------------
15
  processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
16
  parser_model = SegformerForSemanticSegmentation.from_pretrained(
17
  "matei-dorian/segformer-b5-finetuned-human-parsing"
@@ -61,11 +60,16 @@ class UNetGenerator(nn.Module):
61
  u4 = self.up4(torch.cat([u3, d1], dim=1))
62
  return u4
63
 
 
 
 
 
 
 
64
  # ----------------- Image Transforms -----------------
65
  img_transform = transforms.Compose([
66
  transforms.Resize((256, 192)),
67
- transforms.ToTensor(),
68
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69
  ])
70
 
71
  # ----------------- Helper Functions -----------------
@@ -81,79 +85,39 @@ def generate_agnostic(image: Image.Image, segmentation):
81
  img_np = np.array(image.resize((192, 256)))
82
  agnostic_np = img_np.copy()
83
  segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
84
- clothing_labels = [4]
85
- for label in clothing_labels:
86
- agnostic_np[segmentation_resized == label] = [128, 128, 128]
87
  return Image.fromarray(agnostic_np)
88
 
89
- def load_model(model_type):
90
- if model_type == "UNet":
91
- model = UNetGenerator().to(device)
92
- checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
93
- state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
94
- elif model_type == "GAN":
95
- model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
96
- checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device)
97
- state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
98
- elif model_type == "Diffusion":
99
- model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
100
- checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device)
101
- state_dict = checkpoint.get("model_G_state_dict") or checkpoint.get("model_state_dict")
102
- else:
103
- raise ValueError("Invalid model type")
104
-
105
- if state_dict is None:
106
- raise KeyError(f"No valid state_dict found for model type {model_type}")
107
-
108
- model.load_state_dict(state_dict)
109
- model.eval()
110
- return model
111
-
112
- def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model):
113
  agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
114
  cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
115
  input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
116
 
117
  with torch.no_grad():
118
- output = model(input_tensor)
119
-
120
- output_img = output[0].cpu().permute(1, 2, 0).numpy()
121
- output_img = (output_img + 1) / 2
122
- output_img = np.clip(output_img, 0, 1)
123
-
124
- person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
125
- segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
126
- blend_mask = (segmentation_resized == 0).astype(np.float32)
127
- blend_mask = np.expand_dims(blend_mask, axis=2)
128
 
129
- final_output = blend_mask * person_np + (1 - blend_mask) * output_img
130
- final_output = (final_output * 255).astype(np.uint8)
131
-
132
- return Image.fromarray(final_output)
133
-
134
-
135
- # ----------------- Inference Pipeline -----------------
136
- def virtual_tryon(person_image, cloth_image, model_type):
137
  segmentation = get_segmentation(person_image)
138
  agnostic = generate_agnostic(person_image, segmentation)
139
- model = load_model(model_type)
140
- result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model)
141
  return agnostic, result
142
 
143
- # ----------------- Gradio Interface -----------------
144
  demo = gr.Interface(
145
  fn=virtual_tryon,
146
  inputs=[
147
  gr.Image(type="pil", label="Person Image"),
148
- gr.Image(type="pil", label="Cloth Image"),
149
- gr.Radio(choices=["UNet", "GAN", "Diffusion"], label="Model Type", value="UNet")
150
  ],
151
  outputs=[
152
  gr.Image(type="pil", label="Agnostic (Torso Masked)"),
153
  gr.Image(type="pil", label="Virtual Try-On Output")
154
  ],
155
- title="👕 Virtual Try-On App",
156
- description="Upload a person image and a clothing image, select a model (UNet, GAN, or Diffusion), and try it on virtually."
157
  )
158
 
159
  if __name__ == "__main__":
 
5
  import numpy as np
6
  from torchvision import transforms
7
  import cv2
8
+
9
  from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
 
10
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ # ----------------- Load Human Parser Model from Hugging Face Hub -----------------
14
  processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
15
  parser_model = SegformerForSemanticSegmentation.from_pretrained(
16
  "matei-dorian/segformer-b5-finetuned-human-parsing"
 
60
  u4 = self.up4(torch.cat([u3, d1], dim=1))
61
  return u4
62
 
63
+ # ----------------- Load UNet Try-On Model -----------------
64
+ tryon_model = UNetGenerator().to(device)
65
+ checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
66
+ tryon_model.load_state_dict(checkpoint['model_state_dict'])
67
+ tryon_model.eval()
68
+
69
  # ----------------- Image Transforms -----------------
70
  img_transform = transforms.Compose([
71
  transforms.Resize((256, 192)),
72
+ transforms.ToTensor()
 
73
  ])
74
 
75
  # ----------------- Helper Functions -----------------
 
85
  img_np = np.array(image.resize((192, 256)))
86
  agnostic_np = img_np.copy()
87
  segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
88
+ agnostic_np[segmentation_resized == 4] = [128, 128, 128] # Mask upper clothes
 
 
89
  return Image.fromarray(agnostic_np)
90
 
91
+ def generate_tryon_output(agnostic_img, cloth_img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
93
  cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
94
  input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
95
 
96
  with torch.no_grad():
97
+ output = tryon_model(input_tensor)
98
+ output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
99
+ output_img = (output_img * 255).astype(np.uint8)
100
+ return Image.fromarray(output_img)
 
 
 
 
 
 
101
 
102
+ # ----------------- Gradio Interface -----------------
103
+ def virtual_tryon(person_image, cloth_image):
 
 
 
 
 
 
104
  segmentation = get_segmentation(person_image)
105
  agnostic = generate_agnostic(person_image, segmentation)
106
+ result = generate_tryon_output(agnostic, cloth_image)
 
107
  return agnostic, result
108
 
 
109
  demo = gr.Interface(
110
  fn=virtual_tryon,
111
  inputs=[
112
  gr.Image(type="pil", label="Person Image"),
113
+ gr.Image(type="pil", label="Cloth Image")
 
114
  ],
115
  outputs=[
116
  gr.Image(type="pil", label="Agnostic (Torso Masked)"),
117
  gr.Image(type="pil", label="Virtual Try-On Output")
118
  ],
119
+ title="👕 Virtual Try-On (UNet + Segformer)",
120
+ description="Upload a person image and a cloth image to try on the cloth virtually."
121
  )
122
 
123
  if __name__ == "__main__":