developer0hye commited on
Commit
014228c
·
verified ·
1 Parent(s): 8ee296a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -63
app.py CHANGED
@@ -1,65 +1,27 @@
1
  import gradio as gr
2
 
3
  import argparse
4
- from functools import partial
5
  import cv2
6
- import requests
7
- import os
8
- from io import BytesIO
9
  from PIL import Image
10
  import numpy as np
11
- from pathlib import Path
12
 
13
 
14
  import warnings
15
-
16
  import torch
17
  warnings.filterwarnings("ignore")
18
 
19
- from groundingdino.models import build_model
20
- from groundingdino.util.slconfig import SLConfig
21
- from groundingdino.util.utils import clean_state_dict
22
- from groundingdino.util.inference import annotate, load_image, predict
23
- import groundingdino.datasets.transforms as T
24
-
25
- from huggingface_hub import hf_hub_download
26
-
27
-
28
- # Use this command for evaluate the GLIP-T model
29
- config_file = "groundingdino/config/GroundingDINO_SwinB_cfg.py"
30
- ckpt_repo_id = "ShilongLiu/GroundingDINO"
31
- ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
32
-
33
-
34
- def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
35
- args = SLConfig.fromfile(model_config_path)
36
- model = build_model(args)
37
- args.device = device
38
-
39
- cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
40
- checkpoint = torch.load(cache_file, map_location=device)
41
- log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
42
- print("Model loaded from {} \n => {}".format(cache_file, log))
43
- _ = model.eval()
44
- return model
45
-
46
- def image_transform_grounding(init_image):
47
- transform = T.Compose([
48
- T.RandomResize([800], max_size=1333),
49
- T.ToTensor(),
50
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
51
- ])
52
- image, _ = transform(init_image, None) # 3, h, w
53
- return init_image, image
54
 
55
- def image_transform_grounding_for_vis(init_image):
56
- transform = T.Compose([
57
- T.RandomResize([800], max_size=1333),
58
- ])
59
- image, _ = transform(init_image, None) # 3, h, w
60
- return image
61
 
62
- model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
 
 
 
63
 
64
  def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
65
  # Convert numpy array to PIL Image if needed
@@ -69,16 +31,86 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
69
  input_image = Image.fromarray(input_image)
70
 
71
  init_image = input_image.convert("RGB")
72
- original_size = init_image.size
73
-
74
- _, image_tensor = image_transform_grounding(init_image)
75
- image_pil: Image = image_transform_grounding_for_vis(init_image)
76
-
77
- # run grounidng
78
- boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
79
- annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
80
- image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
81
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  return image_with_box
83
 
84
  if __name__ == "__main__":
@@ -98,17 +130,16 @@ if __name__ == "__main__":
98
  with gr.Blocks(css=css) as demo:
99
  gr.Markdown("<h1><center>Grounding DINO<h1><center>")
100
  gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/IDEA-Research/GroundingDINO'>Grounding DINO</a><h3><center>")
101
- gr.Markdown("<h3><center>Running on CPU, so it may take a while to run the model.<h3><center>")
102
 
103
  with gr.Row():
104
  with gr.Column():
105
  input_image = gr.Image(label="Input Image", type="pil")
106
- grounding_caption = gr.Textbox(label="Detection Prompt")
107
  run_button = gr.Button("Run")
108
 
109
  with gr.Accordion("Advanced options", open=False):
110
  box_threshold = gr.Slider(
111
- minimum=0.0, maximum=1.0, value=0.25, step=0.001,
112
  label="Box Threshold"
113
  )
114
  text_threshold = gr.Slider(
@@ -129,11 +160,14 @@ if __name__ == "__main__":
129
  )
130
 
131
  gr.Examples(
132
- examples=[["this_is_fine.png", "coffee cup", 0.25, 0.25]],
 
 
 
133
  inputs=[input_image, grounding_caption, box_threshold, text_threshold],
134
  outputs=[gallery],
135
  fn=run_grounding,
136
  cache_examples=True,
137
  )
138
 
139
- demo.launch(share=args.share, debug=args.debug, show_error=True)
 
1
  import gradio as gr
2
 
3
  import argparse
 
4
  import cv2
 
 
 
5
  from PIL import Image
6
  import numpy as np
 
7
 
8
 
9
  import warnings
 
10
  import torch
11
  warnings.filterwarnings("ignore")
12
 
13
+ # Replace custom imports with Transformers
14
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
15
+ # Add supervision for better visualization
16
+ import supervision as sv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Model ID for Hugging Face
19
+ model_id = "IDEA-Research/grounding-dino-base"
 
 
 
 
20
 
