wuhp commited on
Commit
edd3af7
·
verified ·
1 Parent(s): 9503f8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -45
app.py CHANGED
@@ -5,18 +5,19 @@ Gradio app to compare object‑detection models:
5
  • Roboflow RF‑DETR (Base, Large)
6
  • Custom fine‑tuned checkpoints for either framework (upload .pt/.pth files)
7
 
8
- Python ≥3.9
9
- Install:
10
- pip install -r requirements.txt
11
- Optionally, add GPU‑specific PyTorch wheels or `rfdetr[onnxexport]` for ONNX export.
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  import time
17
  from pathlib import Path
18
- from typing import List, Tuple, Optional
19
 
 
20
  import numpy as np
21
  from PIL import Image
22
  import gradio as gr
@@ -30,7 +31,7 @@ from rfdetr.util.coco_classes import COCO_CLASSES
30
  ###############################################################################
31
 
32
  YOLO_MODEL_MAP = {
33
- # Names follow Ultralytics hub convention; they will be auto‑downloaded
34
  "YOLOv12‑n": "yolov12n.pt",
35
  "YOLOv12‑s": "yolov12s.pt",
36
  "YOLOv12‑m": "yolov12m.pt",
@@ -53,21 +54,16 @@ ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [
53
  "Custom RF‑DETR (.pth)",
54
  ]
55
 
56
- _loaded = {} # cache of already‑instantiated models
57
 
58
  def load_model(choice: str, custom_file: Optional[Path] = None):
59
- """Return (and cache) a model instance for *choice*.
60
- custom_file is a Path object (uploaded file) used when choice is custom.
61
- Raises RuntimeError with helpful message if loading fails.
62
- """
63
- global _loaded
64
  if choice in _loaded:
65
  return _loaded[choice]
66
 
67
  try:
68
  if choice in YOLO_MODEL_MAP:
69
- weight_id = YOLO_MODEL_MAP[choice]
70
- mdl = YOLO(weight_id) # Ultralytics downloads if not local
71
  elif choice in RFDETR_MODEL_MAP:
72
  mdl = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
73
  elif choice.startswith("Custom YOLO"):
@@ -80,43 +76,47 @@ def load_model(choice: str, custom_file: Optional[Path] = None):
80
  mdl = RFDETRBase(pretrain_weights=str(custom_file))
81
  else:
82
  raise ValueError(f"Unsupported model choice: {choice}")
83
- except FileNotFoundError as e:
84
- raise RuntimeError(
85
- f"Weights for '{choice}' not found locally and could not be downloaded. "
86
- "Place the .pt file in the working directory, supply a custom checkpoint, "
87
- "or ensure the model is released on the Ultralytics hub.\n" + str(e)
88
- ) from e
89
 
90
  _loaded[choice] = mdl
91
  return mdl
92
 
93
  ###############################################################################
94
- # Inference helpers
95
  ###############################################################################
96
 
97
- box_annotator = sv.BoxAnnotator()
98
  label_annotator = sv.LabelAnnotator()
99
 
 
 
 
 
100
  def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]:
101
  start = time.perf_counter()
102
 
103
  if isinstance(model, (RFDETRBase, RFDETRLarge)):
104
  detections = model.predict(image, threshold=threshold)
105
  label_source = COCO_CLASSES
106
- else: # Ultralytics YOLO
107
  result = model.predict(image, verbose=False)[0]
108
  detections = sv.Detections.from_ultralytics(result)
109
  label_source = model.names
110
 
111
  runtime = time.perf_counter() - start
112
 
113
- labels = [f"{label_source[cid]} {conf:.2f}" for cid, conf in zip(detections.class_id, detections.confidence)]
114
- annotated = box_annotator.annotate(image.copy(), detections)
115
- annotated = label_annotator.annotate(annotated, detections, labels)
116
- return annotated, runtime
 
 
 
 
117
 
118
  ###############################################################################
119
- # Gradio UI logic
120
  ###############################################################################
121
 
122
  def compare_models(models: List[str], img: Image.Image, threshold: float, custom_file: Optional[Path]):
