wuhp commited on
Commit
edc92bb
·
verified ·
1 Parent(s): a0fdf82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -33
app.py CHANGED
@@ -5,10 +5,10 @@ Gradio app to compare object‑detection models:
5
  • Roboflow RF‑DETR (Base, Large)
6
  • Custom fine‑tuned checkpoints (.pt/.pth upload)
7
 
8
- Revision 2025‑04‑19‑d:
9
- Pre‑loads all selected models before running detections, with a visible progress bar.
10
- Progress shows two phases: *Loading weights* and *Running inference*.
11
- Keeps thin, semi‑transparent boxes and concise error labels.
12
  """
13
 
14
  from __future__ import annotations
@@ -31,7 +31,7 @@ from rfdetr.util.coco_classes import COCO_CLASSES
31
  ###############################################################################
32
 
33
  YOLO_MODEL_MAP: Dict[str, str] = {
34
- # NOTE: Ultralytics filenames do NOT include the "v" character
35
  "YOLOv12‑n": "yolo12n.pt",
36
  "YOLOv12‑s": "yolo12s.pt",
37
  "YOLOv12‑m": "yolo12m.pt",
@@ -44,7 +44,6 @@ YOLO_MODEL_MAP: Dict[str, str] = {
44
  "YOLOv11‑x": "yolo11x.pt",
45
  }
46
 
47
-
48
  RFDETR_MODEL_MAP = {
49
  "RF‑DETR‑Base (29M)": "base",
50
  "RF‑DETR‑Large (128M)": "large",
@@ -58,12 +57,11 @@ ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [
58
  _loaded: Dict[str, object] = {}
59
 
60
  def load_model(choice: str, custom_file: Optional[Path] = None):
61
- """Fetch and cache a detector instance for *choice*."""
62
  if choice in _loaded:
63
  return _loaded[choice]
64
 
65
  if choice in YOLO_MODEL_MAP:
66
- model = YOLO(YOLO_MODEL_MAP[choice]) # Ultralytics auto‑downloads if missing
67
  elif choice in RFDETR_MODEL_MAP:
68
  model = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
69
  elif choice.startswith("Custom YOLO"):
@@ -84,8 +82,8 @@ def load_model(choice: str, custom_file: Optional[Path] = None):
84
  # Inference helpers
85
  ###############################################################################
86
 
87
- BOX_THICKNESS = 2 # thinner boxes
88
- BOX_ALPHA = 0.6 # 60 % opacity
89
 
90
  box_annotator = sv.BoxAnnotator(thickness=BOX_THICKNESS)
91
  label_annotator = sv.LabelAnnotator()
@@ -118,7 +116,7 @@ def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[I
118
  return Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)), runtime
119
 
120
  ###############################################################################
121
- # Gradio generator callback with progress
122
  ###############################################################################
123
 
124
  def compare_models(
@@ -132,50 +130,54 @@ def compare_models(
132
  if img.mode != "RGB":
133
  img = img.convert("RGB")
134
 
135
- total_steps = len(models) * 2 # phase 1: load, phase 2: inference
136
  progress = gr.Progress()
137
 
138
- # ----- Phase 1: preload weights -----
139
  detectors: Dict[str, object] = {}
140
  for i, name in enumerate(models, 1):
141
  try:
142
  detectors[name] = load_model(name, custom_file)
143
  except Exception as exc:
144
- detectors[name] = exc # store exception for later reporting
145
  progress(i, total=total_steps, desc=f"Loading {name}")
146
 
147
- # ----- Phase 2: run inference -----
148
- results: List[Image.Image] = []
149
  legends: Dict[str, str] = {}
150
 
151
  for j, name in enumerate(models, 1):
152
- detector_or_err = detectors[name]
153
- step_index = len(models) + j
154
- if isinstance(detector_or_err, Exception):
155
- # model failed to load
156
- results.append(Image.new("RGB", img.size, (40, 40, 40)))
157
- emsg = str(detector_or_err)
158
- legends[name] = "Unavailable (weights not found)" if "No such file" in emsg or "not found" in emsg else f"ERROR: {emsg.splitlines()[0][:120]}"
159
- progress(step_index, total=total_steps, desc=f"Skipped {name}")
 
160
  continue
161
  try:
162
- annotated, latency = run_single_inference(detector_or_err, img, threshold)
163
- results.append(annotated)
 
164
  legends[name] = f"{latency*1000:.1f} ms"
165
  except Exception as exc:
166
- results.append(Image.new("RGB", img.size, (40, 40, 40)))
 
 
167
  legends[name] = f"ERROR: {str(exc).splitlines()[0][:120]}"
168
- progress(step_index, total=total_steps, desc=f"Inference {name}")
169
 
170
- yield results, legends # final output
171
 
172
  ###############################################################################
173
- # Gradio UI
174
  ###############################################################################
175
 
176
  def build_demo():
177
  with gr.Blocks(title="CV Model Comparison") as demo:
178
- gr.Markdown("""# 🔍 Compare Object‑Detection Models\nUpload an image, select detectors, then click **Run Inference**.\nThin, semi‑transparent boxes highlight detections.""")
 
 
179
 
180
  with gr.Row():
181
  sel_models = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Models")
