import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import gradio as gr from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation # Load processor and model from Hugging Face processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance") model.eval() # Load label map from model config COCO_INSTANCE_CATEGORY_NAMES = model.config.id2label if hasattr(model.config, "id2label") else [str(i) for i in range(133)] def segment_image(image, threshold=0.5): inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) results = processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0] segmentation_map = results["segmentation"].cpu().numpy() # shape: [H, W] segments_info = results["segments_info"] # list of dicts with keys: id, label_id, score image_np = np.array(image).copy() overlay = image_np.copy() fig, ax = plt.subplots(1, figsize=(10, 10)) ax.imshow(image_np) for segment in segments_info: score = segment.get("score", 1.0) if score < threshold: continue segment_id = segment["id"] label_id = segment["label_id"] mask = segmentation_map == segment_id # Random color per object color = np.random.rand(3) overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8) # Draw bounding box y_indices, x_indices = np.where(mask) if len(x_indices) == 0 or len(y_indices) == 0: continue x1, x2 = x_indices.min(), x_indices.max() y1, y2 = y_indices.min(), y_indices.max() label_name = COCO_INSTANCE_CATEGORY_NAMES.get(str(label_id), str(label_id)) ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2)) ax.text(x1, y1, f"{label_name}: {score:.2f}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10) ax.imshow(overlay) ax.axis('off') output_path = "mask2former_output.png" plt.savefig(output_path, bbox_inches='tight', pad_inches=0) plt.close() return output_path # Gradio interface interface = gr.Interface( fn=segment_image, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold") ], outputs=gr.Image(type="filepath", label="Segmented Output"), title="Mask2Former Instance Segmentation (Transformer)", description="Upload an image to segment objects using Facebook's transformer-based Mask2Former model (Swin-Small backbone)." ) if __name__ == "__main__": interface.launch(debug=True,share=True)