@@ -125,45 +125,44 @@ def compare_models(models: List[str], img: Image.Image, threshold: float, custom
125
  if img.mode != "RGB":
126
  img = img.convert("RGB")
127
 
128
- results, legends = [], []
 
 
129
  for m in models:
130
  try:
131
  model_obj = load_model(m, custom_file)
132
  annotated, t = run_single_inference(model_obj, img, threshold)
133
  results.append(annotated)
134
- legends.append(f"{m} {t*1000:.1f} ms")
135
  except Exception as e:
136
- # Append a blank image with the error message overlayed
137
- error_img = Image.new("RGB", img.size, color=(30, 30, 30))
138
- legends.append(f"{m} – ERROR: {e}")
139
- results.append(error_img)
 
140
  return results, legends
141
 
142
  ###############################################################################
143
- # Build & launch demo
144
  ###############################################################################
145
 
146
  def build_demo():
147
  with gr.Blocks(title="CV Model Comparison") as demo:
148
- gr.Markdown("""# 🔍 Compare Object‑Detection Models\nUpload an image, select detectors, and optionally upload a custom checkpoint.\nThe app annotates predictions and reports per‑model latency.""")
149
 
150
  with gr.Row():
151
  model_select = gr.CheckboxGroup(choices=ALL_MODELS, value=["YOLOv12‑n"], label="Select models")
152
- threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence threshold")
153
 
154
- custom_checkpoint = gr.File(label="Upload custom YOLO / RF‑DETR checkpoint", file_types=[".pt", ".pth"], interactive=True)
155
- image_in = gr.Image(type="pil", label="Upload image", sources=["upload", "webcam"], show_label=True)
156
 
157
  with gr.Row():
158
  gallery = gr.Gallery(label="Annotated results", columns=2, height="auto")
159
- legends_out = gr.JSON(label="Runtime (ms) or error messages")
160
 
161
  run_btn = gr.Button("Run Inference", variant="primary")
162
- run_btn.click(
163
- fn=compare_models,
164
- inputs=[model_select, image_in, threshold_slider, custom_checkpoint],
165
- outputs=[gallery, legends_out],
166
- )
167
 
168
  return demo
169
 
 
5
  • Roboflow RF‑DETR (Base, Large)
6
  • Custom fine‑tuned checkpoints for either framework (upload .pt/.pth files)
7
 
8
+ Changes in this revision (2025‑04‑19):
9
+ • Thinner, semi‑transparent bounding boxes for better visibility in crowded scenes.
10
+ Legend now shows a clean dict of runtimes (or concise errors) instead of auto‑indexed JSON.
11
+ File uploader is fully integrated for custom checkpoints.
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  import time
17
  from pathlib import Path
18
+ from typing import List, Tuple, Dict, Optional
19
 
20
+ import cv2
21
  import numpy as np
22
  from PIL import Image
23
  import gradio as gr
 
31
  ###############################################################################
32
 
33
  YOLO_MODEL_MAP = {
34
+ # Ultralytics hub IDs downloaded on first use
35
  "YOLOv12‑n": "yolov12n.pt",
36
  "YOLOv12‑s": "yolov12s.pt",
37
  "YOLOv12‑m": "yolov12m.pt",
 
54
  "Custom RF‑DETR (.pth)",
55
  ]
56
 
57
+ _loaded: Dict[str, object] = {}
58
 
59
  def load_model(choice: str, custom_file: Optional[Path] = None):
60
+ """Lazy‑load and cache a detector. Returns a model instance or raises RuntimeError."""
 
 
 
 
61
  if choice in _loaded:
62
  return _loaded[choice]
63
 
64
  try:
65
  if choice in YOLO_MODEL_MAP:
66
+ mdl = YOLO(YOLO_MODEL_MAP[choice]) # hub download if needed
 
67
  elif choice in RFDETR_MODEL_MAP:
68
  mdl = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
69
  elif choice.startswith("Custom YOLO"):
 
76
  mdl = RFDETRBase(pretrain_weights=str(custom_file))
77
  else:
78
  raise ValueError(f"Unsupported model choice: {choice}")
79
+ except Exception as e:
80
+ raise RuntimeError(str(e)) from e
 
 
 
 
81
 
82
  _loaded[choice] = mdl
83
  return mdl
84
 
85
  ###############################################################################
86
+ # Inference helpers — semi‑transparent, thin boxes
87
  ###############################################################################
88
 
89
+ box_annotator = sv.BoxAnnotator(thickness=2) # thinner lines
90
  label_annotator = sv.LabelAnnotator()
91
 
92
+ def blend_overlay(base_np: np.ndarray, overlay_np: np.ndarray, alpha: float = 0.6) -> np.ndarray:
93
+ """Blend two BGR images with given alpha for overlay."""
94
+ return cv2.addWeighted(overlay_np, alpha, base_np, 1 - alpha, 0)
95
+
96
  def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]:
