import os,sys os.system("python -m pip install -e segment_anything") os.system("python -m pip install -e GroundingDINO") sys.path.append(os.path.join(os.getcwd(), "GroundingDINO")) sys.path.append(os.path.join(os.getcwd(), "segment_anything")) os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth") os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth") import cv2 import numpy as np import torch import torchvision import gradio as gr from PIL import Image from GroundingDINO.groundingdino.util.inference import load_model from segment_anything import build_sam, SamPredictor import spaces import GroundingDINO.groundingdino.datasets.transforms as T from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # GroundingDINO config and checkpoint GROUNDING_DINO_CONFIG_PATH = "./GroundingDINO_SwinB.cfg.py" GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swinb_cogcoor.pth" # Segment-Anything checkpoint SAM_ENCODER_VERSION = "vit_h" SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth" # Building GroundingDINO inference model groundingdino_model = load_model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=DEVICE) # Building SAM Model and SAM Predictor sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH) sam.to(device=DEVICE) sam_predictor = SamPredictor(sam) def transform_image(image_pil): transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image, _ = transform(image_pil, None) # 3, h, w return image def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True): caption = caption.lower() caption = caption.strip() if not caption.endswith("."): caption = caption + "." with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) logits.shape[0] # filter output logits_filt = logits.clone() boxes_filt = boxes.clone() filt_mask = logits_filt.max(dim=1)[0] > box_threshold logits_filt = logits_filt[filt_mask] # num_filt, 256 boxes_filt = boxes_filt[filt_mask] # num_filt, 4 logits_filt.shape[0] # get phrase tokenlizer = model.tokenizer tokenized = tokenlizer(caption) # build pred pred_phrases = [] scores = [] for logit, box in zip(logits_filt, boxes_filt): pred_phrase = get_phrases_from_posmap( logit > text_threshold, tokenized, tokenlizer) if with_logits: pred_phrases.append( pred_phrase + f"({str(logit.max().item())[:4]})") else: pred_phrases.append(pred_phrase) scores.append(logit.max().item()) return boxes_filt, torch.Tensor(scores), pred_phrases @spaces.GPU def run_local(image, label): global groundingdino_model, sam_predictor image_pil = image.convert("RGB") transformed_image = transform_image(image_pil) boxes_filt, scores, pred_phrases = get_grounding_output( groundingdino_model, transformed_image, label ) size = image_pil.size # process boxes H, W = size[1], size[0] for i in range(boxes_filt.size(0)): boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 boxes_filt[i][2:] += boxes_filt[i][:2] boxes_filt = boxes_filt.cpu() # nms nms_idx = torchvision.ops.nms( boxes_filt, scores, 0.8).numpy().tolist() boxes_filt = boxes_filt[nms_idx] pred_phrases = [pred_phrases[idx] for idx in nms_idx] image = np.array(image_pil) sam_predictor.set_image(image) transformed_boxes = sam_predictor.transform.apply_boxes_torch( boxes_filt, image.shape[:2]).to(DEVICE) masks, _, _ = sam_predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, ) result_mask = masks[0][0].cpu().numpy() result_mask = Image.fromarray(result_mask) return [result_mask] with gr.Blocks() as demo: gr.Markdown("# Segment") with gr.Row(): with gr.Column(): input_image = gr.Image(sources='upload', type="pil", height=512) text_prompt = gr.Textbox(label="Label") with gr.Column(): gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height=512) run_local_button = gr.Button(value="Run") run_local_button.click(fn=run_local, inputs=[input_image, text_prompt], outputs=[gallery] ) demo.launch()