wuhp commited on
Commit
9503f8d
·
verified ·
1 Parent(s): 57b8ebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -53
app.py CHANGED
@@ -3,17 +3,19 @@ Gradio app to compare object‑detection models:
3
  • Ultralytics YOLOv12 (n, s, m, l, x)
4
  • Ultralytics YOLOv11 (n, s, m, l, x)
5
  • Roboflow RF‑DETR (Base, Large)
6
- • Custom fine‑tuned checkpoints for either framework
7
- Requires Python ≥3.9 plus:
8
- pip install gradio ultralytics rfdetr supervision pillow numpy torch torchvision
9
- If you need ONNX export for RF‑DETR, also: pip install rfdetr[onnxexport]
 
 
10
  """
11
 
12
  from __future__ import annotations
13
 
14
  import time
15
  from pathlib import Path
16
- from typing import List, Tuple
17
 
18
  import numpy as np
19
  from PIL import Image
@@ -23,18 +25,17 @@ from ultralytics import YOLO
23
  from rfdetr import RFDETRBase, RFDETRLarge
24
  from rfdetr.util.coco_classes import COCO_CLASSES
25
 
26
- # -----------------------------------------------------------------------------
27
  # Model registry & lazy loader
28
- # -----------------------------------------------------------------------------
29
 
30
  YOLO_MODEL_MAP = {
31
- # YOLOv12 sizes
32
  "YOLOv12‑n": "yolov12n.pt",
33
  "YOLOv12‑s": "yolov12s.pt",
34
  "YOLOv12‑m": "yolov12m.pt",
35
  "YOLOv12‑l": "yolov12l.pt",
36
  "YOLOv12‑x": "yolov12x.pt",
37
- # YOLOv11 sizes
38
  "YOLOv11‑n": "yolov11n.pt",
39
  "YOLOv11‑s": "yolov11s.pt",
40
  "YOLOv11‑m": "yolov11m.pt",
@@ -43,7 +44,7 @@ YOLO_MODEL_MAP = {
43
  }
44
 
45
  RFDETR_MODEL_MAP = {
46
- "RF‑DETR‑Base (29M)": "base", # handled explicitly
47
  "RF‑DETR‑Large (128M)": "large",
48
  }
49
 
@@ -52,35 +53,46 @@ ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [
52
  "Custom RF‑DETR (.pth)",
53
  ]
54
 
55
- _loaded = {}
56
 
57
- def load_model(choice: str, custom_path: str | None = None):
58
- """Lazy‑load and cache models to avoid re‑download between inferences."""
 
 
 
59
  global _loaded
60
  if choice in _loaded:
61
  return _loaded[choice]
62
 
63
- if choice in YOLO_MODEL_MAP:
64
- mdl = YOLO(YOLO_MODEL_MAP[choice])
65
- elif choice in RFDETR_MODEL_MAP:
66
- mdl = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
67
- elif choice.startswith("Custom YOLO"):
68
- if not custom_path:
69
- raise ValueError("Please provide a path to your YOLO checkpoint.")
70
- mdl = YOLO(custom_path)
71
- elif choice.startswith("Custom RF‑DETR"):
72
- if not custom_path:
73
- raise ValueError("Please provide a path to your RF‑DETR checkpoint.")
74
- mdl = RFDETRBase(pretrain_weights=custom_path)
75
- else:
76
- raise ValueError(f"Unsupported model choice: {choice}")
 
 
 
 
 
 
 
 
77
 
78
  _loaded[choice] = mdl
79
  return mdl
80
 
81
- # -----------------------------------------------------------------------------
82
  # Inference helpers
83
- # -----------------------------------------------------------------------------
84
 
85
  box_annotator = sv.BoxAnnotator()
86
  label_annotator = sv.LabelAnnotator()
@@ -88,15 +100,14 @@ label_annotator = sv.LabelAnnotator()
88
  def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]:
89
  start = time.perf_counter()
90
 
91
- # RF‑DETR already returns sv.Detections
92
  if isinstance(model, (RFDETRBase, RFDETRLarge)):
93
  detections = model.predict(image, threshold=threshold)
94
  label_source = COCO_CLASSES
95
- else:
96
- # Ultralytics YOLO inference: returns list of Results
97
  result = model.predict(image, verbose=False)[0]
98
  detections = sv.Detections.from_ultralytics(result)
99
- label_source = model.names # list of class names
 
100
  runtime = time.perf_counter() - start
101
 
102
  labels = [f"{label_source[cid]} {conf:.2f}" for cid, conf in zip(detections.class_id, detections.confidence)]
@@ -104,46 +115,57 @@ def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[I
104
  annotated = label_annotator.annotate(annotated, detections, labels)
105
  return annotated, runtime
106
 
107
- # -----------------------------------------------------------------------------
108
  # Gradio UI logic
109
- # -----------------------------------------------------------------------------
110
 
111
- def compare_models(models: List[str], img: Image.Image, threshold: float, custom_path: str | None):
 
 
112
  if img.mode != "RGB":
113
  img = img.convert("RGB")
114
- results = []
115
- legends = []
116
  for m in models:
117
- model_obj = load_model(m, custom_path)
118
- annotated, t = run_single_inference(model_obj, img, threshold)
119
- results.append(annotated)
120
- legends.append(f"{m} – {t*1000:.1f} ms")
 
 
 
 
 
 
121
  return results, legends
122
 
123
- # -----------------------------------------------------------------------------
124
- # Launch Gradio Interface
125
- # -----------------------------------------------------------------------------
126
 
127
  def build_demo():
128
  with gr.Blocks(title="CV Model Comparison") as demo:
129
- gr.Markdown("""# 🔍 Compare Object‑Detection Models\nUpload an image and select one or more models to see their predictions sideby‑side.""")
130
 
