wuhp commited on
Commit
90903c1
·
verified ·
1 Parent(s): 766e4d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -36
app.py CHANGED
@@ -5,10 +5,10 @@ Gradio app to compare object‑detection models:
5
  • Roboflow RF‑DETR (Base, Large)
6
  • Custom fine‑tuned checkpoints for either framework (upload .pt/.pth files)
7
 
8
- Changes in this revision (2025‑04‑19):
9
- Thinner, semi‑transparent bounding boxes for better visibility in crowded scenes.
10
- Legend now shows a clean dict of runtimes (or concise errors) instead of auto‑indexed JSON.
11
- File uploader is fully integrated for custom checkpoints.
12
  """
13
 
14
  from __future__ import annotations
@@ -31,7 +31,6 @@ from rfdetr.util.coco_classes import COCO_CLASSES
31
  ###############################################################################
32
 
33
  YOLO_MODEL_MAP = {
34
- # Ultralytics hub IDs — downloaded on first use
35
  "YOLOv12‑n": "yolov12n.pt",
36
  "YOLOv12‑s": "yolov12s.pt",
37
  "YOLOv12‑m": "yolov12m.pt",
@@ -57,13 +56,11 @@ ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [
57
  _loaded: Dict[str, object] = {}
58
 
59
  def load_model(choice: str, custom_file: Optional[Path] = None):
60
- """Lazy‑load and cache a detector. Returns a model instance or raises RuntimeError."""
61
  if choice in _loaded:
62
  return _loaded[choice]
63
-
64
  try:
65
  if choice in YOLO_MODEL_MAP:
66
- mdl = YOLO(YOLO_MODEL_MAP[choice]) # hub download if needed
67
  elif choice in RFDETR_MODEL_MAP:
68
  mdl = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
69
  elif choice.startswith("Custom YOLO"):
@@ -78,24 +75,21 @@ def load_model(choice: str, custom_file: Optional[Path] = None):
78
  raise ValueError(f"Unsupported model choice: {choice}")
79
  except Exception as e:
80
  raise RuntimeError(str(e)) from e
81
-
82
  _loaded[choice] = mdl
83
  return mdl
84
 
85
  ###############################################################################
86
- # Inference helpers — semi‑transparent, thin boxes
87
  ###############################################################################
88
 
89
- box_annotator = sv.BoxAnnotator(thickness=2) # thinner lines
90
  label_annotator = sv.LabelAnnotator()
91
 
92
  def blend_overlay(base_np: np.ndarray, overlay_np: np.ndarray, alpha: float = 0.6) -> np.ndarray:
93
- """Blend two BGR images with given alpha for overlay."""
94
  return cv2.addWeighted(overlay_np, alpha, base_np, 1 - alpha, 0)
95
 
96
  def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]:
97
  start = time.perf_counter()
98
-
99
  if isinstance(model, (RFDETRBase, RFDETRLarge)):
100
  detections = model.predict(image, threshold=threshold)
101
  label_source = COCO_CLASSES
@@ -103,17 +97,17 @@ def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[I
103
  result = model.predict(image, verbose=False)[0]
104
  detections = sv.Detections.from_ultralytics(result)
105
  label_source = model.names
106
-
107
  runtime = time.perf_counter() - start
108
-
109
  img_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
110
  overlay = img_np.copy()
111
  overlay = box_annotator.annotate(overlay, detections)
112
- overlay = label_annotator.annotate(overlay, detections, [f"{label_source[c]} {p:.2f}" for c, p in zip(detections.class_id, detections.confidence)])
113
- blended = blend_overlay(img_np, overlay, alpha=0.6) # semi‑transparent boxes
114
- annotated_pil = Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB))
115
-
116
- return annotated_pil, runtime
 
 
117
 
118
  ###############################################################################
119
  # Gradio callback
@@ -135,11 +129,9 @@ def compare_models(models: List[str], img: Image.Image, threshold: float, custom
135
  results.append(annotated)
136
  legends[m] = f"{t*1000:.1f} ms"
137
  except Exception as e:
138
- # show blank slate if model unavailable
139
  results.append(Image.new("RGB", img.size, (40, 40, 40)))
140
- err_msg = str(e)
141
- # Normalize common weight‑missing errors for clarity
142
- if "No such file or directory" in err_msg:
143
  legends[m] = "Unavailable (weights not found)"
144
  else:
145
  legends[m] = f"ERROR: {err_msg.splitlines()[0][:120]}"
@@ -147,27 +139,25 @@ def compare_models(models: List[str], img: Image.Image, threshold: float, custom
147
  return results, legends
148
 
149
  ###############################################################################
150
- # Build & launch Gradio UI
151
  ###############################################################################
152
 
153
  def build_demo():
154
  with gr.Blocks(title="CV Model Comparison") as demo:
155
- gr.Markdown("""# 🔍 Compare Object‑Detection Models\nUpload an image, choose detectors, and optionally add a custom checkpoint.\nBounding boxes are thin and 60 % opaque for clarity.""")
156
-
 
157
  with gr.Row():
158
- model_select = gr.CheckboxGroup(choices=ALL_MODELS, value=["YOLOv12‑n"], label="Select models")
159
- threshold_slider = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Confidence threshold")
160
-
161
- custom_checkpoint = gr.File(label="Upload custom checkpoint (.pt/.pth)", file_types=[".pt", ".pth"], interactive=True)
162
  image_in = gr.Image(type="pil", label="Image", sources=["upload", "webcam"])
