[email protected]
commited on
Commit
·
76e8b57
1
Parent(s):
ed117f5
init
Browse files- app.py +183 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
from torchvision.transforms import functional as F
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
# Set up color for sticker border (white)
|
9 |
+
BORDER_COLOR = (255, 255, 255)
|
10 |
+
|
11 |
+
# COCO classes (used by Mask R-CNN). Note: some entries are marked as 'N/A'
|
12 |
+
# as the COCO dataset does not include a class for those indices.
|
13 |
+
CLASSES = [
|
14 |
+
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
15 |
+
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
|
16 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
17 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
|
18 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
19 |
+
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
20 |
+
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
21 |
+
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
22 |
+
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
|
23 |
+
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
24 |
+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
|
25 |
+
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
26 |
+
]
|
27 |
+
|
28 |
+
# We are focusing on the person class for sticker creation.
|
29 |
+
TARGET_CLASSES = ['person']
|
30 |
+
|
31 |
+
def get_prediction(img, threshold=0.7):
|
32 |
+
"""
|
33 |
+
Get model predictions filtered for person detection.
|
34 |
+
|
35 |
+
Parameters:
|
36 |
+
- img: Input image in BGR format.
|
37 |
+
- threshold: Score threshold to filter weak predictions.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
- masks: List of masks for detected persons.
|
41 |
+
- boxes: Bounding boxes for detected persons.
|
42 |
+
- labels: Class labels for each detected object (only persons in our case).
|
43 |
+
- scores: Confidence scores for each detection.
|
44 |
+
"""
|
45 |
+
# Convert image to tensor expected by the model
|
46 |
+
transform = F.to_tensor(img)
|
47 |
+
|
48 |
+
# Get predictions from the model
|
49 |
+
prediction = model([transform])
|
50 |
+
|
51 |
+
# Initialize lists to store filtered predictions
|
52 |
+
masks = []
|
53 |
+
boxes = []
|
54 |
+
labels = []
|
55 |
+
scores = []
|
56 |
+
|
57 |
+
# Convert model output labels to human-readable classes
|
58 |
+
pred_classes = [CLASSES[i] for i in prediction[0]['labels']]
|
59 |
+
# Get masks, boxes, and scores from the prediction
|
60 |
+
pred_masks = prediction[0]['masks'].detach().cpu().numpy()
|
61 |
+
pred_boxes = prediction[0]['boxes'].detach().cpu().numpy()
|
62 |
+
pred_scores = prediction[0]['scores'].detach().cpu().numpy()
|
63 |
+
|
64 |
+
# Filter detections based on threshold and target class (person)
|
65 |
+
for i, score in enumerate(pred_scores):
|
66 |
+
if score > threshold and pred_classes[i] in TARGET_CLASSES:
|
67 |
+
masks.append(pred_masks[i][0])
|
68 |
+
boxes.append(pred_boxes[i])
|
69 |
+
labels.append(pred_classes[i])
|
70 |
+
scores.append(score)
|
71 |
+
|
72 |
+
return masks, boxes, labels, scores
|
73 |
+
|
74 |
+
def create_sticker(img, mask, border_thickness=3):
|
75 |
+
"""
|
76 |
+
Create a sticker image by applying the mask to the image.
|
77 |
+
The area outside the mask is made transparent and a border is drawn.
|
78 |
+
|
79 |
+
Parameters:
|
80 |
+
- img: Input image in BGR format.
|
81 |
+
- mask: Mask corresponding to the object (person).
|
82 |
+
- border_thickness: Thickness of the white border around the mask.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
- sticker: The resulting sticker image with transparency (BGRA format).
|
86 |
+
"""
|
87 |
+
# Create a binary mask where mask values > 0.5 become 1 and others 0.
|
88 |
+
mask_bin = (mask > 0.5).astype(np.uint8)
|
89 |
+
|
90 |
+
# Convert the image to BGRA (BGR + Alpha channel)
|
91 |
+
b_channel, g_channel, r_channel = cv2.split(img)
|
92 |
+
alpha_channel = (mask_bin * 255).astype(np.uint8)
|
93 |
+
sticker = cv2.merge([b_channel, g_channel, r_channel, alpha_channel])
|
94 |
+
|
95 |
+
# Find contours on the binary mask to create a border
|
96 |
+
contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
97 |
+
for cnt in contours:
|
98 |
+
# Draw the border on the sticker image
|
99 |
+
cv2.drawContours(sticker, [cnt], -1, BORDER_COLOR, border_thickness)
|
100 |
+
|
101 |
+
return sticker
|
102 |
+
|
103 |
+
def process_image(input_img):
|
104 |
+
"""
|
105 |
+
Process the uploaded image:
|
106 |
+
- Run detection for persons.
|
107 |
+
- If found:
|
108 |
+
* If one person is detected, create a sticker using that mask.
|
109 |
+
* If multiple persons are detected, combine their masks and then create a sticker.
|
110 |
+
- If no person is detected, return the original image with a notification.
|
111 |
+
|
112 |
+
Parameters:
|
113 |
+
- input_img: The image provided by the user (numpy array).
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
- Processed image: Either a sticker with transparency or the original image with a notification.
|
117 |
+
"""
|
118 |
+
# Ensure image is in BGR format. If image has 4 channels (BGRA), convert to BGR.
|
119 |
+
if input_img.shape[2] == 4:
|
120 |
+
input_img = cv2.cvtColor(input_img, cv2.COLOR_BGRA2BGR)
|
121 |
+
|
122 |
+
# Disable gradient calculation for inference
|
123 |
+
with torch.no_grad():
|
124 |
+
masks, boxes, labels, scores = get_prediction(input_img)
|
125 |
+
|
126 |
+
# Check if any persons were detected
|
127 |
+
if len(masks) == 0:
|
128 |
+
# If no person is detected, add a notification to the image.
|
129 |
+
print("No person detected")
|
130 |
+
output_img = input_img.copy()
|
131 |
+
cv2.putText(output_img, "No person detected.", (30, 30), cv2.FONT_HERSHEY_SIMPLEX,
|
132 |
+
1, (0, 0, 255), 2)
|
133 |
+
|
134 |
+
# Optionally, display a separate notification window (this is optional and might not work in some environments)
|
135 |
+
img = np.zeros((200, 400, 3), dtype=np.uint8)
|
136 |
+
cv2.putText(img, "No person detected.", (30, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
|
137 |
+
cv2.imshow("Test", img)
|
138 |
+
cv2.waitKey(0)
|
139 |
+
cv2.destroyAllWindows()
|
140 |
+
|
141 |
+
return output_img
|
142 |
+
elif len(masks) == 1:
|
143 |
+
# If only one person is detected, create a sticker from the single mask.
|
144 |
+
sticker = create_sticker(input_img, masks[0])
|
145 |
+
else:
|
146 |
+
# If multiple persons are detected, combine their masks via pixel-wise maximum.
|
147 |
+
combined_mask = np.zeros_like(masks[0])
|
148 |
+
for m in masks:
|
149 |
+
combined_mask = np.maximum(combined_mask, m)
|
150 |
+
sticker = create_sticker(input_img, combined_mask)
|
151 |
+
|
152 |
+
return sticker
|
153 |
+
|
154 |
+
# Load the pre-trained Mask R-CNN model.
|
155 |
+
# The model is downloaded and set to evaluation mode.
|
156 |
+
print("Loading Mask R-CNN model...")
|
157 |
+
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
|
158 |
+
model.eval()
|
159 |
+
|
160 |
+
# Check if GPU is available and move the model to GPU for faster inference.
|
161 |
+
if torch.cuda.is_available():
|
162 |
+
model.cuda()
|
163 |
+
print("Using GPU for inference")
|
164 |
+
else:
|
165 |
+
print("Using CPU for inference")
|
166 |
+
|
167 |
+
# Create a Gradio interface with an image upload widget.
|
168 |
+
description = """
|
169 |
+
Upload an image containing one or more persons. The model will detect the person(s) and convert them into a sticker with a transparent background and a white border.
|
170 |
+
"""
|
171 |
+
|
172 |
+
# Define the Gradio interface.
|
173 |
+
iface = gr.Interface(
|
174 |
+
fn=process_image, # Function to process the image.
|
175 |
+
inputs=gr.Image(type="numpy", label="Upload Image"),
|
176 |
+
outputs=gr.Image(type="numpy", label="Sticker Output"),
|
177 |
+
title="Person Sticker Maker",
|
178 |
+
description=description,
|
179 |
+
allow_flagging="never"
|
180 |
+
)
|
181 |
+
|
182 |
+
# Launch the Gradio app.
|
183 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python
|
2 |
+
numpy
|
3 |
+
torch
|
4 |
+
torchvision
|
5 |
+
gradio
|