File size: 3,316 Bytes
dfdcd97
a3ee867
e9cd6fd
 
c95f3e0
 
e0d4d2f
c95f3e0
 
 
 
 
 
e0d4d2f
564688d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bee2b4
73989e5
 
 
 
99fdace
73989e5
 
 
 
 
 
 
99fdace
73989e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564688d
73989e5
 
 
564688d
73989e5
564688d
73989e5
 
 
 
 
 
 
 
3ba1061
73989e5
 
e0d4d2f
e9cd6fd
 
e0d4d2f
e9cd6fd
3ba1061
7bee2b4
e9cd6fd
3ba1061
 
 
 
c95f3e0
7bee2b4
e0d4d2f
 
e9cd6fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor
from PIL import Image

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

def process_mask(mask, target_size):
    # Ensure mask is 2D
    if mask.ndim > 2:
        mask = mask.squeeze()
    
    # If mask is still not 2D, take the first 2D slice
    if mask.ndim > 2:
        mask = mask[0]
    
    # Convert to binary
    mask = (mask > 0.5).astype(np.uint8) * 255
    
    # Resize mask to match original image size using PIL
    mask_image = Image.fromarray(mask)
    mask_image = mask_image.resize(target_size, Image.NEAREST)
    
    return np.array(mask_image) > 0

def segment_image(input_image, segment_anything):
    try:
        if input_image is None:
            return None, "Please upload an image before submitting."
        
        # Convert input_image to PIL Image and ensure it's RGB
        input_image = Image.fromarray(input_image).convert("RGB")
        
        # Store original size
        original_size = input_image.size
        if not original_size or 0 in original_size:
            return None, "Invalid image size. Please upload a different image."
        
        # Process the image
        if segment_anything:
            inputs = processor(input_image, return_tensors="pt").to(device)
        else:
            width, height = original_size
            center_point = [[width // 2, height // 2]]
            inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
        
        # Generate masks
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Post-process masks
        masks = processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(),
            inputs["original_sizes"].cpu(),
            inputs["reshaped_input_sizes"].cpu()
        )
        
        # Process the mask
        if segment_anything:
            combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
        else:
            combined_mask = masks[0][0].numpy()
        
        combined_mask = process_mask(combined_mask, original_size)
        
        # Overlay the mask on the original image
        result_image = np.array(input_image)
        mask_rgb = np.zeros_like(result_image)
        mask_rgb[combined_mask] = [255, 0, 0]  # Red color for the mask
        result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
        
        return result_image, "Segmentation completed successfully."
    
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(type="numpy", label="Upload an image"),
        gr.Checkbox(label="Segment Everything")
    ],
    outputs=[
        gr.Image(type="numpy", label="Segmented Image"),
        gr.Textbox(label="Status")
    ],
    title="Segment Anything Model (SAM) Image Segmentation",
    description="Upload an image and choose whether to segment everything or use a center point."
)

# Launch the interface
iface.launch()