wuhp commited on
Commit
3b9517f
·
verified ·
1 Parent(s): abc9620

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -48
app.py CHANGED
@@ -5,9 +5,10 @@ Gradio app to compare object‑detection models:
5
  • Roboflow RF‑DETR (Base, Large)
6
  • Custom fine‑tuned checkpoints (.pt/.pth upload)
7
 
8
- Revision 2025‑04‑19‑c:
9
- Reindented entire file with 4‑space consistency to remove `IndentationError`.
10
- Thin, semi‑transparent 60 % boxes; concise error labels.
 
11
  """
12
 
13
  from __future__ import annotations
@@ -26,7 +27,7 @@ from rfdetr import RFDETRBase, RFDETRLarge
26
  from rfdetr.util.coco_classes import COCO_CLASSES
27
 
28
  ###############################################################################
29
- # Model registry & lazy loader
30
  ###############################################################################
31
 
32
  YOLO_MODEL_MAP: Dict[str, str] = {
@@ -55,27 +56,24 @@ ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [
55
  _loaded: Dict[str, object] = {}
56
 
57
  def load_model(choice: str, custom_file: Optional[Path] = None):
58
- """Return and cache a detector matching *choice*."""
59
  if choice in _loaded:
60
  return _loaded[choice]
61
 
62
- try:
63
- if choice in YOLO_MODEL_MAP:
64
- model = YOLO(YOLO_MODEL_MAP[choice])
65
- elif choice in RFDETR_MODEL_MAP:
66
- model = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
67
- elif choice.startswith("Custom YOLO"):
68
- if custom_file is None:
69
- raise ValueError("Upload a YOLO .pt/.pth checkpoint first.")
70
- model = YOLO(str(custom_file))
71
- elif choice.startswith("Custom RF‑DETR"):
72
- if custom_file is None:
73
- raise ValueError("Upload an RF‑DETR .pth checkpoint first.")
74
- model = RFDETRBase(pretrain_weights=str(custom_file))
75
- else:
76
- raise ValueError(f"Unsupported model choice: {choice}")
77
- except Exception as exc:
78
- raise RuntimeError(str(exc)) from exc
79
 
80
  _loaded[choice] = model
81
  return model
@@ -84,8 +82,8 @@ def load_model(choice: str, custom_file: Optional[Path] = None):
84
  # Inference helpers
85
  ###############################################################################
86
 
87
- BOX_THICKNESS = 2
88
- BOX_ALPHA = 0.6
89
 
90
  box_annotator = sv.BoxAnnotator(thickness=BOX_THICKNESS)
91
  label_annotator = sv.LabelAnnotator()
@@ -100,28 +98,25 @@ def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[I
100
  detections = model.predict(image, threshold=threshold)
101
  label_src = COCO_CLASSES
102
  else:
103
- ul_result = model.predict(image, verbose=False)[0]
104
- detections = sv.Detections.from_ultralytics(ul_result)
105
  label_src = model.names # type: ignore
106
 
107
  runtime = time.perf_counter() - start
108
 
109
- base_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
110
- overlay = base_bgr.copy()
111
-
112
  overlay = box_annotator.annotate(overlay, detections)
113
  overlay = label_annotator.annotate(
114
  overlay,
115
  detections,
116
- [f"{label_src[cid]} {conf:.2f}" for cid, conf in zip(detections.class_id, detections.confidence)],
117
  )
118
-
119
- blended = _blend(base_bgr, overlay)
120
- out_pil = Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB))
121
- return out_pil, runtime
122
 
123
  ###############################################################################
124
- # Gradio callback
125
  ###############################################################################
126
 
127
  def compare_models(
@@ -135,24 +130,42 @@ def compare_models(
135
  if img.mode != "RGB":
136
  img = img.convert("RGB")
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  results: List[Image.Image] = []
139
  legends: Dict[str, str] = {}
140
 
141
- for model_name in models:
 
 
 
 
 
 
 
 
 
142
  try:
143
- detector = load_model(model_name, custom_file)
144
- annotated, latency = run_single_inference(detector, img, threshold)
145
  results.append(annotated)
146
- legends[model_name] = f"{latency*1000:.1f} ms"
147
  except Exception as exc:
148
  results.append(Image.new("RGB", img.size, (40, 40, 40)))
149
- emsg = str(exc)
150
- if "No such file" in emsg or "not found" in emsg:
151
- legends[model_name] = "Unavailable (weights not found)"
152
- else:
153
- legends[model_name] = f"ERROR: {emsg.splitlines()[0][:120]}"
154
 
155
- return results, legends
156
 
157
  ###############################################################################
158
  # Gradio UI
@@ -160,9 +173,7 @@ def compare_models(
160
 
161
  def build_demo():
162
  with gr.Blocks(title="CV Model Comparison") as demo:
163
- gr.Markdown(
164
- """# 🔍 Compare Object‑Detection Models\nUpload an image, choose detectors, and optionally add a custom checkpoint.\nBounding boxes are thin (2 px) and 60 % transparent for clarity."""
165
- )
166
 
167
  with gr.Row():
168
  sel_models = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Models")
 
5
  • Roboflow RF‑DETR (Base, Large)
6
  • Custom fine‑tuned checkpoints (.pt/.pth upload)
7
 
