|
import torch |
|
import torchvision |
|
from PIL import Image |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
|
|
|
|
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) |
|
model.eval() |
|
|
|
|
|
COCO_INSTANCE_CATEGORY_NAMES = [ |
|
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', |
|
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', |
|
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', |
|
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', |
|
'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', |
|
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', |
|
'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', |
|
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', |
|
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', |
|
'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', |
|
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', |
|
'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', |
|
'hair drier', 'toothbrush' |
|
] |
|
|
|
|
|
def segment_objects(image, threshold=0.5): |
|
transform = torchvision.transforms.ToTensor() |
|
img_tensor = transform(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model(img_tensor)[0] |
|
|
|
masks = output['masks'] |
|
boxes = output['boxes'] |
|
labels = output['labels'] |
|
scores = output['scores'] |
|
|
|
image_np = np.array(image).copy() |
|
fig, ax = plt.subplots(1, figsize=(10, 10)) |
|
ax.imshow(image_np) |
|
|
|
for i in range(len(masks)): |
|
if scores[i] >= threshold: |
|
mask = masks[i, 0].cpu().numpy() |
|
mask = mask > 0.5 |
|
|
|
|
|
color = np.random.rand(3) |
|
colored_mask = np.zeros_like(image_np, dtype=np.uint8) |
|
for c in range(3): |
|
colored_mask[:, :, c] = mask * int(color[c] * 255) |
|
|
|
|
|
image_np = np.where(mask[:, :, None], 0.5 * image_np + 0.5 * colored_mask, image_np).astype(np.uint8) |
|
|
|
|
|
x1, y1, x2, y2 = boxes[i].cpu().numpy() |
|
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, |
|
fill=False, color=color, linewidth=2)) |
|
label = COCO_INSTANCE_CATEGORY_NAMES[labels[i].item()] |
|
ax.text(x1, y1, f"{label}: {scores[i]:.2f}", |
|
bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10) |
|
|
|
ax.imshow(image_np) |
|
ax.axis('off') |
|
output_path = "output_maskrcnn_with_masks.png" |
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0) |
|
plt.close() |
|
return output_path |
|
|
|
|
|
interface = gr.Interface( |
|
fn=segment_objects, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Image"), |
|
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold") |
|
], |
|
outputs=gr.Image(type="filepath", label="Segmented Output"), |
|
title="Mask R-CNN Instance Segmentation", |
|
description="Upload an image to detect and segment objects using a pretrained Mask R-CNN model (TorchVision)." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch(debug=True) |
|
|