97
  start = time.perf_counter()
98
 
99
  if isinstance(model, (RFDETRBase, RFDETRLarge)):
100
  detections = model.predict(image, threshold=threshold)
101
  label_source = COCO_CLASSES
102
+ else:
103
  result = model.predict(image, verbose=False)[0]
104
  detections = sv.Detections.from_ultralytics(result)
105
  label_source = model.names
106
 
107
  runtime = time.perf_counter() - start
108
 
109
+ img_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
110
+ overlay = img_np.copy()
111
+ overlay = box_annotator.annotate(overlay, detections)
112
+ overlay = label_annotator.annotate(overlay, detections, [f"{label_source[c]} {p:.2f}" for c, p in zip(detections.class_id, detections.confidence)])
113
+ blended = blend_overlay(img_np, overlay, alpha=0.6) # semi‑transparent boxes
114
+ annotated_pil = Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB))
115
+
116
+ return annotated_pil, runtime
117
 
118
  ###############################################################################
119
+ # Gradio callback
120
  ###############################################################################
121
 
122
  def compare_models(models: List[str], img: Image.Image, threshold: float, custom_file: Optional[Path]):
 
125
  if img.mode != "RGB":
126
  img = img.convert("RGB")
127
 
128
+ results: List[Image.Image] = []
129
+ legends: Dict[str, str] = {}
130
+
131
  for m in models:
132
  try:
133
  model_obj = load_model(m, custom_file)
134
  annotated, t = run_single_inference(model_obj, img, threshold)
135
  results.append(annotated)
136
+ legends[m] = f"{t*1000:.1f} ms"
137
  except Exception as e:
138
+ # show blank slate if model unavailable
139
+ results.append(Image.new("RGB", img.size, (40, 40, 40)))
140
+ err = str(e).split("\n")[0][:120] # shorten
141
+ legends[m] = f"ERROR: {err}"
142
+
143
  return results, legends
144
 
145
  ###############################################################################
146
+ # Build & launch Gradio UI
147
  ###############################################################################
148
 
149
  def build_demo():
150
  with gr.Blocks(title="CV Model Comparison") as demo:
151
+ gr.Markdown("""# 🔍 Compare Object‑Detection Models\nUpload an image, choose detectors, and optionally add a custom checkpoint.\nBounding boxes are thin and 60 % opaque for clarity.""")
152
 
153
  with gr.Row():
154
  model_select = gr.CheckboxGroup(choices=ALL_MODELS, value=["YOLOv12‑n"], label="Select models")
155
+ threshold_slider = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Confidence threshold")
156
 
157
+ custom_checkpoint = gr.File(label="Upload custom checkpoint (.pt/.pth)", file_types=[".pt", ".pth"], interactive=True)
158
+ image_in = gr.Image(type="pil", label="Image", sources=["upload", "webcam"])
159
 
160
  with gr.Row():
161
  gallery = gr.Gallery(label="Annotated results", columns=2, height="auto")
162
+ legends_out = gr.JSON(label="Latency / status by model")
163
 
164
  run_btn = gr.Button("Run Inference", variant="primary")
165
+ run_btn.click(compare_models, [model_select, image_in, threshold_slider, custom_checkpoint], [gallery, legends_out])
 
 
 
 
166
 
167
  return demo
168