import gradio as gr import spaces import argparse import cv2 from PIL import Image import numpy as np import warnings import torch warnings.filterwarnings("ignore") # Replace custom imports with Transformers from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection # Add supervision for better visualization import supervision as sv # Model ID for Hugging Face model_id = "IDEA-Research/grounding-dino-base" # Load model and processor using Transformers device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) @spaces.GPU def run_grounding(input_image, grounding_caption, box_threshold, text_threshold): # Convert numpy array to PIL Image if needed if isinstance(input_image, np.ndarray): if input_image.ndim == 3: input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) input_image = Image.fromarray(input_image) init_image = input_image.convert("RGB") # Process input using transformers inputs = processor(images=init_image, text=grounding_caption, return_tensors="pt").to(device) # Run inference with torch.no_grad(): outputs = model(**inputs) # Post-process results results = processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=box_threshold, text_threshold=text_threshold, target_sizes=[init_image.size[::-1]] ) result = results[0] # Convert image for supervision visualization image_np = np.array(init_image) # Create detections for supervision boxes = [] labels = [] confidences = [] class_ids = [] for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])): # Convert box to xyxy format xyxy = box.tolist() boxes.append(xyxy) labels.append(label) confidences.append(float(score)) class_ids.append(i) # Use index as class_id (integer) # Create Detections object for supervision if boxes: detections = sv.Detections( xyxy=np.array(boxes), confidence=np.array(confidences), class_id=np.array(class_ids, dtype=np.int32), # Ensure it's an integer array ) text_scale = sv.calculate_optimal_text_scale(resolution_wh=init_image.size) line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=init_image.size) # Create annotators box_annotator = sv.BoxAnnotator( thickness=2, color=sv.ColorPalette.DEFAULT, ) label_annotator = sv.LabelAnnotator( color=sv.ColorPalette.DEFAULT, text_color=sv.Color.WHITE, text_scale=text_scale, text_thickness=line_thickness, text_padding=3 ) # Create formatted labels for each detection formatted_labels = [ f"{label}: {conf:.2f}" for label, conf in zip(labels, confidences) ] # Apply annotations to the image annotated_image = box_annotator.annotate(scene=image_np, detections=detections) annotated_image = label_annotator.annotate( scene=annotated_image, detections=detections, labels=formatted_labels ) else: annotated_image = image_np # Convert back to PIL Image image_with_box = Image.fromarray(annotated_image) return image_with_box if __name__ == "__main__": parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True) parser.add_argument("--debug", action="store_true", help="using debug mode") parser.add_argument("--share", action="store_true", help="share the app") args = parser.parse_args() css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.Markdown("

Grounding DINO Base

") gr.Markdown("

Open-World Detection with Grounding DINO

") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil") grounding_caption = gr.Textbox(label="Detection Prompt(VERY important: text queries need to be lowercased + end with a dot, example: a cat. a remote control.)", value="a person. a car.") run_button = gr.Button("Run") with gr.Accordion("Advanced options", open=False): box_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.3, step=0.001, label="Box Threshold" ) text_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.25, step=0.001, label="Text Threshold" ) with gr.Column(): gallery = gr.Image( label="Detection Result", type="pil" ) run_button.click( fn=run_grounding, inputs=[input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery] ) gr.Examples( examples=[ ["000000039769.jpg", "a cat. a remote control.", 0.3, 0.25], ["KakaoTalk_20250430_163200504.jpg", "cup. screen. hand.", 0.3, 0.25] ], inputs=[input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery], fn=run_grounding, cache_examples=True, ) demo.launch(share=args.share, debug=args.debug, show_error=True)