Segment / app.py
Harshithtd's picture
Create app.py
6f359a3 verified
raw
history blame
1.9 kB
import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
from gradio_image_prompter import ImagePrompter
import spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(device)
slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
@spaces.GPU
def sam_box_inference(image, x_min, y_min, x_max, y_max):
inputs = slimsam_processor(
Image.fromarray(image),
input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = slimsam_model(**inputs)
mask = slimsam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask = mask[np.newaxis, ...]
print(mask)
print(mask.shape)
return [(mask, "mask")]
def infer_box(prompts):
image = prompts["image"]
if image is None:
gr.Error("Please upload an image and draw a box before submitting.")
points = prompts["points"][0]
if points is None:
gr.Error("Please draw a box before submitting.")
print(points)
return [(image, sam_box_inference(image, points[0], points[1], points[3], points[4]))]
with gr.Blocks(title="SlimSAM Box Prompt") as demo:
gr.Markdown("# SlimSAM Box Prompt")
gr.Markdown("In this demo, you can upload an image and draw a box for SlimSAM to process.")
with gr.Row():
with gr.Column():
im = ImagePrompter()
btn = gr.Button("Submit")
with gr.Column():
output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
btn.click(infer_box, inputs=im, outputs=[output_box_slimsam])
demo.launch(debug=True)