wuhp commited on
Commit
63eb207
·
verified ·
1 Parent(s): 3c0560e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app to compare object‑detection models:
3
+ • Ultralytics YOLOv12 (n, s, m, l, x)
4
+ • Ultralytics YOLOv11 (n, s, m, l, x)
5
+ • Roboflow RF‑DETR (Base, Large)
6
+ • Custom fine‑tuned checkpoints for either framework
7
+ Requires Python ≥3.9 plus:
8
+ pip install gradio ultralytics rfdetr supervision pillow numpy torch torchvision
9
+ If you need ONNX export for RF‑DETR, also: pip install rfdetr[onnxexport]
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import time
15
+ from pathlib import Path
16
+ from typing import List, Tuple
17
+
18
+ import numpy as np
19
+ from PIL import Image
20
+ import gradio as gr
21
+ import supervision as sv
22
+ from ultralytics import YOLO
23
+ from rfdetr import RFDETRBase, RFDETRLarge
24
+ from rfdetr.util.coco_classes import COCO_CLASSES
25
+
26
+ # -----------------------------------------------------------------------------
27
+ # Model registry & lazy loader
28
+ # -----------------------------------------------------------------------------
29
+
30
+ YOLO_MODEL_MAP = {
31
+ # YOLOv12 sizes
32
+ "YOLOv12‑n": "yolov12n.pt",
33
+ "YOLOv12‑s": "yolov12s.pt",
34
+ "YOLOv12‑m": "yolov12m.pt",
35
+ "YOLOv12‑l": "yolov12l.pt",
36
+ "YOLOv12‑x": "yolov12x.pt",
37
+ # YOLOv11 sizes
38
+ "YOLOv11‑n": "yolov11n.pt",
39
+ "YOLOv11‑s": "yolov11s.pt",
40
+ "YOLOv11‑m": "yolov11m.pt",
41
+ "YOLOv11‑l": "yolov11l.pt",
42
+ "YOLOv11‑x": "yolov11x.pt",
43
+ }
44
+
45
+ RFDETR_MODEL_MAP = {
46
+ "RF‑DETR‑Base (29M)": "base", # handled explicitly
47
+ "RF‑DETR‑Large (128M)": "large",
48
+ }
49
+
50
+ ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [
51
+ "Custom YOLO (.pt/.pth)",
52
+ "Custom RF‑DETR (.pth)",
53
+ ]
54
+
55
+ _loaded = {}
56
+
57
+ def load_model(choice: str, custom_path: str | None = None):
58
+ """Lazy‑load and cache models to avoid re‑download between inferences."""
59
+ global _loaded
60
+ if choice in _loaded:
61
+ return _loaded[choice]
62
+
63
+ if choice in YOLO_MODEL_MAP:
64
+ mdl = YOLO(YOLO_MODEL_MAP[choice])
65
+ elif choice in RFDETR_MODEL_MAP:
66
+ mdl = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
67
+ elif choice.startswith("Custom YOLO"):
68
+ if not custom_path:
69
+ raise ValueError("Please provide a path to your YOLO checkpoint.")
70
+ mdl = YOLO(custom_path)
71
+ elif choice.startswith("Custom RF‑DETR"):
72
+ if not custom_path:
73
+ raise ValueError("Please provide a path to your RF‑DETR checkpoint.")
74
+ mdl = RFDETRBase(pretrain_weights=custom_path)
75
+ else:
76
+ raise ValueError(f"Unsupported model choice: {choice}")
77
+
78
+ _loaded[choice] = mdl
79
+ return mdl
80
+
81
+ # -----------------------------------------------------------------------------
82
+ # Inference helpers
83
+ # -----------------------------------------------------------------------------
84
+
85
+ box_annotator = sv.BoxAnnotator()
86
+ label_annotator = sv.LabelAnnotator()
87
+
88
+ def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[Image.Image, float]:
89
+ start = time.perf_counter()
90
+
91
+ # RF‑DETR already returns sv.Detections
92
+ if isinstance(model, (RFDETRBase, RFDETRLarge)):
93
+ detections = model.predict(image, threshold=threshold)
94
+ label_source = COCO_CLASSES
95
+ else:
96
+ # Ultralytics YOLO inference: returns list of Results
97
+ result = model.predict(image, verbose=False)[0]
98
+ detections = sv.Detections.from_ultralytics(result)
99
+ label_source = model.names # list of class names
100
+ runtime = time.perf_counter() - start
101
+
102
+ labels = [f"{label_source[cid]} {conf:.2f}" for cid, conf in zip(detections.class_id, detections.confidence)]
103
+ annotated = box_annotator.annotate(image.copy(), detections)
104
+ annotated = label_annotator.annotate(annotated, detections, labels)
105
+ return annotated, runtime
106
+
107
+ # -----------------------------------------------------------------------------
108
+ # Gradio UI logic
109
+ # -----------------------------------------------------------------------------
110
+
111
+ def compare_models(models: List[str], img: Image.Image, threshold: float, custom_path: str | None):
112
+ if img.mode != "RGB":
113
+ img = img.convert("RGB")
114
+ results = []
115
+ legends = []
116
+ for m in models:
117
+ model_obj = load_model(m, custom_path)
118
+ annotated, t = run_single_inference(model_obj, img, threshold)
119
+ results.append(annotated)
120
+ legends.append(f"{m} – {t*1000:.1f} ms")
121
+ return results, legends
122
+
123
+ # -----------------------------------------------------------------------------
124
+ # Launch Gradio Interface
125
+ # -----------------------------------------------------------------------------
126
+
127
+ def build_demo():
128
+ with gr.Blocks(title="CV Model Comparison") as demo:
129
+ gr.Markdown("""# 🔍 Compare Object‑Detection Models\nUpload an image and select one or more models to see their predictions side‑by‑side.""")
130
+
131
+ with gr.Row():
132
+ model_select = gr.CheckboxGroup(choices=ALL_MODELS, value=["YOLOv12‑n"], label="Select models")
133
+ threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence threshold")
134
+ custom_weight_path = gr.Textbox(label="Path to custom checkpoint (if selected)")
135
+ image_in = gr.Image(type="pil", label="Upload image")
136
+
137
+ with gr.Row():
138
+ gallery = gr.Gallery(label="Annotated results", columns=2, height="auto")
139
+
140
+ legends_out = gr.JSON(label="Runtime (ms)")
141
+
142
+ run_btn = gr.Button("Run Inference")
143
+ run_btn.click(compare_models, inputs=[model_select, image_in, threshold_slider, custom_weight_path], outputs=[gallery, legends_out])
144
+
145
+ return demo
146
+
147
+ # Execute when running directly
148
+ if __name__ == "__main__":
149
+ build_demo().launch()