Spaces:
Sleeping
Sleeping
import pathlib | |
import zipfile | |
from typing import Any, Dict, List | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import torch | |
from gradio_image_annotation import image_annotator | |
from sam2.build_sam import build_sam2 | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
from src.plot_utils import render_masks | |
choice_mapping: Dict[str, List[str]] = { | |
"tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"], | |
"small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"], | |
"base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"], | |
"large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"], | |
} | |
def predict(model_choice, annotations: Dict[str, Any]): | |
config_file, ckpt_path = choice_mapping[str(model_choice)] | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
sam2_model = build_sam2(config_file, ckpt_path, device=device) | |
predictor = SAM2ImagePredictor(sam2_model) | |
predictor.set_image(annotations["image"]) | |
coordinates = [] | |
for i in range(len(annotations["boxes"])): | |
coordinate = [ | |
int(annotations["boxes"][i]["xmin"]), | |
int(annotations["boxes"][i]["ymin"]), | |
int(annotations["boxes"][i]["xmax"]), | |
int(annotations["boxes"][i]["ymax"]), | |
] | |
coordinates.append(coordinate) | |
masks, scores, _ = predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=np.array(coordinates), | |
multimask_output=False, | |
) | |
for count, mask in enumerate(masks): | |
mask = mask.transpose(1, 2, 0) # type:ignore | |
mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format | |
cv2.imwrite(f"assets/mask_{count}.png", mask_image) | |
mask_dir = pathlib.Path("assets/") | |
with zipfile.ZipFile("assets/masks.zip", "w") as archive: | |
for mask_file in mask_dir.glob("mask_*.png"): | |
archive.write(mask_file, arcname=mask_file.relative_to(mask_dir)) | |
return [ | |
render_masks(annotations["image"], masks), | |
gr.DownloadButton("Download Mask(s)", value="assets/masks.zip", visible=True), | |
] | |
with gr.Blocks(delete_cache=(30, 30)) as demo: | |
gr.Markdown( | |
""" | |
# 1. Choose Model Checkpoint | |
""" | |
) | |
with gr.Row(): | |
model = gr.Dropdown( | |
choices=["tiny", "small", "base_plus", "large"], | |
value="tiny", | |
label="Model Checkpoint", | |
info="Which model checkpoint to load?", | |
) | |
gr.Markdown( | |
""" | |
# 2. Upload your Image and draw bounding box(es) | |
""" | |
) | |
annotator = image_annotator( | |
value={"image": cv2.imread("assets/example.png")}, | |
disable_edit_boxes=True, | |
label="Draw a bounding box", | |
) | |
btn = gr.Button("Get Segmentation Mask(s)") | |
download_btn = gr.DownloadButton( | |
"Download Mask(s)", value="assets/masks.zip", visible=False | |
) | |
btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn]) | |
demo.launch() | |