Spaces:
Sleeping
Sleeping
File size: 3,316 Bytes
dfdcd97 a3ee867 e9cd6fd c95f3e0 e0d4d2f c95f3e0 e0d4d2f 564688d 7bee2b4 73989e5 99fdace 73989e5 99fdace 73989e5 564688d 73989e5 564688d 73989e5 564688d 73989e5 3ba1061 73989e5 e0d4d2f e9cd6fd e0d4d2f e9cd6fd 3ba1061 7bee2b4 e9cd6fd 3ba1061 c95f3e0 7bee2b4 e0d4d2f e9cd6fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor
from PIL import Image
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def process_mask(mask, target_size):
# Ensure mask is 2D
if mask.ndim > 2:
mask = mask.squeeze()
# If mask is still not 2D, take the first 2D slice
if mask.ndim > 2:
mask = mask[0]
# Convert to binary
mask = (mask > 0.5).astype(np.uint8) * 255
# Resize mask to match original image size using PIL
mask_image = Image.fromarray(mask)
mask_image = mask_image.resize(target_size, Image.NEAREST)
return np.array(mask_image) > 0
def segment_image(input_image, segment_anything):
try:
if input_image is None:
return None, "Please upload an image before submitting."
# Convert input_image to PIL Image and ensure it's RGB
input_image = Image.fromarray(input_image).convert("RGB")
# Store original size
original_size = input_image.size
if not original_size or 0 in original_size:
return None, "Invalid image size. Please upload a different image."
# Process the image
if segment_anything:
inputs = processor(input_image, return_tensors="pt").to(device)
else:
width, height = original_size
center_point = [[width // 2, height // 2]]
inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
# Generate masks
with torch.no_grad():
outputs = model(**inputs)
# Post-process masks
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
# Process the mask
if segment_anything:
combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
else:
combined_mask = masks[0][0].numpy()
combined_mask = process_mask(combined_mask, original_size)
# Overlay the mask on the original image
result_image = np.array(input_image)
mask_rgb = np.zeros_like(result_image)
mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
return result_image, "Segmentation completed successfully."
except Exception as e:
return None, f"An error occurred: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="numpy", label="Upload an image"),
gr.Checkbox(label="Segment Everything")
],
outputs=[
gr.Image(type="numpy", label="Segmented Image"),
gr.Textbox(label="Status")
],
title="Segment Anything Model (SAM) Image Segmentation",
description="Upload an image and choose whether to segment everything or use a center point."
)
# Launch the interface
iface.launch() |