File size: 2,903 Bytes
8eed584 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import gradio as gr
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
# Load processor and model from Hugging Face
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance")
model.eval()
# Load label map from model config
COCO_INSTANCE_CATEGORY_NAMES = model.config.id2label if hasattr(model.config, "id2label") else [str(i) for i in range(133)]
def segment_image(image, threshold=0.5):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
segmentation_map = results["segmentation"].cpu().numpy() # shape: [H, W]
segments_info = results["segments_info"] # list of dicts with keys: id, label_id, score
image_np = np.array(image).copy()
overlay = image_np.copy()
fig, ax = plt.subplots(1, figsize=(10, 10))
ax.imshow(image_np)
for segment in segments_info:
score = segment.get("score", 1.0)
if score < threshold:
continue
segment_id = segment["id"]
label_id = segment["label_id"]
mask = segmentation_map == segment_id
# Random color per object
color = np.random.rand(3)
overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8)
# Draw bounding box
y_indices, x_indices = np.where(mask)
if len(x_indices) == 0 or len(y_indices) == 0:
continue
x1, x2 = x_indices.min(), x_indices.max()
y1, y2 = y_indices.min(), y_indices.max()
label_name = COCO_INSTANCE_CATEGORY_NAMES.get(str(label_id), str(label_id))
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2))
ax.text(x1, y1, f"{label_name}: {score:.2f}",
bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)
ax.imshow(overlay)
ax.axis('off')
output_path = "mask2former_output.png"
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
plt.close()
return output_path
# Gradio interface
interface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
],
outputs=gr.Image(type="filepath", label="Segmented Output"),
title="Mask2Former Instance Segmentation (Transformer)",
description="Upload an image to segment objects using Facebook's transformer-based Mask2Former model (Swin-Small backbone)."
)
if __name__ == "__main__":
interface.launch(debug=True,share=True)
|