131
  with gr.Row():
132
  model_select = gr.CheckboxGroup(choices=ALL_MODELS, value=["YOLOv12‑n"], label="Select models")
133
  threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence threshold")
134
- custom_weight_path = gr.Textbox(label="Path to custom checkpoint (if selected)")
135
- image_in = gr.Image(type="pil", label="Upload image")
 
136
 
137
  with gr.Row():
138
  gallery = gr.Gallery(label="Annotated results", columns=2, height="auto")
 
139
 
140
- legends_out = gr.JSON(label="Runtime (ms)")
141
-
142
- run_btn = gr.Button("Run Inference")
143
- run_btn.click(compare_models, inputs=[model_select, image_in, threshold_slider, custom_weight_path], outputs=[gallery, legends_out])
 
 
144
 
145
  return demo
146
 
147
- # Execute when running directly
148
  if __name__ == "__main__":
149
  build_demo().launch()
 
3
  • Ultralytics YOLOv12 (n, s, m, l, x)
4
  • Ultralytics YOLOv11 (n, s, m, l, x)
5
  • Roboflow RF‑DETR (Base, Large)
6
+ • Custom fine‑tuned checkpoints for either framework (upload .pt/.pth files)
7
+
8
+ Python ≥3.9
9
+ Install:
10
+ pip install -r requirements.txt
11
+ Optionally, add GPU‑specific PyTorch wheels or `rfdetr[onnxexport]` for ONNX export.
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  import time
17
  from pathlib import Path
18
+ from typing import List, Tuple, Optional
19
 
20
  import numpy as np
21
  from PIL import Image
 
25
  from rfdetr import RFDETRBase, RFDETRLarge
26
  from rfdetr.util.coco_classes import COCO_CLASSES
27
 
28
+ ###############################################################################
29
  # Model registry & lazy loader
30
+ ###############################################################################
31
 
32
  YOLO_MODEL_MAP = {
33
+ # Names follow Ultralytics hub convention; they will be auto‑downloaded
34
  "YOLOv12‑n": "yolov12n.pt",
35
  "YOLOv12‑s": "yolov12s.pt",
36
  "YOLOv12‑m": "yolov12m.pt",
37
  "YOLOv12‑l": "yolov12l.pt",
38
  "YOLOv12‑x": "yolov12x.pt",
 
39
  "YOLOv11‑n": "yolov11n.pt",
40
  "YOLOv11‑s": "yolov11s.pt",
41
  "YOLOv11‑m": "yolov11m.pt",
 
44
  }
45
 
46
  RFDETR_MODEL_MAP = {
47
+ "RF‑DETR‑Base (29M)": "base",
48
  "RF‑DETR‑Large (128M)": "large",
49
  }
50
 
 
53
  "Custom RF‑DETR (.pth)",
54
  ]
55
 
56
+ _loaded = {} # cache of already‑instantiated models
57
 
58
+ def load_model(choice: str, custom_file: Optional[Path] = None):
59
+ """Return (and cache) a model instance for *choice*.
60
+ custom_file is a Path object (uploaded file) used when choice is custom.
61
+ Raises RuntimeError with helpful message if loading fails.
62
+ """
63
  global _loaded
64
  if choice in _loaded:
65
  return _loaded[choice]
66
 
