File size: 7,319 Bytes
76e8b57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import cv2
import numpy as np
import torch
import torchvision
from torchvision.transforms import functional as F
import gradio as gr

# Set up color for sticker border (white)
BORDER_COLOR = (255, 255, 255)

# COCO classes (used by Mask R-CNN). Note: some entries are marked as 'N/A'
# as the COCO dataset does not include a class for those indices.
CLASSES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

# We are focusing on the person class for sticker creation.
TARGET_CLASSES = ['person']

def get_prediction(img, threshold=0.7):
    """
    Get model predictions filtered for person detection.
    
    Parameters:
    - img: Input image in BGR format.
    - threshold: Score threshold to filter weak predictions.
    
    Returns:
    - masks: List of masks for detected persons.
    - boxes: Bounding boxes for detected persons.
    - labels: Class labels for each detected object (only persons in our case).
    - scores: Confidence scores for each detection.
    """
    # Convert image to tensor expected by the model
    transform = F.to_tensor(img)
    
    # Get predictions from the model
    prediction = model([transform])
    
    # Initialize lists to store filtered predictions
    masks = []
    boxes = []
    labels = []
    scores = []
    
    # Convert model output labels to human-readable classes
    pred_classes = [CLASSES[i] for i in prediction[0]['labels']]
    # Get masks, boxes, and scores from the prediction
    pred_masks = prediction[0]['masks'].detach().cpu().numpy()
    pred_boxes = prediction[0]['boxes'].detach().cpu().numpy()
    pred_scores = prediction[0]['scores'].detach().cpu().numpy()
    
    # Filter detections based on threshold and target class (person)
    for i, score in enumerate(pred_scores):
        if score > threshold and pred_classes[i] in TARGET_CLASSES:
            masks.append(pred_masks[i][0])
            boxes.append(pred_boxes[i])
            labels.append(pred_classes[i])
            scores.append(score)
            
    return masks, boxes, labels, scores

def create_sticker(img, mask, border_thickness=3):
    """
    Create a sticker image by applying the mask to the image.
    The area outside the mask is made transparent and a border is drawn.
    
    Parameters:
    - img: Input image in BGR format.
    - mask: Mask corresponding to the object (person).
    - border_thickness: Thickness of the white border around the mask.
    
    Returns:
    - sticker: The resulting sticker image with transparency (BGRA format).
    """
    # Create a binary mask where mask values > 0.5 become 1 and others 0.
    mask_bin = (mask > 0.5).astype(np.uint8)
    
    # Convert the image to BGRA (BGR + Alpha channel)
    b_channel, g_channel, r_channel = cv2.split(img)
    alpha_channel = (mask_bin * 255).astype(np.uint8)
    sticker = cv2.merge([b_channel, g_channel, r_channel, alpha_channel])
    
    # Find contours on the binary mask to create a border
    contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for cnt in contours:
        # Draw the border on the sticker image
        cv2.drawContours(sticker, [cnt], -1, BORDER_COLOR, border_thickness)
    
    return sticker

def process_image(input_img):
    """
    Process the uploaded image:
      - Run detection for persons.
      - If found:
          * If one person is detected, create a sticker using that mask.
          * If multiple persons are detected, combine their masks and then create a sticker.
      - If no person is detected, return the original image with a notification.
    
    Parameters:
    - input_img: The image provided by the user (numpy array).
    
    Returns:
    - Processed image: Either a sticker with transparency or the original image with a notification.
    """
    # Ensure image is in BGR format. If image has 4 channels (BGRA), convert to BGR.
    if input_img.shape[2] == 4:
        input_img = cv2.cvtColor(input_img, cv2.COLOR_BGRA2BGR)
    
    # Disable gradient calculation for inference
    with torch.no_grad():
        masks, boxes, labels, scores = get_prediction(input_img)
    
    # Check if any persons were detected
    if len(masks) == 0:
        # If no person is detected, add a notification to the image.
        print("No person detected")
        output_img = input_img.copy()
        cv2.putText(output_img, "No person detected.", (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                    1, (0, 0, 255), 2)
        
        # Optionally, display a separate notification window (this is optional and might not work in some environments)
        img = np.zeros((200, 400, 3), dtype=np.uint8)
        cv2.putText(img, "No person detected.", (30, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
        cv2.imshow("Test", img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        
        return output_img
    elif len(masks) == 1:
        # If only one person is detected, create a sticker from the single mask.
        sticker = create_sticker(input_img, masks[0])
    else:
        # If multiple persons are detected, combine their masks via pixel-wise maximum.
        combined_mask = np.zeros_like(masks[0])
        for m in masks:
            combined_mask = np.maximum(combined_mask, m)
        sticker = create_sticker(input_img, combined_mask)
    
    return sticker

# Load the pre-trained Mask R-CNN model.
# The model is downloaded and set to evaluation mode.
print("Loading Mask R-CNN model...")
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()

# Check if GPU is available and move the model to GPU for faster inference.
if torch.cuda.is_available():
    model.cuda()
    print("Using GPU for inference")
else:
    print("Using CPU for inference")

# Create a Gradio interface with an image upload widget.
description = """
Upload an image containing one or more persons. The model will detect the person(s) and convert them into a sticker with a transparent background and a white border.
"""

# Define the Gradio interface.
iface = gr.Interface(
    fn=process_image,  # Function to process the image.
    inputs=gr.Image(type="numpy", label="Upload Image"),
    outputs=gr.Image(type="numpy", label="Sticker Output"),
    title="Person Sticker Maker",
    description=description,
    allow_flagging="never"
)

# Launch the Gradio app.
iface.launch()