File size: 1,901 Bytes
6f359a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
from gradio_image_prompter import ImagePrompter
import spaces

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(device)
slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")

@spaces.GPU
def sam_box_inference(image, x_min, y_min, x_max, y_max):
    inputs = slimsam_processor(
        Image.fromarray(image),
        input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = slimsam_model(**inputs)

    mask = slimsam_processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )[0][0][0].numpy()
    mask = mask[np.newaxis, ...]
    print(mask)
    print(mask.shape)
    return [(mask, "mask")]

def infer_box(prompts):
    image = prompts["image"]
    if image is None:
        gr.Error("Please upload an image and draw a box before submitting.")
    points = prompts["points"][0]
    if points is None:
        gr.Error("Please draw a box before submitting.")
    print(points)
    return [(image, sam_box_inference(image, points[0], points[1], points[3], points[4]))]

with gr.Blocks(title="SlimSAM Box Prompt") as demo:
    gr.Markdown("# SlimSAM Box Prompt")
    gr.Markdown("In this demo, you can upload an image and draw a box for SlimSAM to process.")

    with gr.Row():
        with gr.Column():
            im = ImagePrompter()
            btn = gr.Button("Submit")
        with gr.Column():
            output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")

    btn.click(infer_box, inputs=im, outputs=[output_box_slimsam])

demo.launch(debug=True)