67
+ try:
68
+ if choice in YOLO_MODEL_MAP:
69
+ weight_id = YOLO_MODEL_MAP[choice]
70
+ mdl = YOLO(weight_id) # Ultralytics downloads if not local
71
+ elif choice in RFDETR_MODEL_MAP:
72
+ mdl = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
73
+ elif choice.startswith("Custom YOLO"):
74
+ if not custom_file:
75
+ raise ValueError("Upload a YOLO .pt/.pth checkpoint first.")
76
+ mdl = YOLO(str(custom_file))
77
+ elif choice.startswith("Custom RF‑DETR"):
78
+ if not custom_file:
79
+ raise ValueError("Upload an RF‑DETR .pth checkpoint first.")
80
+ mdl = RFDETRBase(pretrain_weights=str(custom_file))
81
+ else:
82
+ raise ValueError(f"Unsupported model choice: {choice}")
83
+ except FileNotFoundError as e:
84
+ raise RuntimeError(
85
+ f"Weights for '{choice}' not found locally and could not be downloaded. "
86
+ "Place the .pt file in the working directory, supply a custom checkpoint, "
87
+ "or ensure the model is released on the Ultralytics hub.\n" + str(e)
88
+ ) from e
89
 
90
  _loaded[choice] = mdl
91
  return mdl
92
 
93
+ ###############################################################################
94
  # Inference helpers
95
+ ###############################################################################
96
 
97
  box_annotator = sv.BoxAnnotator()
98
  label_annotator = sv.LabelAnnotator()
 
100
  def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]:
101
  start = time.perf_counter()
102
 
 
103
  if isinstance(model, (RFDETRBase, RFDETRLarge)):
104
  detections = model.predict(image, threshold=threshold)
105
  label_source = COCO_CLASSES
106
+ else: # Ultralytics YOLO
 
107
  result = model.predict(image, verbose=False)[0]
108
  detections = sv.Detections.from_ultralytics(result)
109
+ label_source = model.names
110
+
111
  runtime = time.perf_counter() - start
112
 
113
  labels = [f"{label_source[cid]} {conf:.2f}" for cid, conf in zip(detections.class_id, detections.confidence)]
 
115
  annotated = label_annotator.annotate(annotated, detections, labels)
116
  return annotated, runtime
117
 
118
+ ###############################################################################
119
  # Gradio UI logic
120
+ ###############################################################################
121
 
122
+ def compare_models(models: List[str], img: Image.Image, threshold: float, custom_file: Optional[Path]):
123
+ if img is None:
124
+ raise gr.Error("Please upload an image first.")
125
  if img.mode != "RGB":
126
  img = img.convert("RGB")
127
+
128
+ results, legends = [], []
129
  for m in models:
130
+ try:
131
+ model_obj = load_model(m, custom_file)
132
+ annotated, t = run_single_inference(model_obj, img, threshold)
133
+ results.append(annotated)
134
+ legends.append(f"{m} – {t*1000:.1f} ms")
135
+ except Exception as e:
136
+ # Append a blank image with the error message overlayed
137
+ error_img = Image.new("RGB", img.size, color=(30, 30, 30))
138
+ legends.append(f"{m} – ERROR: {e}")
139
+ results.append(error_img)
140
  return results, legends
141
 
142
+ ###############################################################################
143
+ # Build & launch demo
144
+ ###############################################################################
145
 
146
  def build_demo():
147
  with gr.Blocks(title="CV Model Comparison") as demo:
148
+ gr.Markdown("""# 🔍 Compare Object‑Detection Models\nUpload an image, select detectors, and optionally upload a custom checkpoint.\nThe app annotates predictions and reports permodel latency.""")
149
 
150
  with gr.Row():
151
  model_select = gr.CheckboxGroup(choices=ALL_MODELS, value=["YOLOv12‑n"], label="Select models")
152
  threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence threshold")
153
+
154
+ custom_checkpoint = gr.File(label="Upload custom YOLO / RF‑DETR checkpoint", file_types=[".pt", ".pth"], interactive=True)
155
+ image_in = gr.Image(type="pil", label="Upload image", sources=["upload", "webcam"], show_label=True)
156
 
157
  with gr.Row():
158
  gallery = gr.Gallery(label="Annotated results", columns=2, height="auto")
159
+ legends_out = gr.JSON(label="Runtime (ms) or error messages")
160
 
161
+ run_btn = gr.Button("Run Inference", variant="primary")
162
+ run_btn.click(
163
+ fn=compare_models,
164
+ inputs=[model_select, image_in, threshold_slider, custom_checkpoint],
165
+ outputs=[gallery, legends_out],
166
+ )
167
 
168
  return demo
169
 
 
170
  if __name__ == "__main__":
171
  build_demo().launch()