@@ -188,8 +190,9 @@ def build_demo():
188
  gallery = gr.Gallery(label="Results", columns=2, height="auto")
189
  legend_out = gr.JSON(label="Latency / status by model")
190
 
191
- run_btn = gr.Button("Run Inference", variant="primary")
192
- run_btn.click(compare_models, [sel_models, img_in, conf_slider, ckpt_file], [gallery, legend_out])
 
193
 
194
  return demo
195
 
 
5
  • Roboflow RF‑DETR (Base, Large)
6
  • Custom fine‑tuned checkpoints (.pt/.pth upload)
7
 
8
+ Revision 2025‑04‑19‑e:
9
+ Gallery items now carry captions so you can see which model produced which image (and latency).
10
+ Captions display as "Model (xx ms)" or error status.
11
+ No other behaviour changed: pre‑loading, progress bar, thin semi‑transparent boxes, concise error labels.
12
  """
13
 
14
  from __future__ import annotations
 
31
  ###############################################################################
32
 
33
  YOLO_MODEL_MAP: Dict[str, str] = {
34
+ # Ultralytics filenames omit the "v"
35
  "YOLOv12‑n": "yolo12n.pt",
36
  "YOLOv12‑s": "yolo12s.pt",
37
  "YOLOv12‑m": "yolo12m.pt",
 
44
  "YOLOv11‑x": "yolo11x.pt",
45
  }
46
 
 
47
  RFDETR_MODEL_MAP = {
48
  "RF‑DETR‑Base (29M)": "base",
49
  "RF‑DETR‑Large (128M)": "large",
 
57
  _loaded: Dict[str, object] = {}
58
 
59
  def load_model(choice: str, custom_file: Optional[Path] = None):
 
60
  if choice in _loaded:
61
  return _loaded[choice]
62
 
63
  if choice in YOLO_MODEL_MAP:
64
+ model = YOLO(YOLO_MODEL_MAP[choice])
65
  elif choice in RFDETR_MODEL_MAP:
66
  model = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
67
  elif choice.startswith("Custom YOLO"):
 
82
  # Inference helpers
83
  ###############################################################################
84
 
85
+ BOX_THICKNESS = 2
86
+ BOX_ALPHA = 0.6
87
 
88
  box_annotator = sv.BoxAnnotator(thickness=BOX_THICKNESS)
89
  label_annotator = sv.LabelAnnotator()
 
116
  return Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)), runtime
117
 
118
  ###############################################################################
119
+ # Callback with progress & captions
120
  ###############################################################################
121
 
122
  def compare_models(
 
130
  if img.mode != "RGB":
131
  img = img.convert("RGB")
132
 
133
+ total_steps = len(models) * 2
134
  progress = gr.Progress()
135
 
 
136
  detectors: Dict[str, object] = {}
137
  for i, name in enumerate(models, 1):
138
  try:
139
  detectors[name] = load_model(name, custom_file)
140
  except Exception as exc:
141
+ detectors[name] = exc
142
  progress(i, total=total_steps, desc=f"Loading {name}")
143
 
144
+ results: List[Tuple[Image.Image, str]] = []
 
145
  legends: Dict[str, str] = {}
146
 
147
  for j, name in enumerate(models, 1):
148
+ item = detectors[name]
149
+ step = len(models) + j
150
+ if isinstance(item, Exception):
151
+ placeholder = Image.new("RGB", img.size, (40, 40, 40))
152
+ emsg = str(item)
153
+ caption = f"{name} – Unavailable" if "No such file" in emsg or "not found" in emsg else f"{name} – ERROR"
154
+ results.append((placeholder, caption))
155
+ legends[name] = caption
156
+ progress(step, total=total_steps, desc=f"Skipped {name}")
157
  continue
158
  try:
159
+ annotated, latency = run_single_inference(item, img, threshold)
160
+ caption = f"{name} ({latency*1000:.1f} ms)"
161
+ results.append((annotated, caption))
162
  legends[name] = f"{latency*1000:.1f} ms"
163
  except Exception as exc:
164
+ placeholder = Image.new("RGB", img.size, (40, 40, 40))
165
+ caption = f"{name} – ERROR"
166
+ results.append((placeholder, caption))
167
  legends[name] = f"ERROR: {str(exc).splitlines()[0][:120]}"
168
+ progress(step, total=total_steps, desc=f"Inference {name}")
169
 
170
+ yield results, legends
171
 
172
  ###############################################################################
173
+ # UI
174
  ###############################################################################
175
 
176
  def build_demo():
177
  with gr.Blocks(title="CV Model Comparison") as demo:
178
+ gr.Markdown(
179
+ """# 🔍 Compare Object‑Detection Models\nUpload an image, select detectors, and click **Run Inference**.\nCaptions beneath each result show which model (and latency) generated it."""
180
+ )
181
 
182
  with gr.Row():
183
  sel_models = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Models")
 
190
  gallery = gr.Gallery(label="Results", columns=2, height="auto")
191
  legend_out = gr.JSON(label="Latency / status by model")
192
 
193
+ gr.Button("Run Inference", variant="primary").click(
194
+ compare_models, [sel_models, img_in, conf_slider, ckpt_file], [gallery, legend_out]
195
+ )
196
 
197
  return demo
198