8
+ Revision 2025‑04‑19‑d:
9
+ Preloads all selected models before running detections, with a visible progress bar.
10
+ Progress shows two phases: *Loading weights* and *Running inference*.
11
+ • Keeps thin, semi‑transparent boxes and concise error labels.
12
  """
13
 
14
  from __future__ import annotations
 
27
  from rfdetr.util.coco_classes import COCO_CLASSES
28
 
29
  ###############################################################################
30
+ # Model registry & cache
31
  ###############################################################################
32
 
33
  YOLO_MODEL_MAP: Dict[str, str] = {
 
56
  _loaded: Dict[str, object] = {}
57
 
58
  def load_model(choice: str, custom_file: Optional[Path] = None):
59
+ """Fetch and cache a detector instance for *choice*."""
60
  if choice in _loaded:
61
  return _loaded[choice]
62
 
63
+ if choice in YOLO_MODEL_MAP:
64
+ model = YOLO(YOLO_MODEL_MAP[choice]) # Ultralytics auto‑downloads if missing
65
+ elif choice in RFDETR_MODEL_MAP:
66
+ model = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
67
+ elif choice.startswith("Custom YOLO"):
68
+ if custom_file is None:
69
+ raise RuntimeError("Upload a YOLO .pt/.pth checkpoint first.")
70
+ model = YOLO(str(custom_file))
71
+ elif choice.startswith("Custom RF‑DETR"):
72
+ if custom_file is None:
73
+ raise RuntimeError("Upload an RF‑DETR .pth checkpoint first.")
74
+ model = RFDETRBase(pretrain_weights=str(custom_file))
75
+ else:
76
+ raise RuntimeError(f"Unsupported model choice: {choice}")
 
 
 
77
 
78
  _loaded[choice] = model
79
  return model
 
82
  # Inference helpers
83
  ###############################################################################
84
 
85
+ BOX_THICKNESS = 2 # thinner boxes
86
+ BOX_ALPHA = 0.6 # 60 % opacity
87
 
88
  box_annotator = sv.BoxAnnotator(thickness=BOX_THICKNESS)
89
  label_annotator = sv.LabelAnnotator()
 
98
  detections = model.predict(image, threshold=threshold)
99
  label_src = COCO_CLASSES
100
  else:
101
+ ul_res = model.predict(image, verbose=False)[0]
102
+ detections = sv.Detections.from_ultralytics(ul_res)
103
  label_src = model.names # type: ignore
104
 
105
  runtime = time.perf_counter() - start
106
 
107
+ img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
108
+ overlay = img_bgr.copy()
 
109
  overlay = box_annotator.annotate(overlay, detections)
110
  overlay = label_annotator.annotate(
111
  overlay,
112
  detections,
113
+ [f"{label_src[c]} {p:.2f}" for c, p in zip(detections.class_id, detections.confidence)],
114
  )
115
+ blended = _blend(img_bgr, overlay)
116
+ return Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)), runtime
 
 
117
 
118
  ###############################################################################
119
+ # Gradio generator callback with progress
120
  ###############################################################################
121
 
122
  def compare_models(
 
130
  if img.mode != "RGB":
131
  img = img.convert("RGB")
132
 
133
+ total_steps = len(models) * 2 # phase 1: load, phase 2: inference
134
+ progress = gr.Progress(total=total_steps)
135
+
136
+ # ----- Phase 1: preload weights -----
137
+ detectors: Dict[str, object] = {}
138
+ for i, name in enumerate(models, 1):
139
+ try:
140
+ detectors[name] = load_model(name, custom_file)
141
+ except Exception as exc:
142
+ detectors[name] = exc # store exception for later reporting
143
+ progress.update(i, desc=f"Loading {name}")
144
+
145
+ # ----- Phase 2: run inference -----
146
  results: List[Image.Image] = []
147
  legends: Dict[str, str] = {}
148
 
149
+ for j, name in enumerate(models, 1):
150
+ detector_or_err = detectors[name]
151
+ step_index = len(models) + j
152
+ if isinstance(detector_or_err, Exception):
153
+ # model failed to load
154
+ results.append(Image.new("RGB", img.size, (40, 40, 40)))
155
+ emsg = str(detector_or_err)
156
+ legends[name] = "Unavailable (weights not found)" if "No such file" in emsg or "not found" in emsg else f"ERROR: {emsg.splitlines()[0][:120]}"
157
+ progress.update(step_index, desc=f"Skipped {name}")
158
+ continue
159
  try:
160
+ annotated, latency = run_single_inference(detector_or_err, img, threshold)
 
161
  results.append(annotated)
162
+ legends[name] = f"{latency*1000:.1f} ms"
163
  except Exception as exc:
164
  results.append(Image.new("RGB", img.size, (40, 40, 40)))
165
+ legends[name] = f"ERROR: {str(exc).splitlines()[0][:120]}"
166
+ progress.update(step_index, desc=f"Inference {name}")
 
 
 
167
 
168
+ yield results, legends # final output
169
 
170
  ###############################################################################
171
  # Gradio UI
 
173
 
174
  def build_demo():
175
  with gr.Blocks(title="CV Model Comparison") as demo:
176
+ gr.Markdown("""# 🔍 Compare Object‑Detection Models\nUpload an image, select detectors, then click **Run Inference**.\nThin, semi‑transparent boxes highlight detections.""")
 
 
177
 
178
  with gr.Row():
179
  sel_models = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Models")