File size: 2,903 Bytes
8eed584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import gradio as gr
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation

# Load processor and model from Hugging Face
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance")
model.eval()

# Load label map from model config
COCO_INSTANCE_CATEGORY_NAMES = model.config.id2label if hasattr(model.config, "id2label") else [str(i) for i in range(133)]

def segment_image(image, threshold=0.5):
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    results = processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]

    segmentation_map = results["segmentation"].cpu().numpy()  # shape: [H, W]
    segments_info = results["segments_info"]  # list of dicts with keys: id, label_id, score

    image_np = np.array(image).copy()
    overlay = image_np.copy()
    fig, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(image_np)

    for segment in segments_info:
        score = segment.get("score", 1.0)
        if score < threshold:
            continue

        segment_id = segment["id"]
        label_id = segment["label_id"]
        mask = segmentation_map == segment_id

        # Random color per object
        color = np.random.rand(3)
        overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8)

        # Draw bounding box
        y_indices, x_indices = np.where(mask)
        if len(x_indices) == 0 or len(y_indices) == 0:
            continue
        x1, x2 = x_indices.min(), x_indices.max()
        y1, y2 = y_indices.min(), y_indices.max()

        label_name = COCO_INSTANCE_CATEGORY_NAMES.get(str(label_id), str(label_id))
        ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2))
        ax.text(x1, y1, f"{label_name}: {score:.2f}",
                bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)

    ax.imshow(overlay)
    ax.axis('off')
    output_path = "mask2former_output.png"
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    return output_path



# Gradio interface
interface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
    ],
    outputs=gr.Image(type="filepath", label="Segmented Output"),
    title="Mask2Former Instance Segmentation (Transformer)",
    description="Upload an image to segment objects using Facebook's transformer-based Mask2Former model (Swin-Small backbone)."
)

if __name__ == "__main__":
    interface.launch(debug=True,share=True)