SauravMaheshkar's picture
feat: use util fn from library
6447433 unverified
raw
history blame contribute delete
2.9 kB
from typing import Any, Dict
import cv2
import gradio as gr
import numpy as np
from gradio_image_annotation import image_annotator
from sam2 import load_model
from sam2.utils.visualization import show_masks
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.sam2_image_predictor import SAM2ImagePredictor
# @spaces.GPU()
def predict(model_choice, annotations: Dict[str, Any]):
sam2_model = load_model(
variant=model_choice,
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
device="cpu",
)
if annotations["boxes"]:
predictor = SAM2ImagePredictor(sam2_model) # type:ignore
predictor.set_image(annotations["image"])
coordinates = []
for i in range(len(annotations["boxes"])):
coordinate = [
int(annotations["boxes"][i]["xmin"]),
int(annotations["boxes"][i]["ymin"]),
int(annotations["boxes"][i]["xmax"]),
int(annotations["boxes"][i]["ymax"]),
]
coordinates.append(coordinate)
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=np.array(coordinates),
multimask_output=False,
)
multi_box = len(scores) > 1
return show_masks(
image=annotations["image"],
masks=masks,
scores=scores if len(scores) == 1 else None,
only_best=not multi_box,
)
else:
mask_generator = SAM2AutomaticMaskGenerator(sam2_model) # type: ignore
masks = mask_generator.generate(annotations["image"])
return show_masks(
image=annotations["image"],
masks=masks, # type: ignore
scores=None,
only_best=False,
autogenerated_mask=True
)
with gr.Blocks(delete_cache=(30, 30)) as demo:
gr.Markdown(
"""
## To read more about the Segment Anything Project please refer to the [Lightly AI blogpost](https://www.lightly.ai/post/segment-anything-model-and-friends)
"""
)
gr.Markdown(
"""
# 1. Choose Model Checkpoint
"""
)
with gr.Row():
model = gr.Dropdown(
choices=["tiny", "small", "base_plus", "large"],
value="tiny",
label="Model Checkpoint",
info="Which model checkpoint to load?",
)
gr.Markdown(
"""
# 2. Upload your Image and draw bounding box(es)
"""
)
annotator = image_annotator(
value={"image": cv2.imread("assets/example.png")},
disable_edit_boxes=True,
label="Draw a bounding box",
)
btn = gr.Button("Get Segmentation Mask(s)")
btn.click(
fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")]
)
demo.launch()