Kk / app.py
Daraphan's picture
Create app.py
be322b6 verified
# -*- coding: utf-8 -*-
"""VTON_GarmentMasker.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1Y22abu3jZQ5qCKP7DTR6kYvXdQbHnJCu
Using YOLO Clothing Classification Model
"""
# !pip install gradio
# !pip install ultralytics
# !pip install segment-anything
# !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
import torch
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms
from ultralytics import YOLO
from segment_anything import SamPredictor, sam_model_registry
from transformers import YolosForObjectDetection, YolosImageProcessor
import gradio as gr
import os
import urllib.request
class GarmentMaskingPipeline:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
self.yolo_model, self.sam_predictor, self.classification_model = self.load_models()
self.clothing_to_body_parts = {
'shirt': ['torso', 'arms'],
't-shirt': ['torso', 'upper_arms'],
'blouse': ['torso', 'arms'],
'dress': ['torso', 'legs'],
'skirt': ['lower_torso', 'legs'],
'pants': ['legs'],
'shorts': ['upper_legs'],
'jacket': ['torso', 'arms'],
'coat': ['torso', 'arms']
}
self.body_parts_positions = {
'face': (0.0, 0.2),
'torso': (0.2, 0.5),
'arms': (0.2, 0.5),
'upper_arms': (0.2, 0.35),
'lower_torso': (0.4, 0.6),
'legs': (0.5, 0.9),
'upper_legs': (0.5, 0.7),
'feet': (0.9, 1.0)
}
def load_models(self):
print("Loading models...")
# Download models if they don't exist
self.download_models()
# Load YOLO model
yolo_model = YOLO('yolov8n.pt')
# Load SAM model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(self.device)
predictor = SamPredictor(sam)
# Load YOLOS-Fashionpedia model for clothing classification
print("Loading YOLOS-Fashionpedia model...")
model_name = "valentinafeve/yolos-fashionpedia"
processor = YolosImageProcessor.from_pretrained(model_name)
classification_model = YolosForObjectDetection.from_pretrained(model_name)
classification_model.to(self.device)
classification_model.eval()
print("Models loaded successfully!")
return yolo_model, predictor, classification_model
def download_models(self):
"""Download required model files if they don't exist"""
models = {
"yolov8n.pt": "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt",
"sam_vit_h_4b8939.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
}
for filename, url in models.items():
if not os.path.exists(filename):
print(f"Downloading {filename}...")
urllib.request.urlretrieve(url, filename)
print(f"Downloaded {filename}")
else:
print(f"{filename} already exists")
# The YOLOS-Fashionpedia model will be downloaded automatically by transformers
def classify_clothing(self, clothing_image):
if not isinstance(clothing_image, Image.Image):
clothing_image = Image.fromarray(clothing_image)
# Process image with YOLOS processor
processor = YolosImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia")
inputs = processor(images=clothing_image, return_tensors="pt").to(self.device)
# Run inference
with torch.no_grad():
outputs = self.classification_model(**inputs)
# Process results
target_sizes = torch.tensor([clothing_image.size[::-1]]).to(self.device)
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=0.1
)[0]
# Extract detected labels and confidence scores
labels = results["labels"]
scores = results["scores"]
# Get class names from model config
id2label = self.classification_model.config.id2label
# Define Fashionpedia to our category mapping
fashionpedia_to_clothing = {
'shirt': 'shirt',
'blouse': 'shirt',
'top': 't-shirt',
't-shirt': 't-shirt',
'sweater': 'shirt',
'jacket': 'jacket',
'cardigan': 'jacket',
'coat': 'coat',
'jumper': 'shirt',
'dress': 'dress',
'skirt': 'skirt',
'shorts': 'shorts',
'pants': 'pants',
'jeans': 'pants',
'leggings': 'pants',
'jumpsuit': 'dress'
}
# Find the garment with highest confidence
if len(labels) > 0:
detections = [(id2label[label.item()].lower(), score.item())
for label, score in zip(labels, scores)]
detections.sort(key=lambda x: x[1], reverse=True)
for label, score in detections:
# Look for clothing keywords in the label
for keyword, category in fashionpedia_to_clothing.items():
if keyword in label:
return category
# If no mapping found, use the first detection as is
return 't-shirt'
# Default to t-shirt if nothing detected
return 't-shirt'
def create_garment_mask(self, person_image, garment_image):
clothing_type = self.classify_clothing(garment_image)
parts_to_mask = self.clothing_to_body_parts.get(clothing_type, [])
results = self.yolo_model(person_image, classes=[0])
mask = np.zeros(person_image.shape[:2], dtype=np.uint8)
if results and len(results[0].boxes.data) > 0:
person_boxes = results[0].boxes.data
person_areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in person_boxes]
largest_person_index = np.argmax(person_areas)
person_box = person_boxes[largest_person_index][:4].cpu().numpy().astype(int)
self.sam_predictor.set_image(person_image)
masks, _, _ = self.sam_predictor.predict(box=person_box, multimask_output=False)
person_mask = masks[0].astype(np.uint8)
h, w = person_mask.shape
for part in parts_to_mask:
if part in self.body_parts_positions:
top_ratio, bottom_ratio = self.body_parts_positions[part]
top_px, bottom_px = int(h * top_ratio), int(h * bottom_ratio)
part_mask = np.zeros_like(person_mask)
part_mask[top_px:bottom_px, :] = 1
part_mask = np.logical_and(part_mask, person_mask).astype(np.uint8)
mask = np.logical_or(mask, part_mask).astype(np.uint8)
# Remove face from the mask
face_top_px, face_bottom_px = int(h * 0.0), int(h * 0.2)
face_mask = np.zeros_like(person_mask)
face_mask[face_top_px:face_bottom_px, :] = 1
face_mask = np.logical_and(face_mask, person_mask).astype(np.uint8)
mask = np.logical_and(mask, np.logical_not(face_mask)).astype(np.uint8)
# Remove feet from the mask
feet_top_px, feet_bottom_px = int(h * 0.9), int(h * 1.0)
feet_mask = np.zeros_like(person_mask)
feet_mask[feet_top_px:feet_bottom_px, :] = 1
feet_mask = np.logical_and(feet_mask, person_mask).astype(np.uint8)
mask = np.logical_and(mask, np.logical_not(feet_mask)).astype(np.uint8)
return mask * 255
def process(self, person_image_pil, garment_image_pil, mask_color_hex="#00FF00", opacity=0.5):
"""Process the input images and return the masked result"""
# Convert PIL to numpy array
person_image = np.array(person_image_pil)
garment_image = np.array(garment_image_pil)
# Convert to RGB if needed
if person_image.shape[2] == 4: # RGBA
person_image = person_image[:, :, :3]
if garment_image.shape[2] == 4: # RGBA
garment_image = garment_image[:, :, :3]
# Create garment mask
garment_mask = self.create_garment_mask(person_image, garment_image)
# Convert hex color to RGB
r = int(mask_color_hex[1:3], 16)
g = int(mask_color_hex[3:5], 16)
b = int(mask_color_hex[5:7], 16)
color = (r, g, b)
# Create a colored mask
colored_mask = np.zeros_like(person_image)
for i in range(3):
colored_mask[:, :, i] = garment_mask * (color[i] / 255.0)
# Create binary mask for visualization
binary_mask = np.stack([garment_mask, garment_mask, garment_mask], axis=2)
# Overlay mask on original image
mask_3d = garment_mask[:, :, np.newaxis] / 255.0
overlay = person_image * (1 - opacity * mask_3d) + colored_mask * opacity
overlay = overlay.astype(np.uint8)
# Get classification result
clothing_type = self.classify_clothing(garment_image)
parts_to_mask = self.clothing_to_body_parts.get(clothing_type, [])
return overlay, binary_mask, f"Detected garment: {clothing_type}\nBody parts to mask: {', '.join(parts_to_mask)}"
def process_images(person_img, garment_img, mask_color, opacity):
"""Gradio processing function"""
try:
pipeline = GarmentMaskingPipeline()
result = pipeline.process(person_img, garment_img, mask_color, opacity)
return result
except Exception as e:
import traceback
error_msg = f"Error processing images: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, None, error_msg
def create_gradio_interface():
"""Create and launch the Gradio interface"""
with gr.Blocks(title="VTON SAM Garment Masking Pipeline") as interface:
gr.Markdown("""
# Virtual Try-On Garment Masking Pipeline with SAM and YOLOS-Fashionpedia
Upload a person image and a garment image to generate a mask for a virtual try-on application.
The system will:
1. Detect the person using YOLO
2. Create a high-quality segmentation using SAM (Segment Anything Model)
3. Classify the garment type using YOLOS-Fashionpedia
4. Generate a mask of the area where the garment should be placed
**Note**: This system uses state-of-the-art AI segmentation and fashion detection models for accurate results.
""")
with gr.Row():
with gr.Column():
person_input = gr.Image(label="Person Image (Image A)", type="pil")
garment_input = gr.Image(label="Garment Image (Image B)", type="pil")
with gr.Row():
mask_color = gr.ColorPicker(label="Mask Color", value="#00FF00")
opacity = gr.Slider(label="Mask Opacity", minimum=0.1, maximum=0.9, value=0.5, step=0.1)
submit_btn = gr.Button("Generate Mask")
with gr.Column():
masked_output = gr.Image(label="Person with Masked Region")
mask_output = gr.Image(label="Standalone Mask")
result_text = gr.Textbox(label="Detection Results", lines=3)
# Set up the processing flow
submit_btn.click(
fn=process_images,
inputs=[person_input, garment_input, mask_color, opacity],
outputs=[masked_output, mask_output, result_text]
)
gr.Markdown("""
## How It Works
1. **Person Detection**: Uses YOLO to detect and locate the person in the image
2. **Segmentation**: Uses SAM (Segment Anything Model) to create a high-quality segmentation mask
3. **Garment Classification**: Uses YOLOS-Fashionpedia to identify the garment type with fashion-specific detection
4. **Mask Generation**: Creates a mask based on the garment type and body part mapping
## Supported Garment Types
- Shirts, Blouses, Tops, and T-shirts
- Sweaters and Cardigans
- Dresses and Jumpsuits
- Skirts
- Pants, Jeans, and Leggings
- Shorts
-
Jackets and Coats
""")
return interface
if __name__ == "__main__":
# Create and launch the Gradio interface
interface = create_gradio_interface()
interface.launch(debug=True,share=True)