21
+ # Load model and processor using Transformers
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ processor = AutoProcessor.from_pretrained(model_id)
24
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
25
 
26
  def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
27
  # Convert numpy array to PIL Image if needed
 
31
  input_image = Image.fromarray(input_image)
32
 
33
  init_image = input_image.convert("RGB")
34
+
35
+ # Process input using transformers
36
+ inputs = processor(images=init_image, text=grounding_caption, return_tensors="pt").to(device)
37
+
38
+ # Run inference
39
+ with torch.no_grad():
40
+ outputs = model(**inputs)
41
+
42
+ # Post-process results
43
+ results = processor.post_process_grounded_object_detection(
44
+ outputs,
45
+ inputs.input_ids,
46
+ box_threshold=box_threshold,
47
+ text_threshold=text_threshold,
48
+ target_sizes=[init_image.size[::-1]]
49
+ )
50
+
51
+ result = results[0]
52
+
53
+ # Convert image for supervision visualization
54
+ image_np = np.array(init_image)
55
+
56
+ # Create detections for supervision
57
+ boxes = []
58
+ labels = []
59
+ confidences = []
60
+ class_ids = []
61
+
62
+ for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])):
63
+ # Convert box to xyxy format
64
+ xyxy = box.tolist()
65
+ boxes.append(xyxy)
66
+ labels.append(label)
67
+ confidences.append(float(score))
68
+ class_ids.append(i) # Use index as class_id (integer)
69
+
70
+ # Create Detections object for supervision
71
+ if boxes:
72
+ detections = sv.Detections(
73
+ xyxy=np.array(boxes),
74
+ confidence=np.array(confidences),
75
+ class_id=np.array(class_ids, dtype=np.int32), # Ensure it's an integer array
76
+ )
77
+
78
+ text_scale = sv.calculate_optimal_text_scale(resolution_wh=init_image.size)
79
+ line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=init_image.size)
80
+
81
+ # Create annotators
82
+ box_annotator = sv.BoxAnnotator(
83
+ thickness=2,
84
+ color=sv.ColorPalette.DEFAULT,
85
+ )
86
+
87
+ label_annotator = sv.LabelAnnotator(
88
+ color=sv.ColorPalette.DEFAULT,
89
+ text_color=sv.Color.WHITE,
90
+ text_scale=text_scale,
91
+ text_thickness=line_thickness,
92
+ text_padding=3
93
+ )
94
+
95
+ # Create formatted labels for each detection
96
+ formatted_labels = [
97
+ f"{label}: {conf:.2f}"
98
+ for label, conf in zip(labels, confidences)
99
+ ]
100
+
101
+ # Apply annotations to the image
102
+ annotated_image = box_annotator.annotate(scene=image_np, detections=detections)
103
+ annotated_image = label_annotator.annotate(
104
+ scene=annotated_image,
105
+ detections=detections,
106
+ labels=formatted_labels
107
+ )
108
+ else:
109
+ annotated_image = image_np
110
+
111
+ # Convert back to PIL Image
112
+ image_with_box = Image.fromarray(annotated_image)
113
+
114
  return image_with_box
115
 
116
  if __name__ == "__main__":
 
130
  with gr.Blocks(css=css) as demo:
131
  gr.Markdown("<h1><center>Grounding DINO<h1><center>")
132
  gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/IDEA-Research/GroundingDINO'>Grounding DINO</a><h3><center>")
 
133
 
134
  with gr.Row():
135
  with gr.Column():
136
  input_image = gr.Image(label="Input Image", type="pil")
137
+ grounding_caption = gr.Textbox(label="Detection Prompt(VERY important: text queries need to be lowercased + end with a dot, example: a cat. a remote control.)", value="a person. a car.")
138
  run_button = gr.Button("Run")
139
 
140
  with gr.Accordion("Advanced options", open=False):
141
  box_threshold = gr.Slider(
142
+ minimum=0.0, maximum=1.0, value=0.3, step=0.001,
143
  label="Box Threshold"
144
  )
145
  text_threshold = gr.Slider(
 
160
  )
161
 
162
  gr.Examples(
163
+ examples=[
164
+ ["000000039769.jpg", "a cat. a remote control.", 0.3, 0.25],
165
+ ["KakaoTalk_20250430_163200504.jpg", "cup. screen. hand.", 0.3, 0.25]
166
+ ],
167
  inputs=[input_image, grounding_caption, box_threshold, text_threshold],
168
  outputs=[gallery],
169
  fn=run_grounding,
170
  cache_examples=True,
171
  )
172
 
173
+ demo.launch(share=args.share, debug=args.debug, show_error=True)