wuhp commited on
Commit
abc9620
·
verified ·
1 Parent(s): 90903c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -50
app.py CHANGED
@@ -3,19 +3,18 @@ Gradio app to compare object‑detection models:
3
  • Ultralytics YOLOv12 (n, s, m, l, x)
4
  • Ultralytics YOLOv11 (n, s, m, l, x)
5
  • Roboflow RF‑DETR (Base, Large)
6
- • Custom fine‑tuned checkpoints for either framework (upload .pt/.pth files)
7
 
8
- Revision 2025‑04‑19‑b:
9
- Fixed indentation error in errorhandling branch.
10
- Unavailable weights report as concise messages.
11
- • Bounding boxes: 2 px, 60 % opacity.
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  import time
17
  from pathlib import Path
18
- from typing import List, Tuple, Dict, Optional
19
 
20
  import cv2
21
  import numpy as np
@@ -30,7 +29,7 @@ from rfdetr.util.coco_classes import COCO_CLASSES
30
  # Model registry & lazy loader
31
  ###############################################################################
32
 
33
- YOLO_MODEL_MAP = {
34
  "YOLOv12‑n": "yolov12n.pt",
35
  "YOLOv12‑s": "yolov12s.pt",
36
  "YOLOv12‑m": "yolov12m.pt",
@@ -56,64 +55,81 @@ ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [
56
  _loaded: Dict[str, object] = {}
57
 
58
  def load_model(choice: str, custom_file: Optional[Path] = None):
 
59
  if choice in _loaded:
60
  return _loaded[choice]
 
61
  try:
62
  if choice in YOLO_MODEL_MAP:
63
- mdl = YOLO(YOLO_MODEL_MAP[choice])
64
  elif choice in RFDETR_MODEL_MAP:
65
- mdl = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
66
  elif choice.startswith("Custom YOLO"):
67
- if not custom_file:
68
  raise ValueError("Upload a YOLO .pt/.pth checkpoint first.")
69
- mdl = YOLO(str(custom_file))
70
  elif choice.startswith("Custom RF‑DETR"):
71
- if not custom_file:
72
  raise ValueError("Upload an RF‑DETR .pth checkpoint first.")
73
- mdl = RFDETRBase(pretrain_weights=str(custom_file))
74
  else:
75
  raise ValueError(f"Unsupported model choice: {choice}")
76
- except Exception as e:
77
- raise RuntimeError(str(e)) from e
78
- _loaded[choice] = mdl
79
- return mdl
 
80
 
81
  ###############################################################################
82
  # Inference helpers
83
  ###############################################################################
84
 
85
- box_annotator = sv.BoxAnnotator(thickness=2)
 
 
 
86
  label_annotator = sv.LabelAnnotator()
87
 
88
- def blend_overlay(base_np: np.ndarray, overlay_np: np.ndarray, alpha: float = 0.6) -> np.ndarray:
89
- return cv2.addWeighted(overlay_np, alpha, base_np, 1 - alpha, 0)
90
 
91
  def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]:
92
  start = time.perf_counter()
 
93
  if isinstance(model, (RFDETRBase, RFDETRLarge)):
94
  detections = model.predict(image, threshold=threshold)
95
- label_source = COCO_CLASSES
96
  else:
97
- result = model.predict(image, verbose=False)[0]
98
- detections = sv.Detections.from_ultralytics(result)
99
- label_source = model.names
 
100
  runtime = time.perf_counter() - start
101
- img_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
102
- overlay = img_np.copy()
 
 
103
  overlay = box_annotator.annotate(overlay, detections)
104
  overlay = label_annotator.annotate(
105
  overlay,
106
  detections,
107
- [f"{label_source[c]} {p:.2f}" for c, p in zip(detections.class_id, detections.confidence)],
108
  )
109
- blended = blend_overlay(img_np, overlay, alpha=0.6)
110
- return Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)), runtime
 
 
111
 
112
  ###############################################################################
113
  # Gradio callback
114
  ###############################################################################
115
 
116
- def compare_models(models: List[str], img: Image.Image, threshold: float, custom_file: Optional[Path]):
 
 
 
 
 
117
  if img is None:
118
  raise gr.Error("Please upload an image first.")
119
  if img.mode != "RGB":
