venkyvicky commited on
Commit
8eed584
·
verified ·
1 Parent(s): 5ab45cd

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +78 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import gradio as gr
6
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
7
+
8
+ # Load processor and model from Hugging Face
9
+ processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
10
+ model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance")
11
+ model.eval()
12
+
13
+ # Load label map from model config
14
+ COCO_INSTANCE_CATEGORY_NAMES = model.config.id2label if hasattr(model.config, "id2label") else [str(i) for i in range(133)]
15
+
16
+ def segment_image(image, threshold=0.5):
17
+ inputs = processor(images=image, return_tensors="pt")
18
+ with torch.no_grad():
19
+ outputs = model(**inputs)
20
+
21
+ results = processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
22
+
23
+ segmentation_map = results["segmentation"].cpu().numpy() # shape: [H, W]
24
+ segments_info = results["segments_info"] # list of dicts with keys: id, label_id, score
25
+
26
+ image_np = np.array(image).copy()
27
+ overlay = image_np.copy()
28
+ fig, ax = plt.subplots(1, figsize=(10, 10))
29
+ ax.imshow(image_np)
30
+
31
+ for segment in segments_info:
32
+ score = segment.get("score", 1.0)
33
+ if score < threshold:
34
+ continue
35
+
36
+ segment_id = segment["id"]
37
+ label_id = segment["label_id"]
38
+ mask = segmentation_map == segment_id
39
+
40
+ # Random color per object
41
+ color = np.random.rand(3)
42
+ overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8)
43
+
44
+ # Draw bounding box
45
+ y_indices, x_indices = np.where(mask)
46
+ if len(x_indices) == 0 or len(y_indices) == 0:
47
+ continue
48
+ x1, x2 = x_indices.min(), x_indices.max()
49
+ y1, y2 = y_indices.min(), y_indices.max()
50
+
51
+ label_name = COCO_INSTANCE_CATEGORY_NAMES.get(str(label_id), str(label_id))
52
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2))
53
+ ax.text(x1, y1, f"{label_name}: {score:.2f}",
54
+ bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)
55
+
56
+ ax.imshow(overlay)
57
+ ax.axis('off')
58
+ output_path = "mask2former_output.png"
59
+ plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
60
+ plt.close()
61
+ return output_path
62
+
63
+
64
+
65
+ # Gradio interface
66
+ interface = gr.Interface(
67
+ fn=segment_image,
68
+ inputs=[
69
+ gr.Image(type="pil", label="Upload Image"),
70
+ gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
71
+ ],
72
+ outputs=gr.Image(type="filepath", label="Segmented Output"),
73
+ title="Mask2Former Instance Segmentation (Transformer)",
74
+ description="Upload an image to segment objects using Facebook's transformer-based Mask2Former model (Swin-Small backbone)."
75
+ )
76
+
77
+ if __name__ == "__main__":
78
+ interface.launch(debug=True,share=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.40.0
3
+ gradio>=4.24.0
4
+ matplotlib
5
+ Pillow
6
+ numpy
7
+ scipy