sengourav012 commited on
Commit
e208ed3
·
verified ·
1 Parent(s): e43c7ba

Included GAN and Diffusion model also

Browse files
Files changed (1) hide show
  1. app.py +32 -17
app.py CHANGED
@@ -5,12 +5,12 @@ import torch.nn as nn
5
  import numpy as np
6
  from torchvision import transforms
7
  import cv2
8
-
9
  from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
10
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- # ----------------- Load Human Parser Model from Hugging Face Hub -----------------
14
  processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
15
  parser_model = SegformerForSemanticSegmentation.from_pretrained(
16
  "matei-dorian/segformer-b5-finetuned-human-parsing"
@@ -60,12 +60,6 @@ class UNetGenerator(nn.Module):
60
  u4 = self.up4(torch.cat([u3, d1], dim=1))
61
  return u4
62
 
63
- # ----------------- Load UNet Try-On Model -----------------
64
- tryon_model = UNetGenerator().to(device)
65
- checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
66
- tryon_model.load_state_dict(checkpoint['model_state_dict'])
67
- tryon_model.eval()
68
-
69
  # ----------------- Image Transforms -----------------
70
  img_transform = transforms.Compose([
71
  transforms.Resize((256, 192)),
@@ -85,39 +79,60 @@ def generate_agnostic(image: Image.Image, segmentation):
85
  img_np = np.array(image.resize((192, 256)))
86
  agnostic_np = img_np.copy()
87
  segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
88
- agnostic_np[segmentation_resized == 4] = [128, 128, 128] # Mask upper clothes
 
 
89
  return Image.fromarray(agnostic_np)
90
 
91
- def generate_tryon_output(agnostic_img, cloth_img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
93
  cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
94
  input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
95
 
96
  with torch.no_grad():
97
- output = tryon_model(input_tensor)
98
  output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
 
99
  output_img = (output_img * 255).astype(np.uint8)
100
  return Image.fromarray(output_img)
101
 
102
- # ----------------- Gradio Interface -----------------
103
- def virtual_tryon(person_image, cloth_image):
104
  segmentation = get_segmentation(person_image)
105
  agnostic = generate_agnostic(person_image, segmentation)
106
- result = generate_tryon_output(agnostic, cloth_image)
 
107
  return agnostic, result
108
 
 
109
  demo = gr.Interface(
110
  fn=virtual_tryon,
111
  inputs=[
112
  gr.Image(type="pil", label="Person Image"),
113
- gr.Image(type="pil", label="Cloth Image")
 
114
  ],
115
  outputs=[
116
  gr.Image(type="pil", label="Agnostic (Torso Masked)"),
117
  gr.Image(type="pil", label="Virtual Try-On Output")
118
  ],
119
- title="👕 Virtual Try-On (UNet + Segformer)",
120
- description="Upload a person image and a cloth image to try on the cloth virtually."
121
  )
122
 
123
  if __name__ == "__main__":
 
5
  import numpy as np
6
  from torchvision import transforms
7
  import cv2
 
8
  from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
9
 
10
+ # ----------------- Device -----------------
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ # ----------------- Load Human Parser Model -----------------
14
  processor = AutoImageProcessor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
15
  parser_model = SegformerForSemanticSegmentation.from_pretrained(
16
  "matei-dorian/segformer-b5-finetuned-human-parsing"
 
60
  u4 = self.up4(torch.cat([u3, d1], dim=1))
61
  return u4
62
 
 
 
 
 
 
 
63
  # ----------------- Image Transforms -----------------
64
  img_transform = transforms.Compose([
65
  transforms.Resize((256, 192)),
 
79
  img_np = np.array(image.resize((192, 256)))
80
  agnostic_np = img_np.copy()
81
  segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
82
+ clothing_labels = [4, 5, 6, 7, 8, 16]
83
+ for label in clothing_labels:
84
+ agnostic_np[segmentation_resized == label] = [128, 128, 128]
85
  return Image.fromarray(agnostic_np)
86
 
87
+ def load_model(model_type):
88
+ model = UNetGenerator().to(device)
89
+ if model_type == "UNet":
90
+ checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
91
+ elif model_type == "GAN":
92
+ checkpoint = torch.load("viton_gan_generator_checkpoint.pth", map_location=device)
93
+ elif model_type == "Diffusion":
94
+ checkpoint = torch.load("viton_diffusion_generator_checkpoint.pth", map_location=device)
95
+ else:
96
+ raise ValueError("Invalid model type")
97
+
98
+ model.load_state_dict(checkpoint['model_state_dict'])
99
+ model.eval()
100
+ return model
101
+
102
+ def generate_tryon_output(agnostic_img, cloth_img, model):
103
  agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
104
  cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
105
  input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
106
 
107
  with torch.no_grad():
108
+ output = model(input_tensor)
109
  output_img = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
110
+ output_img = (output_img + 1) / 2
111
  output_img = (output_img * 255).astype(np.uint8)
112
  return Image.fromarray(output_img)
113
 
114
+ # ----------------- Inference Pipeline -----------------
115
+ def virtual_tryon(person_image, cloth_image, model_type):
116
  segmentation = get_segmentation(person_image)
117
  agnostic = generate_agnostic(person_image, segmentation)
118
+ model = load_model(model_type)
119
+ result = generate_tryon_output(agnostic, cloth_image, model)
120
  return agnostic, result
121
 
122
+ # ----------------- Gradio Interface -----------------
123
  demo = gr.Interface(
124
  fn=virtual_tryon,
125
  inputs=[
126
  gr.Image(type="pil", label="Person Image"),
127
+ gr.Image(type="pil", label="Cloth Image"),
128
+ gr.Radio(choices=["UNet", "GAN", "Diffusion"], label="Model Type", value="UNet")
129
  ],
130
  outputs=[
131
  gr.Image(type="pil", label="Agnostic (Torso Masked)"),
132
  gr.Image(type="pil", label="Virtual Try-On Output")
133
  ],
134
+ title="👕 Virtual Try-On App",
135
+ description="Upload a person image and a clothing image, select a model (UNet, GAN, Diffusion), and try it on virtually."
136
  )
137
 
138
  if __name__ == "__main__":