|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation |
|
|
|
|
|
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
|
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
|
model.eval() |
|
|
|
|
|
COCO_INSTANCE_CATEGORY_NAMES = model.config.id2label if hasattr(model.config, "id2label") else [str(i) for i in range(133)] |
|
|
|
def segment_image(image, threshold=0.5): |
|
inputs = processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
results = processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0] |
|
|
|
segmentation_map = results["segmentation"].cpu().numpy() |
|
segments_info = results["segments_info"] |
|
|
|
image_np = np.array(image).copy() |
|
overlay = image_np.copy() |
|
fig, ax = plt.subplots(1, figsize=(10, 10)) |
|
ax.imshow(image_np) |
|
|
|
for segment in segments_info: |
|
score = segment.get("score", 1.0) |
|
if score < threshold: |
|
continue |
|
|
|
segment_id = segment["id"] |
|
label_id = segment["label_id"] |
|
mask = segmentation_map == segment_id |
|
|
|
|
|
color = np.random.rand(3) |
|
overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8) |
|
|
|
|
|
y_indices, x_indices = np.where(mask) |
|
if len(x_indices) == 0 or len(y_indices) == 0: |
|
continue |
|
x1, x2 = x_indices.min(), x_indices.max() |
|
y1, y2 = y_indices.min(), y_indices.max() |
|
|
|
label_name = COCO_INSTANCE_CATEGORY_NAMES.get(str(label_id), str(label_id)) |
|
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2)) |
|
ax.text(x1, y1, f"{label_name}: {score:.2f}", |
|
bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10) |
|
|
|
ax.imshow(overlay) |
|
ax.axis('off') |
|
output_path = "mask2former_output.png" |
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0) |
|
plt.close() |
|
return output_path |
|
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=segment_image, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Image"), |
|
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold") |
|
], |
|
outputs=gr.Image(type="filepath", label="Segmented Output"), |
|
title="Mask2Former Instance Segmentation (Transformer)", |
|
description="Upload an image to segment objects using Facebook's transformer-based Mask2Former model (Swin-Small backbone)." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch(debug=True,share=True) |
|
|