[email protected] commited on
Commit
76e8b57
·
1 Parent(s): ed117f5
Files changed (2) hide show
  1. app.py +183 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import torchvision
5
+ from torchvision.transforms import functional as F
6
+ import gradio as gr
7
+
8
+ # Set up color for sticker border (white)
9
+ BORDER_COLOR = (255, 255, 255)
10
+
11
+ # COCO classes (used by Mask R-CNN). Note: some entries are marked as 'N/A'
12
+ # as the COCO dataset does not include a class for those indices.
13
+ CLASSES = [
14
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
15
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
16
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
17
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
18
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
19
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
20
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
21
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
22
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
23
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
24
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
25
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
26
+ ]
27
+
28
+ # We are focusing on the person class for sticker creation.
29
+ TARGET_CLASSES = ['person']
30
+
31
+ def get_prediction(img, threshold=0.7):
32
+ """
33
+ Get model predictions filtered for person detection.
34
+
35
+ Parameters:
36
+ - img: Input image in BGR format.
37
+ - threshold: Score threshold to filter weak predictions.
38
+
39
+ Returns:
40
+ - masks: List of masks for detected persons.
41
+ - boxes: Bounding boxes for detected persons.
42
+ - labels: Class labels for each detected object (only persons in our case).
43
+ - scores: Confidence scores for each detection.
44
+ """
45
+ # Convert image to tensor expected by the model
46
+ transform = F.to_tensor(img)
47
+
48
+ # Get predictions from the model
49
+ prediction = model([transform])
50
+
51
+ # Initialize lists to store filtered predictions
52
+ masks = []
53
+ boxes = []
54
+ labels = []
55
+ scores = []
56
+
57
+ # Convert model output labels to human-readable classes
58
+ pred_classes = [CLASSES[i] for i in prediction[0]['labels']]
59
+ # Get masks, boxes, and scores from the prediction
60
+ pred_masks = prediction[0]['masks'].detach().cpu().numpy()
61
+ pred_boxes = prediction[0]['boxes'].detach().cpu().numpy()
62
+ pred_scores = prediction[0]['scores'].detach().cpu().numpy()
63
+
64
+ # Filter detections based on threshold and target class (person)
65
+ for i, score in enumerate(pred_scores):
66
+ if score > threshold and pred_classes[i] in TARGET_CLASSES:
67
+ masks.append(pred_masks[i][0])
68
+ boxes.append(pred_boxes[i])
69
+ labels.append(pred_classes[i])
70
+ scores.append(score)
71
+
72
+ return masks, boxes, labels, scores
73
+
74
+ def create_sticker(img, mask, border_thickness=3):
75
+ """
76
+ Create a sticker image by applying the mask to the image.
77
+ The area outside the mask is made transparent and a border is drawn.
78
+
79
+ Parameters:
80
+ - img: Input image in BGR format.
81
+ - mask: Mask corresponding to the object (person).
82
+ - border_thickness: Thickness of the white border around the mask.
83
+
84
+ Returns:
85
+ - sticker: The resulting sticker image with transparency (BGRA format).
86
+ """
87
+ # Create a binary mask where mask values > 0.5 become 1 and others 0.
88
+ mask_bin = (mask > 0.5).astype(np.uint8)
89
+
90
+ # Convert the image to BGRA (BGR + Alpha channel)
91
+ b_channel, g_channel, r_channel = cv2.split(img)
92
+ alpha_channel = (mask_bin * 255).astype(np.uint8)
93
+ sticker = cv2.merge([b_channel, g_channel, r_channel, alpha_channel])
94
+
95
+ # Find contours on the binary mask to create a border
96
+ contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
97
+ for cnt in contours:
98
+ # Draw the border on the sticker image
99
+ cv2.drawContours(sticker, [cnt], -1, BORDER_COLOR, border_thickness)
100
+
101
+ return sticker
102
+
103
+ def process_image(input_img):
104
+ """
105
+ Process the uploaded image:
106
+ - Run detection for persons.
107
+ - If found:
108
+ * If one person is detected, create a sticker using that mask.
109
+ * If multiple persons are detected, combine their masks and then create a sticker.
110
+ - If no person is detected, return the original image with a notification.
111
+
112
+ Parameters:
113
+ - input_img: The image provided by the user (numpy array).
114
+
115
+ Returns:
116
+ - Processed image: Either a sticker with transparency or the original image with a notification.
117
+ """
118
+ # Ensure image is in BGR format. If image has 4 channels (BGRA), convert to BGR.
119
+ if input_img.shape[2] == 4:
120
+ input_img = cv2.cvtColor(input_img, cv2.COLOR_BGRA2BGR)
121
+
122
+ # Disable gradient calculation for inference
123
+ with torch.no_grad():
124
+ masks, boxes, labels, scores = get_prediction(input_img)
125
+
126
+ # Check if any persons were detected
127
+ if len(masks) == 0:
128
+ # If no person is detected, add a notification to the image.
129
+ print("No person detected")
130
+ output_img = input_img.copy()
131
+ cv2.putText(output_img, "No person detected.", (30, 30), cv2.FONT_HERSHEY_SIMPLEX,
132
+ 1, (0, 0, 255), 2)
133
+
134
+ # Optionally, display a separate notification window (this is optional and might not work in some environments)
135
+ img = np.zeros((200, 400, 3), dtype=np.uint8)
136
+ cv2.putText(img, "No person detected.", (30, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
137
+ cv2.imshow("Test", img)
138
+ cv2.waitKey(0)
139
+ cv2.destroyAllWindows()
140
+
141
+ return output_img
142
+ elif len(masks) == 1:
143
+ # If only one person is detected, create a sticker from the single mask.
144
+ sticker = create_sticker(input_img, masks[0])
145
+ else:
146
+ # If multiple persons are detected, combine their masks via pixel-wise maximum.
147
+ combined_mask = np.zeros_like(masks[0])
148
+ for m in masks:
149
+ combined_mask = np.maximum(combined_mask, m)
150
+ sticker = create_sticker(input_img, combined_mask)
151
+
152
+ return sticker
153
+
154
+ # Load the pre-trained Mask R-CNN model.
155
+ # The model is downloaded and set to evaluation mode.
156
+ print("Loading Mask R-CNN model...")
157
+ model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
158
+ model.eval()
159
+
160
+ # Check if GPU is available and move the model to GPU for faster inference.
161
+ if torch.cuda.is_available():
162
+ model.cuda()
163
+ print("Using GPU for inference")
164
+ else:
165
+ print("Using CPU for inference")
166
+
167
+ # Create a Gradio interface with an image upload widget.
168
+ description = """
169
+ Upload an image containing one or more persons. The model will detect the person(s) and convert them into a sticker with a transparent background and a white border.
170
+ """
171
+
172
+ # Define the Gradio interface.
173
+ iface = gr.Interface(
174
+ fn=process_image, # Function to process the image.
175
+ inputs=gr.Image(type="numpy", label="Upload Image"),
176
+ outputs=gr.Image(type="numpy", label="Sticker Output"),
177
+ title="Person Sticker Maker",
178
+ description=description,
179
+ allow_flagging="never"
180
+ )
181
+
182
+ # Launch the Gradio app.
183
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python
2
+ numpy
3
+ torch
4
+ torchvision
5
+ gradio