163
-
164
  with gr.Row():
165
  gallery = gr.Gallery(label="Annotated results", columns=2, height="auto")
166
  legends_out = gr.JSON(label="Latency / status by model")
167
-
168
- run_btn = gr.Button("Run Inference", variant="primary")
169
- run_btn.click(compare_models, [model_select, image_in, threshold_slider, custom_checkpoint], [gallery, legends_out])
170
-
171
  return demo
172
 
173
  if __name__ == "__main__":
 
5
  • Roboflow RF‑DETR (Base, Large)
6
  • Custom fine‑tuned checkpoints for either framework (upload .pt/.pth files)
7
 
8
+ Revision 2025‑04‑19‑b:
9
+ Fixed indentation error in error‑handling branch.
10
+ Unavailable weights report as concise messages.
11
+ Bounding boxes: 2 px, 60 % opacity.
12
  """
13
 
14
  from __future__ import annotations
 
31
  ###############################################################################
32
 
33
  YOLO_MODEL_MAP = {
 
34
  "YOLOv12‑n": "yolov12n.pt",
35
  "YOLOv12‑s": "yolov12s.pt",
36
  "YOLOv12‑m": "yolov12m.pt",
 
56
  _loaded: Dict[str, object] = {}
57
 
58
  def load_model(choice: str, custom_file: Optional[Path] = None):
 
59
  if choice in _loaded:
60
  return _loaded[choice]
 
61
  try:
62
  if choice in YOLO_MODEL_MAP:
63
+ mdl = YOLO(YOLO_MODEL_MAP[choice])
64
  elif choice in RFDETR_MODEL_MAP:
65
  mdl = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
66
  elif choice.startswith("Custom YOLO"):
 
75
  raise ValueError(f"Unsupported model choice: {choice}")
76
  except Exception as e:
77
  raise RuntimeError(str(e)) from e
 
78
  _loaded[choice] = mdl
79
  return mdl
80
 
81
  ###############################################################################
82
+ # Inference helpers
83
  ###############################################################################
84
 
85
+ box_annotator = sv.BoxAnnotator(thickness=2)
86
  label_annotator = sv.LabelAnnotator()
87
 
88
  def blend_overlay(base_np: np.ndarray, overlay_np: np.ndarray, alpha: float = 0.6) -> np.ndarray:
 
89
  return cv2.addWeighted(overlay_np, alpha, base_np, 1 - alpha, 0)
90
 
91
  def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]:
92
  start = time.perf_counter()
 
93
  if isinstance(model, (RFDETRBase, RFDETRLarge)):
94
  detections = model.predict(image, threshold=threshold)
95
  label_source = COCO_CLASSES
 
97
  result = model.predict(image, verbose=False)[0]
98
  detections = sv.Detections.from_ultralytics(result)
99
  label_source = model.names
 
100
  runtime = time.perf_counter() - start
 
101
  img_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
102
  overlay = img_np.copy()
103
  overlay = box_annotator.annotate(overlay, detections)
104
+ overlay = label_annotator.annotate(
105
+ overlay,
106
+ detections,
107
+ [f"{label_source[c]} {p:.2f}" for c, p in zip(detections.class_id, detections.confidence)],
108
+ )
109
+ blended = blend_overlay(img_np, overlay, alpha=0.6)
110
+ return Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)), runtime
111
 
112
  ###############################################################################
113
  # Gradio callback
 
129
  results.append(annotated)
130
  legends[m] = f"{t*1000:.1f} ms"
131
  except Exception as e:
 
132
  results.append(Image.new("RGB", img.size, (40, 40, 40)))
133
+ err_msg = str(e)
134
+ if "No such file" in err_msg or "not found" in err_msg:
 
135
  legends[m] = "Unavailable (weights not found)"
136
  else:
137
  legends[m] = f"ERROR: {err_msg.splitlines()[0][:120]}"
 
139
  return results, legends
140
 
141
  ###############################################################################
142
+ # UI
143
  ###############################################################################
144
 
145
  def build_demo():
146
  with gr.Blocks(title="CV Model Comparison") as demo:
147
+ gr.Markdown(
148
+ """# 🔍 Compare Object‑Detection Models\nUpload an image, choose detectors, and optionally add a custom checkpoint.\nBounding boxes are thin and 60 % opaque for clarity."""
149
+ )
150
  with gr.Row():
151
+ model_select = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Select models")
152
+ threshold_slider = gr.Slider(0.0, 1.0, 0.5, 0.05, label="Confidence threshold")
153
+ custom_file = gr.File(label="Upload custom checkpoint (.pt/.pth)", file_types=[".pt", ".pth"], interactive=True)
 
154
  image_in = gr.Image(type="pil", label="Image", sources=["upload", "webcam"])
 
155
  with gr.Row():
156
  gallery = gr.Gallery(label="Annotated results", columns=2, height="auto")
157
  legends_out = gr.JSON(label="Latency / status by model")
158
+ gr.Button("Run Inference", variant="primary").click(
159
+ compare_models, [model_select, image_in, threshold_slider, custom_file], [gallery, legends_out]
160
+ )
 
161
  return demo
162
 
163
  if __name__ == "__main__":