@@ -122,42 +138,46 @@ def compare_models(models: List[str], img: Image.Image, threshold: float, custom
122
  results: List[Image.Image] = []
123
  legends: Dict[str, str] = {}
124
 
125
- for m in models:
126
  try:
127
- model_obj = load_model(m, custom_file)
128
- annotated, t = run_single_inference(model_obj, img, threshold)
129
  results.append(annotated)
130
- legends[m] = f"{t*1000:.1f} ms"
131
- except Exception as e:
132
  results.append(Image.new("RGB", img.size, (40, 40, 40)))
133
- err_msg = str(e)
134
- if "No such file" in err_msg or "not found" in err_msg:
135
- legends[m] = "Unavailable (weights not found)"
136
  else:
137
- legends[m] = f"ERROR: {err_msg.splitlines()[0][:120]}"
138
 
139
  return results, legends
140
 
141
  ###############################################################################
142
- # UI
143
  ###############################################################################
144
 
145
  def build_demo():
146
  with gr.Blocks(title="CV Model Comparison") as demo:
147
  gr.Markdown(
148
- """# 🔍 Compare Object‑Detection Models\nUpload an image, choose detectors, and optionally add a custom checkpoint.\nBounding boxes are thin and 60 % opaque for clarity."""
149
  )
 
150
  with gr.Row():
151
- model_select = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Select models")
152
- threshold_slider = gr.Slider(0.0, 1.0, 0.5, 0.05, label="Confidence threshold")
153
- custom_file = gr.File(label="Upload custom checkpoint (.pt/.pth)", file_types=[".pt", ".pth"], interactive=True)
154
- image_in = gr.Image(type="pil", label="Image", sources=["upload", "webcam"])
 
 
155
  with gr.Row():
156
- gallery = gr.Gallery(label="Annotated results", columns=2, height="auto")
157
- legends_out = gr.JSON(label="Latency / status by model")
158
- gr.Button("Run Inference", variant="primary").click(
159
- compare_models, [model_select, image_in, threshold_slider, custom_file], [gallery, legends_out]
160
- )
 
161
  return demo
162
 
163
  if __name__ == "__main__":
 
3
  • Ultralytics YOLOv12 (n, s, m, l, x)
4
  • Ultralytics YOLOv11 (n, s, m, l, x)
5
  • Roboflow RF‑DETR (Base, Large)
6
+ • Custom fine‑tuned checkpoints (.pt/.pth upload)
7
 
8
+ Revision 2025‑04‑19‑c:
9
+ Re‑indented entire file with 4space consistency to remove `IndentationError`.
10
+ Thin, semi‑transparent 60 % boxes; concise error labels.
 
11
  """
12
 
13
  from __future__ import annotations
14
 
15
  import time
16
  from pathlib import Path
17
+ from typing import Dict, List, Optional, Tuple
18
 
19
  import cv2
20
  import numpy as np
 
29
  # Model registry & lazy loader
30
  ###############################################################################
31
 
32
+ YOLO_MODEL_MAP: Dict[str, str] = {
33
  "YOLOv12‑n": "yolov12n.pt",
34
  "YOLOv12‑s": "yolov12s.pt",
35
  "YOLOv12‑m": "yolov12m.pt",
 
55
  _loaded: Dict[str, object] = {}
56
 
57
  def load_model(choice: str, custom_file: Optional[Path] = None):
58
+ """Return and cache a detector matching *choice*."""
59
  if choice in _loaded:
60
  return _loaded[choice]
61
+
62
  try:
63
  if choice in YOLO_MODEL_MAP:
64
+ model = YOLO(YOLO_MODEL_MAP[choice])
65
  elif choice in RFDETR_MODEL_MAP:
66
+ model = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
67
  elif choice.startswith("Custom YOLO"):
68
+ if custom_file is None:
69
  raise ValueError("Upload a YOLO .pt/.pth checkpoint first.")
70
+ model = YOLO(str(custom_file))
71
  elif choice.startswith("Custom RF‑DETR"):
72
+ if custom_file is None:
73
  raise ValueError("Upload an RF‑DETR .pth checkpoint first.")
74
+ model = RFDETRBase(pretrain_weights=str(custom_file))
75
  else:
76
  raise ValueError(f"Unsupported model choice: {choice}")
77
+ except Exception as exc:
78
+ raise RuntimeError(str(exc)) from exc
79
+
80
+ _loaded[choice] = model
81
+ return model
82
 
83
  ###############################################################################
84
  # Inference helpers
85
  ###############################################################################
86
 
87
+ BOX_THICKNESS = 2
88
+ BOX_ALPHA = 0.6
89
+
90
+ box_annotator = sv.BoxAnnotator(thickness=BOX_THICKNESS)
91
  label_annotator = sv.LabelAnnotator()
92
 
93
+ def _blend(base: np.ndarray, overlay: np.ndarray, alpha: float = BOX_ALPHA) -> np.ndarray:
94
+ return cv2.addWeighted(overlay, alpha, base, 1 - alpha, 0)
95
 
96
  def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]:
97
  start = time.perf_counter()
98
+
99
  if isinstance(model, (RFDETRBase, RFDETRLarge)):
100
  detections = model.predict(image, threshold=threshold)
101
+ label_src = COCO_CLASSES
102
  else:
103
+ ul_result = model.predict(image, verbose=False)[0]
104
+ detections = sv.Detections.from_ultralytics(ul_result)
105
+ label_src = model.names # type: ignore
106
+
107
  runtime = time.perf_counter() - start
108
+
109
+ base_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
110
+ overlay = base_bgr.copy()
111
+
112
  overlay = box_annotator.annotate(overlay, detections)
113
  overlay = label_annotator.annotate(
114
  overlay,
115
  detections,
116
+ [f"{label_src[cid]} {conf:.2f}" for cid, conf in zip(detections.class_id, detections.confidence)],
117
  )
118
+
119
+ blended = _blend(base_bgr, overlay)
120
+ out_pil = Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB))
121
+ return out_pil, runtime
122
 
123
  ###############################################################################
124
  # Gradio callback
125
  ###############################################################################
126
 
127
+ def compare_models(
128
+ models: List[str],
129
+ img: Image.Image,
130
+ threshold: float,
131
+ custom_file: Optional[Path],
132
+ ):
133
  if img is None:
134
  raise gr.Error("Please upload an image first.")
135
  if img.mode != "RGB":
 
138
  results: List[Image.Image] = []
139
  legends: Dict[str, str] = {}
140
 
141
+ for model_name in models:
142
  try:
143
+ detector = load_model(model_name, custom_file)
144
+ annotated, latency = run_single_inference(detector, img, threshold)
145
  results.append(annotated)
146
+ legends[model_name] = f"{latency*1000:.1f} ms"
147
+ except Exception as exc:
148
  results.append(Image.new("RGB", img.size, (40, 40, 40)))
149
+ emsg = str(exc)
150
+ if "No such file" in emsg or "not found" in emsg:
151
+ legends[model_name] = "Unavailable (weights not found)"
152
  else:
153
+ legends[model_name] = f"ERROR: {emsg.splitlines()[0][:120]}"
154
 
155
  return results, legends
156
 
157
  ###############################################################################
158
+ # Gradio UI
159
  ###############################################################################
160
 
161
  def build_demo():
162
  with gr.Blocks(title="CV Model Comparison") as demo:
163
  gr.Markdown(
164
+ """# 🔍 Compare Object‑Detection Models\nUpload an image, choose detectors, and optionally add a custom checkpoint.\nBounding boxes are thin (2 px) and 60 % transparent for clarity."""
165
  )
166
+
167
  with gr.Row():
168
+ sel_models = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Models")
169
+ conf_slider = gr.Slider(0.0, 1.0, 0.5, 0.05, label="Confidence")
170
+
171
+ ckpt_file = gr.File(label="Custom checkpoint (.pt/.pth)", file_types=[".pt", ".pth"], interactive=True)
172
+ img_in = gr.Image(type="pil", label="Image", sources=["upload", "webcam"])
173
+
174
  with gr.Row():
175
+ gallery = gr.Gallery(label="Results", columns=2, height="auto")
176
+ legend_out = gr.JSON(label="Latency / status by model")
177
+
178
+ run_btn = gr.Button("Run Inference", variant="primary")
179
+ run_btn.click(compare_models, [sel_models, img_in, conf_slider, ckpt_file], [gallery, legend_out])
180
+
181
  return demo
182
 
183
  if __name__ == "__main__":