Daraphan commited on
Commit
be322b6
·
verified ·
1 Parent(s): a7c9ce1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -0
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """VTON_GarmentMasker.ipynb
3
+ Automatically generated by Colab.
4
+ Original file is located at
5
+ https://colab.research.google.com/drive/1Y22abu3jZQ5qCKP7DTR6kYvXdQbHnJCu
6
+ Using YOLO Clothing Classification Model
7
+ """
8
+
9
+ # !pip install gradio
10
+ # !pip install ultralytics
11
+ # !pip install segment-anything
12
+
13
+ # !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
14
+
15
+ import torch
16
+ import numpy as np
17
+ import cv2
18
+ from PIL import Image
19
+ from torchvision import transforms
20
+ from ultralytics import YOLO
21
+ from segment_anything import SamPredictor, sam_model_registry
22
+ from transformers import YolosForObjectDetection, YolosImageProcessor
23
+ import gradio as gr
24
+ import os
25
+ import urllib.request
26
+
27
+ class GarmentMaskingPipeline:
28
+ def __init__(self):
29
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ print(f"Using device: {self.device}")
31
+ self.yolo_model, self.sam_predictor, self.classification_model = self.load_models()
32
+
33
+ self.clothing_to_body_parts = {
34
+ 'shirt': ['torso', 'arms'],
35
+ 't-shirt': ['torso', 'upper_arms'],
36
+ 'blouse': ['torso', 'arms'],
37
+ 'dress': ['torso', 'legs'],
38
+ 'skirt': ['lower_torso', 'legs'],
39
+ 'pants': ['legs'],
40
+ 'shorts': ['upper_legs'],
41
+ 'jacket': ['torso', 'arms'],
42
+ 'coat': ['torso', 'arms']
43
+ }
44
+
45
+ self.body_parts_positions = {
46
+ 'face': (0.0, 0.2),
47
+ 'torso': (0.2, 0.5),
48
+ 'arms': (0.2, 0.5),
49
+ 'upper_arms': (0.2, 0.35),
50
+ 'lower_torso': (0.4, 0.6),
51
+ 'legs': (0.5, 0.9),
52
+ 'upper_legs': (0.5, 0.7),
53
+ 'feet': (0.9, 1.0)
54
+ }
55
+
56
+ def load_models(self):
57
+ print("Loading models...")
58
+ # Download models if they don't exist
59
+ self.download_models()
60
+
61
+ # Load YOLO model
62
+ yolo_model = YOLO('yolov8n.pt')
63
+
64
+ # Load SAM model
65
+ sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
66
+ sam.to(self.device)
67
+ predictor = SamPredictor(sam)
68
+
69
+ # Load YOLOS-Fashionpedia model for clothing classification
70
+ print("Loading YOLOS-Fashionpedia model...")
71
+ model_name = "valentinafeve/yolos-fashionpedia"
72
+ processor = YolosImageProcessor.from_pretrained(model_name)
73
+ classification_model = YolosForObjectDetection.from_pretrained(model_name)
74
+ classification_model.to(self.device)
75
+ classification_model.eval()
76
+
77
+ print("Models loaded successfully!")
78
+ return yolo_model, predictor, classification_model
79
+
80
+ def download_models(self):
81
+ """Download required model files if they don't exist"""
82
+ models = {
83
+ "yolov8n.pt": "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt",
84
+ "sam_vit_h_4b8939.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
85
+ }
86
+
87
+ for filename, url in models.items():
88
+ if not os.path.exists(filename):
89
+ print(f"Downloading {filename}...")
90
+ urllib.request.urlretrieve(url, filename)
91
+ print(f"Downloaded {filename}")
92
+ else:
93
+ print(f"{filename} already exists")
94
+
95
+ # The YOLOS-Fashionpedia model will be downloaded automatically by transformers
96
+
97
+ def classify_clothing(self, clothing_image):
98
+ if not isinstance(clothing_image, Image.Image):
99
+ clothing_image = Image.fromarray(clothing_image)
100
+
101
+ # Process image with YOLOS processor
102
+ processor = YolosImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia")
103
+ inputs = processor(images=clothing_image, return_tensors="pt").to(self.device)
104
+
105
+ # Run inference
106
+ with torch.no_grad():
107
+ outputs = self.classification_model(**inputs)
108
+
109
+ # Process results
110
+ target_sizes = torch.tensor([clothing_image.size[::-1]]).to(self.device)
111
+ results = processor.post_process_object_detection(
112
+ outputs, target_sizes=target_sizes, threshold=0.1
113
+ )[0]
114
+
115
+ # Extract detected labels and confidence scores
116
+ labels = results["labels"]
117
+ scores = results["scores"]
118
+
119
+ # Get class names from model config
120
+ id2label = self.classification_model.config.id2label
121
+
122
+ # Define Fashionpedia to our category mapping
123
+ fashionpedia_to_clothing = {
124
+ 'shirt': 'shirt',
125
+ 'blouse': 'shirt',
126
+ 'top': 't-shirt',
127
+ 't-shirt': 't-shirt',
128
+ 'sweater': 'shirt',
129
+ 'jacket': 'jacket',
130
+ 'cardigan': 'jacket',
131
+ 'coat': 'coat',
132
+ 'jumper': 'shirt',
133
+ 'dress': 'dress',
134
+ 'skirt': 'skirt',
135
+ 'shorts': 'shorts',
136
+ 'pants': 'pants',
137
+ 'jeans': 'pants',
138
+ 'leggings': 'pants',
139
+ 'jumpsuit': 'dress'
140
+ }
141
+
142
+ # Find the garment with highest confidence
143
+ if len(labels) > 0:
144
+ detections = [(id2label[label.item()].lower(), score.item())
145
+ for label, score in zip(labels, scores)]
146
+ detections.sort(key=lambda x: x[1], reverse=True)
147
+
148
+ for label, score in detections:
149
+ # Look for clothing keywords in the label
150
+ for keyword, category in fashionpedia_to_clothing.items():
151
+ if keyword in label:
152
+ return category
153
+
154
+ # If no mapping found, use the first detection as is
155
+ return 't-shirt'
156
+
157
+ # Default to t-shirt if nothing detected
158
+ return 't-shirt'
159
+
160
+ def create_garment_mask(self, person_image, garment_image):
161
+ clothing_type = self.classify_clothing(garment_image)
162
+ parts_to_mask = self.clothing_to_body_parts.get(clothing_type, [])
163
+
164
+ results = self.yolo_model(person_image, classes=[0])
165
+ mask = np.zeros(person_image.shape[:2], dtype=np.uint8)
166
+
167
+ if results and len(results[0].boxes.data) > 0:
168
+ person_boxes = results[0].boxes.data
169
+ person_areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in person_boxes]
170
+ largest_person_index = np.argmax(person_areas)
171
+ person_box = person_boxes[largest_person_index][:4].cpu().numpy().astype(int)
172
+
173
+ self.sam_predictor.set_image(person_image)
174
+ masks, _, _ = self.sam_predictor.predict(box=person_box, multimask_output=False)
175
+ person_mask = masks[0].astype(np.uint8)
176
+
177
+ h, w = person_mask.shape
178
+ for part in parts_to_mask:
179
+ if part in self.body_parts_positions:
180
+ top_ratio, bottom_ratio = self.body_parts_positions[part]
181
+ top_px, bottom_px = int(h * top_ratio), int(h * bottom_ratio)
182
+
183
+ part_mask = np.zeros_like(person_mask)
184
+ part_mask[top_px:bottom_px, :] = 1
185
+ part_mask = np.logical_and(part_mask, person_mask).astype(np.uint8)
186
+
187
+ mask = np.logical_or(mask, part_mask).astype(np.uint8)
188
+
189
+ # Remove face from the mask
190
+ face_top_px, face_bottom_px = int(h * 0.0), int(h * 0.2)
191
+ face_mask = np.zeros_like(person_mask)
192
+ face_mask[face_top_px:face_bottom_px, :] = 1
193
+ face_mask = np.logical_and(face_mask, person_mask).astype(np.uint8)
194
+ mask = np.logical_and(mask, np.logical_not(face_mask)).astype(np.uint8)
195
+
196
+ # Remove feet from the mask
197
+ feet_top_px, feet_bottom_px = int(h * 0.9), int(h * 1.0)
198
+ feet_mask = np.zeros_like(person_mask)
199
+ feet_mask[feet_top_px:feet_bottom_px, :] = 1
200
+ feet_mask = np.logical_and(feet_mask, person_mask).astype(np.uint8)
201
+ mask = np.logical_and(mask, np.logical_not(feet_mask)).astype(np.uint8)
202
+
203
+ return mask * 255
204
+
205
+ def process(self, person_image_pil, garment_image_pil, mask_color_hex="#00FF00", opacity=0.5):
206
+ """Process the input images and return the masked result"""
207
+ # Convert PIL to numpy array
208
+ person_image = np.array(person_image_pil)
209
+ garment_image = np.array(garment_image_pil)
210
+
211
+ # Convert to RGB if needed
212
+ if person_image.shape[2] == 4: # RGBA
213
+ person_image = person_image[:, :, :3]
214
+ if garment_image.shape[2] == 4: # RGBA
215
+ garment_image = garment_image[:, :, :3]
216
+
217
+ # Create garment mask
218
+ garment_mask = self.create_garment_mask(person_image, garment_image)
219
+
220
+ # Convert hex color to RGB
221
+ r = int(mask_color_hex[1:3], 16)
222
+ g = int(mask_color_hex[3:5], 16)
223
+ b = int(mask_color_hex[5:7], 16)
224
+ color = (r, g, b)
225
+
226
+ # Create a colored mask
227
+ colored_mask = np.zeros_like(person_image)
228
+ for i in range(3):
229
+ colored_mask[:, :, i] = garment_mask * (color[i] / 255.0)
230
+
231
+ # Create binary mask for visualization
232
+ binary_mask = np.stack([garment_mask, garment_mask, garment_mask], axis=2)
233
+
234
+ # Overlay mask on original image
235
+ mask_3d = garment_mask[:, :, np.newaxis] / 255.0
236
+ overlay = person_image * (1 - opacity * mask_3d) + colored_mask * opacity
237
+ overlay = overlay.astype(np.uint8)
238
+
239
+ # Get classification result
240
+ clothing_type = self.classify_clothing(garment_image)
241
+ parts_to_mask = self.clothing_to_body_parts.get(clothing_type, [])
242
+
243
+ return overlay, binary_mask, f"Detected garment: {clothing_type}\nBody parts to mask: {', '.join(parts_to_mask)}"
244
+
245
+ def process_images(person_img, garment_img, mask_color, opacity):
246
+ """Gradio processing function"""
247
+ try:
248
+ pipeline = GarmentMaskingPipeline()
249
+ result = pipeline.process(person_img, garment_img, mask_color, opacity)
250
+ return result
251
+ except Exception as e:
252
+ import traceback
253
+ error_msg = f"Error processing images: {str(e)}\n{traceback.format_exc()}"
254
+ print(error_msg)
255
+ return None, None, error_msg
256
+
257
+ def create_gradio_interface():
258
+ """Create and launch the Gradio interface"""
259
+ with gr.Blocks(title="VTON SAM Garment Masking Pipeline") as interface:
260
+ gr.Markdown("""
261
+ # Virtual Try-On Garment Masking Pipeline with SAM and YOLOS-Fashionpedia
262
+ Upload a person image and a garment image to generate a mask for a virtual try-on application.
263
+ The system will:
264
+ 1. Detect the person using YOLO
265
+ 2. Create a high-quality segmentation using SAM (Segment Anything Model)
266
+ 3. Classify the garment type using YOLOS-Fashionpedia
267
+ 4. Generate a mask of the area where the garment should be placed
268
+ **Note**: This system uses state-of-the-art AI segmentation and fashion detection models for accurate results.
269
+ """)
270
+
271
+ with gr.Row():
272
+ with gr.Column():
273
+ person_input = gr.Image(label="Person Image (Image A)", type="pil")
274
+ garment_input = gr.Image(label="Garment Image (Image B)", type="pil")
275
+
276
+ with gr.Row():
277
+ mask_color = gr.ColorPicker(label="Mask Color", value="#00FF00")
278
+ opacity = gr.Slider(label="Mask Opacity", minimum=0.1, maximum=0.9, value=0.5, step=0.1)
279
+
280
+ submit_btn = gr.Button("Generate Mask")
281
+
282
+ with gr.Column():
283
+ masked_output = gr.Image(label="Person with Masked Region")
284
+ mask_output = gr.Image(label="Standalone Mask")
285
+ result_text = gr.Textbox(label="Detection Results", lines=3)
286
+
287
+ # Set up the processing flow
288
+ submit_btn.click(
289
+ fn=process_images,
290
+ inputs=[person_input, garment_input, mask_color, opacity],
291
+ outputs=[masked_output, mask_output, result_text]
292
+ )
293
+
294
+ gr.Markdown("""
295
+ ## How It Works
296
+ 1. **Person Detection**: Uses YOLO to detect and locate the person in the image
297
+ 2. **Segmentation**: Uses SAM (Segment Anything Model) to create a high-quality segmentation mask
298
+ 3. **Garment Classification**: Uses YOLOS-Fashionpedia to identify the garment type with fashion-specific detection
299
+ 4. **Mask Generation**: Creates a mask based on the garment type and body part mapping
300
+ ## Supported Garment Types
301
+ - Shirts, Blouses, Tops, and T-shirts
302
+ - Sweaters and Cardigans
303
+ - Dresses and Jumpsuits
304
+ - Skirts
305
+ - Pants, Jeans, and Leggings
306
+ - Shorts
307
+ -
308
+ Jackets and Coats
309
+ """)
310
+
311
+ return interface
312
+
313
+ if __name__ == "__main__":
314
+ # Create and launch the Gradio interface
315
+ interface = create_gradio_interface()
316
+ interface.launch(debug=True,share=True)