Mask-RCNN / app.py
init
76e8b57
import cv2
import numpy as np
import torch
import torchvision
from torchvision.transforms import functional as F
import gradio as gr
# Set up color for sticker border (white)
BORDER_COLOR = (255, 255, 255)
# COCO classes (used by Mask R-CNN). Note: some entries are marked as 'N/A'
# as the COCO dataset does not include a class for those indices.
CLASSES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
# We are focusing on the person class for sticker creation.
TARGET_CLASSES = ['person']
def get_prediction(img, threshold=0.7):
"""
Get model predictions filtered for person detection.
Parameters:
- img: Input image in BGR format.
- threshold: Score threshold to filter weak predictions.
Returns:
- masks: List of masks for detected persons.
- boxes: Bounding boxes for detected persons.
- labels: Class labels for each detected object (only persons in our case).
- scores: Confidence scores for each detection.
"""
# Convert image to tensor expected by the model
transform = F.to_tensor(img)
# Get predictions from the model
prediction = model([transform])
# Initialize lists to store filtered predictions
masks = []
boxes = []
labels = []
scores = []
# Convert model output labels to human-readable classes
pred_classes = [CLASSES[i] for i in prediction[0]['labels']]
# Get masks, boxes, and scores from the prediction
pred_masks = prediction[0]['masks'].detach().cpu().numpy()
pred_boxes = prediction[0]['boxes'].detach().cpu().numpy()
pred_scores = prediction[0]['scores'].detach().cpu().numpy()
# Filter detections based on threshold and target class (person)
for i, score in enumerate(pred_scores):
if score > threshold and pred_classes[i] in TARGET_CLASSES:
masks.append(pred_masks[i][0])
boxes.append(pred_boxes[i])
labels.append(pred_classes[i])
scores.append(score)
return masks, boxes, labels, scores
def create_sticker(img, mask, border_thickness=3):
"""
Create a sticker image by applying the mask to the image.
The area outside the mask is made transparent and a border is drawn.
Parameters:
- img: Input image in BGR format.
- mask: Mask corresponding to the object (person).
- border_thickness: Thickness of the white border around the mask.
Returns:
- sticker: The resulting sticker image with transparency (BGRA format).
"""
# Create a binary mask where mask values > 0.5 become 1 and others 0.
mask_bin = (mask > 0.5).astype(np.uint8)
# Convert the image to BGRA (BGR + Alpha channel)
b_channel, g_channel, r_channel = cv2.split(img)
alpha_channel = (mask_bin * 255).astype(np.uint8)
sticker = cv2.merge([b_channel, g_channel, r_channel, alpha_channel])
# Find contours on the binary mask to create a border
contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
# Draw the border on the sticker image
cv2.drawContours(sticker, [cnt], -1, BORDER_COLOR, border_thickness)
return sticker
def process_image(input_img):
"""
Process the uploaded image:
- Run detection for persons.
- If found:
* If one person is detected, create a sticker using that mask.
* If multiple persons are detected, combine their masks and then create a sticker.
- If no person is detected, return the original image with a notification.
Parameters:
- input_img: The image provided by the user (numpy array).
Returns:
- Processed image: Either a sticker with transparency or the original image with a notification.
"""
# Ensure image is in BGR format. If image has 4 channels (BGRA), convert to BGR.
if input_img.shape[2] == 4:
input_img = cv2.cvtColor(input_img, cv2.COLOR_BGRA2BGR)
# Disable gradient calculation for inference
with torch.no_grad():
masks, boxes, labels, scores = get_prediction(input_img)
# Check if any persons were detected
if len(masks) == 0:
# If no person is detected, add a notification to the image.
print("No person detected")
output_img = input_img.copy()
cv2.putText(output_img, "No person detected.", (30, 30), cv2.FONT_HERSHEY_SIMPLEX,
1, (0, 0, 255), 2)
# Optionally, display a separate notification window (this is optional and might not work in some environments)
img = np.zeros((200, 400, 3), dtype=np.uint8)
cv2.putText(img, "No person detected.", (30, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
cv2.imshow("Test", img)
cv2.waitKey(0)
cv2.destroyAllWindows()
return output_img
elif len(masks) == 1:
# If only one person is detected, create a sticker from the single mask.
sticker = create_sticker(input_img, masks[0])
else:
# If multiple persons are detected, combine their masks via pixel-wise maximum.
combined_mask = np.zeros_like(masks[0])
for m in masks:
combined_mask = np.maximum(combined_mask, m)
sticker = create_sticker(input_img, combined_mask)
return sticker
# Load the pre-trained Mask R-CNN model.
# The model is downloaded and set to evaluation mode.
print("Loading Mask R-CNN model...")
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
# Check if GPU is available and move the model to GPU for faster inference.
if torch.cuda.is_available():
model.cuda()
print("Using GPU for inference")
else:
print("Using CPU for inference")
# Create a Gradio interface with an image upload widget.
description = """
Upload an image containing one or more persons. The model will detect the person(s) and convert them into a sticker with a transparent background and a white border.
"""
# Define the Gradio interface.
iface = gr.Interface(
fn=process_image, # Function to process the image.
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=gr.Image(type="numpy", label="Sticker Output"),
title="Person Sticker Maker",
description=description,
allow_flagging="never"
)
# Launch the Gradio app.
iface.launch()