File size: 7,319 Bytes
76e8b57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import cv2
import numpy as np
import torch
import torchvision
from torchvision.transforms import functional as F
import gradio as gr
# Set up color for sticker border (white)
BORDER_COLOR = (255, 255, 255)
# COCO classes (used by Mask R-CNN). Note: some entries are marked as 'N/A'
# as the COCO dataset does not include a class for those indices.
CLASSES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
# We are focusing on the person class for sticker creation.
TARGET_CLASSES = ['person']
def get_prediction(img, threshold=0.7):
"""
Get model predictions filtered for person detection.
Parameters:
- img: Input image in BGR format.
- threshold: Score threshold to filter weak predictions.
Returns:
- masks: List of masks for detected persons.
- boxes: Bounding boxes for detected persons.
- labels: Class labels for each detected object (only persons in our case).
- scores: Confidence scores for each detection.
"""
# Convert image to tensor expected by the model
transform = F.to_tensor(img)
# Get predictions from the model
prediction = model([transform])
# Initialize lists to store filtered predictions
masks = []
boxes = []
labels = []
scores = []
# Convert model output labels to human-readable classes
pred_classes = [CLASSES[i] for i in prediction[0]['labels']]
# Get masks, boxes, and scores from the prediction
pred_masks = prediction[0]['masks'].detach().cpu().numpy()
pred_boxes = prediction[0]['boxes'].detach().cpu().numpy()
pred_scores = prediction[0]['scores'].detach().cpu().numpy()
# Filter detections based on threshold and target class (person)
for i, score in enumerate(pred_scores):
if score > threshold and pred_classes[i] in TARGET_CLASSES:
masks.append(pred_masks[i][0])
boxes.append(pred_boxes[i])
labels.append(pred_classes[i])
scores.append(score)
return masks, boxes, labels, scores
def create_sticker(img, mask, border_thickness=3):
"""
Create a sticker image by applying the mask to the image.
The area outside the mask is made transparent and a border is drawn.
Parameters:
- img: Input image in BGR format.
- mask: Mask corresponding to the object (person).
- border_thickness: Thickness of the white border around the mask.
Returns:
- sticker: The resulting sticker image with transparency (BGRA format).
"""
# Create a binary mask where mask values > 0.5 become 1 and others 0.
mask_bin = (mask > 0.5).astype(np.uint8)
# Convert the image to BGRA (BGR + Alpha channel)
b_channel, g_channel, r_channel = cv2.split(img)
alpha_channel = (mask_bin * 255).astype(np.uint8)
sticker = cv2.merge([b_channel, g_channel, r_channel, alpha_channel])
# Find contours on the binary mask to create a border
contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
# Draw the border on the sticker image
cv2.drawContours(sticker, [cnt], -1, BORDER_COLOR, border_thickness)
return sticker
def process_image(input_img):
"""
Process the uploaded image:
- Run detection for persons.
- If found:
* If one person is detected, create a sticker using that mask.
* If multiple persons are detected, combine their masks and then create a sticker.
- If no person is detected, return the original image with a notification.
Parameters:
- input_img: The image provided by the user (numpy array).
Returns:
- Processed image: Either a sticker with transparency or the original image with a notification.
"""
# Ensure image is in BGR format. If image has 4 channels (BGRA), convert to BGR.
if input_img.shape[2] == 4:
input_img = cv2.cvtColor(input_img, cv2.COLOR_BGRA2BGR)
# Disable gradient calculation for inference
with torch.no_grad():
masks, boxes, labels, scores = get_prediction(input_img)
# Check if any persons were detected
if len(masks) == 0:
# If no person is detected, add a notification to the image.
print("No person detected")
output_img = input_img.copy()
cv2.putText(output_img, "No person detected.", (30, 30), cv2.FONT_HERSHEY_SIMPLEX,
1, (0, 0, 255), 2)
# Optionally, display a separate notification window (this is optional and might not work in some environments)
img = np.zeros((200, 400, 3), dtype=np.uint8)
cv2.putText(img, "No person detected.", (30, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
cv2.imshow("Test", img)
cv2.waitKey(0)
cv2.destroyAllWindows()
return output_img
elif len(masks) == 1:
# If only one person is detected, create a sticker from the single mask.
sticker = create_sticker(input_img, masks[0])
else:
# If multiple persons are detected, combine their masks via pixel-wise maximum.
combined_mask = np.zeros_like(masks[0])
for m in masks:
combined_mask = np.maximum(combined_mask, m)
sticker = create_sticker(input_img, combined_mask)
return sticker
# Load the pre-trained Mask R-CNN model.
# The model is downloaded and set to evaluation mode.
print("Loading Mask R-CNN model...")
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
# Check if GPU is available and move the model to GPU for faster inference.
if torch.cuda.is_available():
model.cuda()
print("Using GPU for inference")
else:
print("Using CPU for inference")
# Create a Gradio interface with an image upload widget.
description = """
Upload an image containing one or more persons. The model will detect the person(s) and convert them into a sticker with a transparent background and a white border.
"""
# Define the Gradio interface.
iface = gr.Interface(
fn=process_image, # Function to process the image.
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=gr.Image(type="numpy", label="Sticker Output"),
title="Person Sticker Maker",
description=description,
allow_flagging="never"
)
# Launch the Gradio app.
iface.launch()
|