import torch import torchvision from PIL import Image import numpy as np import matplotlib.pyplot as plt import gradio as gr # Load pretrained Mask R-CNN model model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) model.eval() # COCO labels COCO_INSTANCE_CATEGORY_NAMES = [ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ] # Detection and segmentation function def segment_objects(image, threshold=0.5): transform = torchvision.transforms.ToTensor() img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): output = model(img_tensor)[0] masks = output['masks'] # shape: [N, 1, H, W] boxes = output['boxes'] labels = output['labels'] scores = output['scores'] image_np = np.array(image).copy() fig, ax = plt.subplots(1, figsize=(10, 10)) ax.imshow(image_np) for i in range(len(masks)): if scores[i] >= threshold: mask = masks[i, 0].cpu().numpy() mask = mask > 0.5 # convert to binary mask # Random color for each mask color = np.random.rand(3) colored_mask = np.zeros_like(image_np, dtype=np.uint8) for c in range(3): colored_mask[:, :, c] = mask * int(color[c] * 255) # Blend the mask onto the image image_np = np.where(mask[:, :, None], 0.5 * image_np + 0.5 * colored_mask, image_np).astype(np.uint8) # Draw bounding box x1, y1, x2, y2 = boxes[i].cpu().numpy() ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2)) label = COCO_INSTANCE_CATEGORY_NAMES[labels[i].item()] ax.text(x1, y1, f"{label}: {scores[i]:.2f}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10) ax.imshow(image_np) ax.axis('off') output_path = "output_maskrcnn_with_masks.png" plt.savefig(output_path, bbox_inches='tight', pad_inches=0) plt.close() return output_path # Gradio interface interface = gr.Interface( fn=segment_objects, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold") ], outputs=gr.Image(type="filepath", label="Segmented Output"), title="Mask R-CNN Instance Segmentation", description="Upload an image to detect and segment objects using a pretrained Mask R-CNN model (TorchVision)." ) if __name__ == "__main__": interface.launch(debug=True)