File size: 1,806 Bytes
dfdcd97
a3ee867
e9cd6fd
 
c95f3e0
 
e0d4d2f
c95f3e0
 
 
 
 
 
e0d4d2f
ac51df9
c95f3e0
 
 
 
ac51df9
e0d4d2f
c95f3e0
 
 
e0d4d2f
c95f3e0
 
 
 
 
 
 
e0d4d2f
c95f3e0
 
e9cd6fd
 
c95f3e0
 
 
 
e9cd6fd
 
e0d4d2f
e9cd6fd
 
e0d4d2f
e9cd6fd
 
ac51df9
 
e9cd6fd
 
c95f3e0
ac51df9
e0d4d2f
 
e9cd6fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor
from PIL import Image

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

def segment_image(input_image, x, y):
    # Convert input_image to PIL Image
    input_image = Image.fromarray(input_image)
    
    # Prepare inputs
    inputs = processor(input_image, input_points=np.array([[x, y]]), return_tensors="pt").to(device)
    
    # Generate masks
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Post-process masks
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )
    scores = outputs.iou_scores
    
    # Convert mask to numpy array
    mask = masks[0][0].numpy()
    
    # Overlay the mask on the original image
    result_image = np.array(input_image)
    mask_rgb = np.zeros_like(result_image)
    mask_rgb[mask > 0.5] = [255, 0, 0]  # Red color for the mask
    result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
    
    return result_image

# Create Gradio interface
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(type="numpy"),
        gr.Slider(minimum=0, maximum=1000, step=1, label="X coordinate"),
        gr.Slider(minimum=0, maximum=1000, step=1, label="Y coordinate")
    ],
    outputs=gr.Image(type="numpy"),
    title="Segment Anything Model (SAM) Image Segmentation",
    description="Enter X and Y coordinates of the object you want to segment."
)

# Launch the interface
iface.launch()