Spaces:
Sleeping
Sleeping
File size: 1,806 Bytes
dfdcd97 a3ee867 e9cd6fd c95f3e0 e0d4d2f c95f3e0 e0d4d2f ac51df9 c95f3e0 ac51df9 e0d4d2f c95f3e0 e0d4d2f c95f3e0 e0d4d2f c95f3e0 e9cd6fd c95f3e0 e9cd6fd e0d4d2f e9cd6fd e0d4d2f e9cd6fd ac51df9 e9cd6fd c95f3e0 ac51df9 e0d4d2f e9cd6fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor
from PIL import Image
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def segment_image(input_image, x, y):
# Convert input_image to PIL Image
input_image = Image.fromarray(input_image)
# Prepare inputs
inputs = processor(input_image, input_points=np.array([[x, y]]), return_tensors="pt").to(device)
# Generate masks
with torch.no_grad():
outputs = model(**inputs)
# Post-process masks
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
# Convert mask to numpy array
mask = masks[0][0].numpy()
# Overlay the mask on the original image
result_image = np.array(input_image)
mask_rgb = np.zeros_like(result_image)
mask_rgb[mask > 0.5] = [255, 0, 0] # Red color for the mask
result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
return result_image
# Create Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="numpy"),
gr.Slider(minimum=0, maximum=1000, step=1, label="X coordinate"),
gr.Slider(minimum=0, maximum=1000, step=1, label="Y coordinate")
],
outputs=gr.Image(type="numpy"),
title="Segment Anything Model (SAM) Image Segmentation",
description="Enter X and Y coordinates of the object you want to segment."
)
# Launch the interface
iface.launch() |