ariG23498 HF Staff commited on
Commit
dff4c96
·
1 Parent(s): afdee69

adding video logic

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. .gitignore +3 -0
  3. README.md +1 -1
  4. app.py +423 -126
  5. requirements.txt +4 -1
  6. video.mp4 +3 -0
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .ruff_cache
2
+ .venv
3
+ static
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: D-Fine Object Detection
3
  emoji: 🌖
4
  colorFrom: red
5
  colorTo: indigo
 
1
  ---
2
+ title: Real Time Object Detection wtih D-Fine
3
  emoji: 🌖
4
  colorFrom: red
5
  colorTo: indigo
app.py CHANGED
@@ -1,45 +1,103 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  from transformers import pipeline
3
  from transformers.image_utils import load_image
 
4
 
5
- checkpoints = [
6
- 'ustc-community/dfine_n_coco',
7
- 'ustc-community/dfine_s_coco',
8
- 'ustc-community/dfine_m_coco',
9
- 'ustc-community/dfine_l_coco',
10
- 'ustc-community/dfine_x_coco',
11
- 'ustc-community/dfine_s_obj365',
12
- 'ustc-community/dfine_m_obj365',
13
- 'ustc-community/dfine_l_obj365',
14
- 'ustc-community/dfine_x_obj365',
15
- 'ustc-community/dfine_s_obj2coco',
16
- 'ustc-community/dfine_m_obj2coco',
17
- 'ustc-community/dfine_l_obj2coco_e25',
18
- 'ustc-community/dfine_x_obj2coco',
 
19
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def detect_objects(image, checkpoint, confidence_threshold=0.3, use_url=False, url=""):
22
- pipe = pipeline(
23
- "object-detection",
24
- model=checkpoint,
25
- image_processor=checkpoint,
26
- device="cpu",
27
- )
28
 
 
 
 
 
 
 
 
 
 
 
29
  if use_url and url:
30
- input_image = load_image(url)
 
 
 
 
 
 
31
  elif image is not None:
 
 
 
32
  input_image = image
33
  else:
34
- return None, gr.Markdown("**Error**: Please provide an image or URL.", visible=True)
 
 
35
 
36
- # Run detection
37
- results = pipe(input_image, threshold=confidence_threshold)
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Get image dimensions for validation
40
  img_width, img_height = input_image.size
41
 
42
- # Prepare annotations in the format: list of (bounding_box, label)
43
  annotations = []
44
  for result in results:
45
  score = result["score"]
@@ -47,107 +105,315 @@ def detect_objects(image, checkpoint, confidence_threshold=0.3, use_url=False, u
47
  continue
48
  label = f"{result['label']} ({score:.2f})"
49
  box = result["box"]
50
- # Validate and convert box to (x1, y1, x2, y2)
51
- x1 = max(0, int(box["xmin"]))
52
- y1 = max(0, int(box["ymin"]))
53
- x2 = min(img_width, int(box["xmax"]))
54
- y2 = min(img_height, int(box["ymax"]))
55
- # Ensure valid box
56
- if x2 <= x1 or y2 <= y1:
57
  continue
58
- bounding_box = (x1, y1, x2, y2)
59
  annotations.append((bounding_box, label))
60
 
61
- # Handle empty annotations
62
  if not annotations:
63
  return (input_image, []), gr.Markdown(
64
  "**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.",
65
- visible=True
66
  )
67
 
68
- # Return base image and annotations
69
  return (input_image, annotations), gr.Markdown(visible=False)
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Gradio interface
72
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
73
  gr.Markdown(
74
  """
75
  # Real-Time Object Detection Demo
76
- Experience state-of-the-art object detection with USTC's Dfine models. Upload an image, provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time!
77
-
78
- **Instructions**:
79
- - Upload an image or enter a URL.
80
- - Choose a model checkpoint from the dropdown.
81
- - Adjust the confidence threshold (0.1 to 1.0).
82
- - Click "Detect Objects" to view results, or select an example.
83
- - Use "Clear" to reset inputs and outputs.
84
  """,
85
- elem_classes="header-text"
86
  )
87
-
88
- with gr.Row():
89
- with gr.Column(scale=1, min_width=300):
90
- with gr.Group():
91
- image_input = gr.Image(
92
- label="Upload Image",
93
- type="pil",
94
- sources=["upload", "webcam"],
95
- interactive=True,
96
- elem_classes="input-component",
97
- )
98
- use_url = gr.Checkbox(label="Use Image URL Instead", value=False)
99
- url_input = gr.Textbox(
100
- label="Image URL",
101
- placeholder="https://example.com/image.jpg",
102
- visible=False,
103
- elem_classes="input-component",
104
- )
105
- checkpoint = gr.Dropdown(
106
- choices=checkpoints,
107
- label="Select Model Checkpoint",
108
- value=checkpoints[0],
109
- elem_classes="input-component",
110
- )
111
- confidence_threshold = gr.Slider(
112
- minimum=0.1,
113
- maximum=1.0,
114
- value=0.3,
115
- step=0.1,
116
- label="Confidence Threshold",
117
- elem_classes="input-component",
118
- )
119
- with gr.Row():
120
- detect_button = gr.Button(
121
- "Detect Objects",
122
- variant="primary",
123
- elem_classes="action-button",
124
  )
125
- clear_button = gr.Button(
126
- "Clear",
127
- variant="secondary",
128
- elem_classes="action-button",
129
  )
130
-
131
- with gr.Column(scale=2):
132
- output_annotated = gr.AnnotatedImage(
133
- label="Detection Results",
134
- show_label=True,
135
- color_map=None, # Let Gradio assign colors
136
- elem_classes="output-component",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  )
138
- error_message = gr.Markdown(visible=False, elem_classes="error-text")
139
-
140
- gr.Examples(
141
- examples=[
142
- ["./image.jpg", False, "", checkpoints[0], 0.3],
143
- [None, True, "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", checkpoints[0], 0.3],
144
- ],
145
- inputs=[image_input, use_url, url_input, checkpoint, confidence_threshold],
146
- outputs=[output_annotated, error_message],
147
- fn=detect_objects,
148
- cache_examples=False, # Avoid caching due to model size
149
- label="Select an example to run the model",
150
- )
151
 
152
  # Dynamic visibility for URL input
153
  use_url.change(
@@ -156,34 +422,65 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
156
  outputs=url_input,
157
  )
158
 
159
- # Clear button functionality
160
- clear_button.click(
161
  fn=lambda: (
162
- None, # image_input
163
- False, # use_url
164
- "", # url_input
165
- checkpoints[0], # checkpoint
166
- 0.3, # confidence_threshold
167
- None, # output_annotated
168
- gr.Markdown(visible=False), # error_message
169
  ),
170
  outputs=[
171
  image_input,
172
  use_url,
173
  url_input,
174
- checkpoint,
175
- confidence_threshold,
176
- output_annotated,
177
- error_message,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  ],
179
  )
180
 
181
- # Detect button event
182
- detect_button.click(
183
  fn=detect_objects,
184
- inputs=[image_input, checkpoint, confidence_threshold, use_url, url_input],
185
- outputs=[output_annotated, error_message],
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  )
187
 
188
  if __name__ == "__main__":
189
- demo.launch()
 
1
+ import logging
2
+ import os
3
+ from typing import Tuple, List, Optional
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile
7
+ import numpy as np
8
+ import cv2
9
  import gradio as gr
10
+ from PIL import Image
11
  from transformers import pipeline
12
  from transformers.image_utils import load_image
13
+ import tqdm
14
 
15
+ # Configuration constants
16
+ CHECKPOINTS = [
17
+ "ustc-community/dfine_m_obj365",
18
+ "ustc-community/dfine_n_coco",
19
+ "ustc-community/dfine_s_coco",
20
+ "ustc-community/dfine_m_coco",
21
+ "ustc-community/dfine_l_coco",
22
+ "ustc-community/dfine_x_coco",
23
+ "ustc-community/dfine_s_obj365",
24
+ "ustc-community/dfine_l_obj365",
25
+ "ustc-community/dfine_x_obj365",
26
+ "ustc-community/dfine_s_obj2coco",
27
+ "ustc-community/dfine_m_obj2coco",
28
+ "ustc-community/dfine_l_obj2coco_e25",
29
+ "ustc-community/dfine_x_obj2coco",
30
  ]
31
+ MAX_NUM_FRAMES = 300
32
+ DEFAULT_CHECKPOINT = CHECKPOINTS[0]
33
+ DEFAULT_CONFIDENCE_THRESHOLD = 0.3
34
+ IMAGE_EXAMPLES = [
35
+ {"path": "./image.jpg", "use_url": False, "url": "", "label": "Local Image"},
36
+ {
37
+ "path": None,
38
+ "use_url": True,
39
+ "url": "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg",
40
+ "label": "Flickr Image",
41
+ },
42
+ ]
43
+ VIDEO_EXAMPLES = [
44
+ {"path": "./video.mp4", "label": "Local Video"},
45
+ ]
46
+ ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"}
47
+
48
+ logging.basicConfig(
49
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
50
+ )
51
+ logger = logging.getLogger(__name__)
52
+
53
+ VIDEO_OUTPUT_DIR = Path("static/videos")
54
+ VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
55
 
 
 
 
 
 
 
 
56
 
57
+ def detect_objects(
58
+ image: Optional[Image.Image],
59
+ checkpoint: str,
60
+ confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
61
+ use_url: bool = False,
62
+ url: str = "",
63
+ ) -> Tuple[
64
+ Optional[Tuple[Image.Image, List[Tuple[Tuple[int, int, int, int], str]]]],
65
+ gr.Markdown,
66
+ ]:
67
  if use_url and url:
68
+ try:
69
+ input_image = load_image(url)
70
+ except Exception as e:
71
+ logger.error(f"Failed to load image from URL {url}: {str(e)}")
72
+ return None, gr.Markdown(
73
+ f"**Error**: Failed to load image from URL: {str(e)}", visible=True
74
+ )
75
  elif image is not None:
76
+ if not isinstance(image, Image.Image):
77
+ logger.error("Input image is not a PIL Image")
78
+ return None, gr.Markdown("**Error**: Invalid image format.", visible=True)
79
  input_image = image
80
  else:
81
+ return None, gr.Markdown(
82
+ "**Error**: Please provide an image or URL.", visible=True
83
+ )
84
 
85
+ try:
86
+ pipe = pipeline(
87
+ "object-detection",
88
+ model=checkpoint,
89
+ image_processor=checkpoint,
90
+ device="cpu",
91
+ )
92
+ except Exception as e:
93
+ logger.error(f"Failed to initialize model pipeline for {checkpoint}: {str(e)}")
94
+ return None, gr.Markdown(
95
+ f"**Error**: Failed to load model: {str(e)}", visible=True
96
+ )
97
 
98
+ results = pipe(input_image, threshold=confidence_threshold)
99
  img_width, img_height = input_image.size
100
 
 
101
  annotations = []
102
  for result in results:
103
  score = result["score"]
 
105
  continue
106
  label = f"{result['label']} ({score:.2f})"
107
  box = result["box"]
108
+ # Validate and convert box to (xmin, ymin, xmax, ymax)
109
+ bbox_xmin = max(0, int(box["xmin"]))
110
+ bbox_ymin = max(0, int(box["ymin"]))
111
+ bbox_xmax = min(img_width, int(box["xmax"]))
112
+ bbox_ymax = min(img_height, int(box["ymax"]))
113
+ if bbox_xmax <= bbox_xmin or bbox_ymax <= bbox_ymin:
 
114
  continue
115
+ bounding_box = (bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax)
116
  annotations.append((bounding_box, label))
117
 
 
118
  if not annotations:
119
  return (input_image, []), gr.Markdown(
120
  "**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.",
121
+ visible=True,
122
  )
123
 
 
124
  return (input_image, annotations), gr.Markdown(visible=False)
125
 
126
+
127
+ def annotate_frame(
128
+ image: Image.Image, annotations: List[Tuple[Tuple[int, int, int, int], str]]
129
+ ) -> np.ndarray:
130
+ image_np = np.array(image)
131
+ image_bgr = image_np[:, :, ::-1].copy() # RGB to BGR
132
+
133
+ for (xmin, ymin, xmax, ymax), label in annotations:
134
+ cv2.rectangle(image_bgr, (xmin, ymin), (xmax, ymax), (255, 255, 255), 2)
135
+ text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
136
+ cv2.rectangle(
137
+ image_bgr,
138
+ (xmin, ymin - text_size[1] - 4),
139
+ (xmin + text_size[0], ymin),
140
+ (255, 255, 255),
141
+ -1,
142
+ )
143
+ cv2.putText(
144
+ image_bgr,
145
+ label,
146
+ (xmin, ymin - 4),
147
+ cv2.FONT_HERSHEY_SIMPLEX,
148
+ 0.5,
149
+ (0, 0, 0),
150
+ 1,
151
+ )
152
+
153
+ return image_bgr
154
+
155
+
156
+ def process_video(
157
+ video_path: str,
158
+ checkpoint: str,
159
+ confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
160
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
161
+ ) -> Tuple[Optional[str], gr.Markdown]:
162
+ if not video_path or not os.path.isfile(video_path):
163
+ logger.error(f"Invalid video path: {video_path}")
164
+ return None, gr.Markdown(
165
+ "**Error**: Please provide a valid video file.", visible=True
166
+ )
167
+
168
+ ext = os.path.splitext(video_path)[1].lower()
169
+ if ext not in ALLOWED_VIDEO_EXTENSIONS:
170
+ logger.error(f"Unsupported video format: {ext}")
171
+ return None, gr.Markdown(
172
+ f"**Error**: Unsupported video format. Use MP4, AVI, or MOV.", visible=True
173
+ )
174
+
175
+ try:
176
+ cap = cv2.VideoCapture(video_path)
177
+ if not cap.isOpened():
178
+ logger.error(f"Failed to open video: {video_path}")
179
+ return None, gr.Markdown(
180
+ "**Error**: Failed to open video file.", visible=True
181
+ )
182
+
183
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
184
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
185
+ fps = cap.get(cv2.CAP_PROP_FPS)
186
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
187
+
188
+ # Use H.264 codec for browser compatibility
189
+ fourcc = cv2.VideoWriter_fourcc(*"H264")
190
+ temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
191
+ writer = cv2.VideoWriter(temp_file.name, fourcc, fps, (width, height))
192
+ if not writer.isOpened():
193
+ logger.error("Failed to initialize video writer")
194
+ cap.release()
195
+ temp_file.close()
196
+ os.unlink(temp_file.name)
197
+ return None, gr.Markdown(
198
+ "**Error**: Failed to initialize video writer.", visible=True
199
+ )
200
+
201
+ frame_count = 0
202
+ for _ in tqdm.tqdm(
203
+ range(min(MAX_NUM_FRAMES, num_frames)), desc="Processing video"
204
+ ):
205
+ ok, frame = cap.read()
206
+ if not ok:
207
+ break
208
+ rgb_frame = frame[:, :, ::-1] # BGR to RGB
209
+ pil_image = Image.fromarray(rgb_frame)
210
+ (annotated_image, annotations), _ = detect_objects(
211
+ pil_image, checkpoint, confidence_threshold, use_url=False, url=""
212
+ )
213
+ if annotated_image is None:
214
+ continue
215
+ annotated_frame = annotate_frame(annotated_image, annotations)
216
+ writer.write(annotated_frame)
217
+ frame_count += 1
218
+
219
+ writer.release()
220
+ cap.release()
221
+
222
+ if frame_count == 0:
223
+ logger.warning("No valid frames processed in video")
224
+ temp_file.close()
225
+ os.unlink(temp_file.name)
226
+ return None, gr.Markdown(
227
+ "**Warning**: No valid frames processed. Try a different video or threshold.",
228
+ visible=True,
229
+ )
230
+
231
+ temp_file.close()
232
+
233
+ # Copy to persistent directory for Gradio access
234
+ output_filename = f"output_{os.path.basename(temp_file.name)}"
235
+ output_path = VIDEO_OUTPUT_DIR / output_filename
236
+ shutil.copy(temp_file.name, output_path)
237
+ os.unlink(temp_file.name) # Remove temporary file
238
+ logger.info(f"Video saved to {output_path}")
239
+
240
+ return str(output_path), gr.Markdown(visible=False)
241
+
242
+ except Exception as e:
243
+ logger.error(f"Video processing failed: {str(e)}")
244
+ if "temp_file" in locals():
245
+ temp_file.close()
246
+ if os.path.exists(temp_file.name):
247
+ os.unlink(temp_file.name)
248
+ return None, gr.Markdown(
249
+ f"**Error**: Video processing failed: {str(e)}", visible=True
250
+ )
251
+
252
+
253
+ def create_image_inputs() -> List[gr.components.Component]:
254
+ return [
255
+ gr.Image(
256
+ label="Upload Image",
257
+ type="pil",
258
+ sources=["upload", "webcam"],
259
+ interactive=True,
260
+ elem_classes="input-component",
261
+ ),
262
+ gr.Checkbox(label="Use Image URL Instead", value=False),
263
+ gr.Textbox(
264
+ label="Image URL",
265
+ placeholder="https://example.com/image.jpg",
266
+ visible=False,
267
+ elem_classes="input-component",
268
+ ),
269
+ gr.Dropdown(
270
+ choices=CHECKPOINTS,
271
+ label="Select Model Checkpoint",
272
+ value=DEFAULT_CHECKPOINT,
273
+ elem_classes="input-component",
274
+ ),
275
+ gr.Slider(
276
+ minimum=0.1,
277
+ maximum=1.0,
278
+ value=DEFAULT_CONFIDENCE_THRESHOLD,
279
+ step=0.1,
280
+ label="Confidence Threshold",
281
+ elem_classes="input-component",
282
+ ),
283
+ ]
284
+
285
+
286
+ def create_video_inputs() -> List[gr.components.Component]:
287
+ return [
288
+ gr.Video(
289
+ label="Upload Video",
290
+ sources=["upload"],
291
+ interactive=True,
292
+ format="mp4", # Ensure MP4 format
293
+ elem_classes="input-component",
294
+ ),
295
+ gr.Dropdown(
296
+ choices=CHECKPOINTS,
297
+ label="Select Model Checkpoint",
298
+ value=DEFAULT_CHECKPOINT,
299
+ elem_classes="input-component",
300
+ ),
301
+ gr.Slider(
302
+ minimum=0.1,
303
+ maximum=1.0,
304
+ value=DEFAULT_CONFIDENCE_THRESHOLD,
305
+ step=0.1,
306
+ label="Confidence Threshold",
307
+ elem_classes="input-component",
308
+ ),
309
+ ]
310
+
311
+
312
+ def create_button_row(is_image: bool) -> List[gr.Button]:
313
+ prefix = "Image" if is_image else "Video"
314
+ return [
315
+ gr.Button(
316
+ f"{prefix} Detect Objects", variant="primary", elem_classes="action-button"
317
+ ),
318
+ gr.Button(f"{prefix} Clear", variant="secondary", elem_classes="action-button"),
319
+ ]
320
+
321
+
322
  # Gradio interface
323
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
324
  gr.Markdown(
325
  """
326
  # Real-Time Object Detection Demo
327
+ Experience state-of-the-art object detection with USTC's Dfine models. Upload an image or video,
328
+ provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time!
 
 
 
 
 
 
329
  """,
330
+ elem_classes="header-text",
331
  )
332
+
333
+ with gr.Tabs():
334
+ with gr.Tab("Image"):
335
+ with gr.Row():
336
+ with gr.Column(scale=1, min_width=300):
337
+ with gr.Group():
338
+ (
339
+ image_input,
340
+ use_url,
341
+ url_input,
342
+ image_checkpoint,
343
+ image_confidence_threshold,
344
+ ) = create_image_inputs()
345
+ image_detect_button, image_clear_button = create_button_row(
346
+ is_image=True
347
+ )
348
+ with gr.Column(scale=2):
349
+ image_output = gr.AnnotatedImage(
350
+ label="Detection Results",
351
+ show_label=True,
352
+ color_map=None,
353
+ elem_classes="output-component",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  )
355
+ image_error_message = gr.Markdown(
356
+ visible=False, elem_classes="error-text"
 
 
357
  )
358
+
359
+ gr.Examples(
360
+ examples=[
361
+ [
362
+ example["path"],
363
+ example["use_url"],
364
+ example["url"],
365
+ DEFAULT_CHECKPOINT,
366
+ DEFAULT_CONFIDENCE_THRESHOLD,
367
+ ]
368
+ for example in IMAGE_EXAMPLES
369
+ ],
370
+ inputs=[
371
+ image_input,
372
+ use_url,
373
+ url_input,
374
+ image_checkpoint,
375
+ image_confidence_threshold,
376
+ ],
377
+ outputs=[image_output, image_error_message],
378
+ fn=detect_objects,
379
+ cache_examples=False,
380
+ label="Select an image example to populate inputs",
381
+ )
382
+
383
+ with gr.Tab("Video"):
384
+ gr.Markdown(
385
+ f"The input video will be truncated to {MAX_NUM_FRAMES} frames."
386
+ )
387
+ with gr.Row():
388
+ with gr.Column(scale=1, min_width=300):
389
+ with gr.Group():
390
+ video_input, video_checkpoint, video_confidence_threshold = (
391
+ create_video_inputs()
392
+ )
393
+ video_detect_button, video_clear_button = create_button_row(
394
+ is_image=False
395
+ )
396
+ with gr.Column(scale=2):
397
+ video_output = gr.Video(
398
+ label="Detection Results",
399
+ format="mp4", # Explicit MP4 format
400
+ elem_classes="output-component",
401
+ )
402
+ video_error_message = gr.Markdown(
403
+ visible=False, elem_classes="error-text"
404
+ )
405
+
406
+ gr.Examples(
407
+ examples=[
408
+ [example["path"], DEFAULT_CHECKPOINT, DEFAULT_CONFIDENCE_THRESHOLD]
409
+ for example in VIDEO_EXAMPLES
410
+ ],
411
+ inputs=[video_input, video_checkpoint, video_confidence_threshold],
412
+ outputs=[video_output, video_error_message],
413
+ fn=process_video,
414
+ cache_examples=False,
415
+ label="Select a video example to populate inputs",
416
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
  # Dynamic visibility for URL input
419
  use_url.change(
 
422
  outputs=url_input,
423
  )
424
 
425
+ # Image clear button
426
+ image_clear_button.click(
427
  fn=lambda: (
428
+ None,
429
+ False,
430
+ "",
431
+ DEFAULT_CHECKPOINT,
432
+ DEFAULT_CONFIDENCE_THRESHOLD,
433
+ None,
434
+ gr.Markdown(visible=False),
435
  ),
436
  outputs=[
437
  image_input,
438
  use_url,
439
  url_input,
440
+ image_checkpoint,
441
+ image_confidence_threshold,
442
+ image_output,
443
+ image_error_message,
444
+ ],
445
+ )
446
+
447
+ # Video clear button
448
+ video_clear_button.click(
449
+ fn=lambda: (
450
+ None,
451
+ DEFAULT_CHECKPOINT,
452
+ DEFAULT_CONFIDENCE_THRESHOLD,
453
+ None,
454
+ gr.Markdown(visible=False),
455
+ ),
456
+ outputs=[
457
+ video_input,
458
+ video_checkpoint,
459
+ video_confidence_threshold,
460
+ video_output,
461
+ video_error_message,
462
  ],
463
  )
464
 
465
+ # Image detect button
466
+ image_detect_button.click(
467
  fn=detect_objects,
468
+ inputs=[
469
+ image_input,
470
+ image_checkpoint,
471
+ image_confidence_threshold,
472
+ use_url,
473
+ url_input,
474
+ ],
475
+ outputs=[image_output, image_error_message],
476
+ )
477
+
478
+ # Video detect button
479
+ video_detect_button.click(
480
+ fn=process_video,
481
+ inputs=[video_input, video_checkpoint, video_confidence_threshold],
482
+ outputs=[video_output, video_error_message],
483
  )
484
 
485
  if __name__ == "__main__":
486
+ demo.queue(max_size=20).launch()
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  gradio
2
  transformers @ git+https://github.com/huggingface/transformers
3
  torch
4
- torchvision
 
 
 
 
1
  gradio
2
  transformers @ git+https://github.com/huggingface/transformers
3
  torch
4
+ torchvision
5
+ opencv-python
6
+ tqdm
7
+ pillow
video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:747f9c2f9d19e4955603e1a13b69663187882d4c6a8fbcad18ddbd04ee792d4d
3
+ size 1972564