Segment / app.py
WensongSong's picture
Update app.py
8ac1bff verified
import os,sys
os.system("python -m pip install -e segment_anything")
os.system("python -m pip install -e GroundingDINO")
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth")
os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
import cv2
import numpy as np
import torch
import torchvision
import gradio as gr
from PIL import Image
from GroundingDINO.groundingdino.util.inference import load_model
from segment_anything import build_sam, SamPredictor
import spaces
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# GroundingDINO config and checkpoint
GROUNDING_DINO_CONFIG_PATH = "./GroundingDINO_SwinB.cfg.py"
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swinb_cogcoor.pth"
# Segment-Anything checkpoint
SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
# Building GroundingDINO inference model
groundingdino_model = load_model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=DEVICE)
# Building SAM Model and SAM Predictor
sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)
def transform_image(image_pil):
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image
def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
# get phrase
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
# build pred
pred_phrases = []
scores = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(
logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(
pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
scores.append(logit.max().item())
return boxes_filt, torch.Tensor(scores), pred_phrases
@spaces.GPU
def run_local(image, label):
global groundingdino_model, sam_predictor
image_pil = image.convert("RGB")
transformed_image = transform_image(image_pil)
boxes_filt, scores, pred_phrases = get_grounding_output(
groundingdino_model, transformed_image, label
)
size = image_pil.size
# process boxes
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu()
# nms
nms_idx = torchvision.ops.nms(
boxes_filt, scores, 0.8).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
image = np.array(image_pil)
sam_predictor.set_image(image)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
boxes_filt, image.shape[:2]).to(DEVICE)
masks, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
result_mask = masks[0][0].cpu().numpy()
result_mask = Image.fromarray(result_mask)
return [result_mask]
with gr.Blocks() as demo:
gr.Markdown("# Segment")
with gr.Row():
with gr.Column():
input_image = gr.Image(sources='upload', type="pil", height=512)
text_prompt = gr.Textbox(label="Label")
with gr.Column():
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height=512)
run_local_button = gr.Button(value="Run")
run_local_button.click(fn=run_local,
inputs=[input_image, text_prompt],
outputs=[gallery]
)
demo.launch()