Update app.py
Browse files
app.py
CHANGED
@@ -1,22 +1,28 @@
|
|
1 |
-
# app.py – YOLOv8 Dataset Quality Evaluator for
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
12 |
from __future__ import annotations
|
13 |
|
14 |
import imghdr
|
15 |
import json
|
16 |
import os
|
|
|
17 |
import shutil
|
18 |
import tempfile
|
19 |
-
from collections import Counter
|
20 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
21 |
from dataclasses import dataclass
|
22 |
from pathlib import Path
|
@@ -29,16 +35,18 @@ import yaml
|
|
29 |
from PIL import Image
|
30 |
from tqdm import tqdm
|
31 |
|
32 |
-
#
|
|
|
|
|
33 |
try:
|
34 |
import cv2 # type: ignore
|
35 |
except ImportError:
|
36 |
-
cv2 = None
|
37 |
|
38 |
try:
|
39 |
import imagehash # type: ignore
|
40 |
except ImportError:
|
41 |
-
imagehash = None
|
42 |
|
43 |
try:
|
44 |
from ultralytics import YOLO # type: ignore
|
@@ -46,41 +54,40 @@ except ImportError:
|
|
46 |
YOLO = None # noqa: N806
|
47 |
|
48 |
try:
|
49 |
-
from
|
50 |
except ImportError:
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
# ----------------------------------------------------------------------------
|
54 |
-
FASTDUP_AVAILABLE = False # toggled if library present
|
55 |
|
56 |
-
# ----------------------------------------------------------------------------
|
57 |
-
# Dataclasses
|
58 |
-
# ----------------------------------------------------------------------------
|
59 |
@dataclass
|
60 |
class DuplicateGroup:
|
61 |
hash_val: str
|
62 |
paths: List[Path]
|
63 |
|
64 |
|
65 |
-
#
|
66 |
-
#
|
67 |
-
#
|
68 |
-
|
69 |
def load_yaml(path: Path) -> Dict:
|
70 |
-
with path.open(
|
71 |
return yaml.safe_load(f)
|
72 |
|
73 |
|
74 |
def parse_label_file(path: Path) -> List[Tuple[int, float, float, float, float]]:
|
75 |
-
|
76 |
-
|
|
|
|
|
77 |
for ln in f:
|
78 |
parts = ln.strip().split()
|
79 |
-
if len(parts)
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
return entries
|
84 |
|
85 |
|
86 |
def guess_image_dirs(root: Path) -> List[Path]:
|
@@ -96,46 +103,43 @@ def guess_image_dirs(root: Path) -> List[Path]:
|
|
96 |
|
97 |
def gather_dataset(root: Path, yaml_path: Path | None = None):
|
98 |
if yaml_path is None:
|
99 |
-
|
100 |
-
if not
|
101 |
-
raise FileNotFoundError("YAML not found
|
102 |
-
yaml_path =
|
103 |
-
meta = load_yaml(yaml_path)
|
104 |
|
|
|
105 |
img_dirs = guess_image_dirs(root)
|
106 |
if not img_dirs:
|
107 |
-
raise FileNotFoundError("
|
108 |
|
109 |
imgs = [p for d in img_dirs for p in d.rglob("*.*") if imghdr.what(p) is not None]
|
110 |
-
lbls
|
111 |
-
for p in imgs:
|
112 |
-
lbls.append(p.parent.parent / "labels" / f"{p.stem}.txt")
|
113 |
return imgs, lbls, meta
|
114 |
|
115 |
|
116 |
-
#
|
117 |
-
# Quality
|
118 |
-
#
|
119 |
-
|
120 |
-
def _is_corrupt(p: Path) -> bool:
|
121 |
try:
|
122 |
-
with Image.open(
|
123 |
im.verify()
|
124 |
return False
|
125 |
except Exception:
|
126 |
return True
|
127 |
|
128 |
|
129 |
-
def
|
130 |
miss_lbl = [i for i, l in zip(imgs, lbls) if not l.exists()]
|
131 |
miss_img = [l for l in lbls if l.exists() and not (l.parent.parent / "images" / f"{l.stem}{l.suffix}").exists()]
|
132 |
|
133 |
corrupt: List[Path] = []
|
134 |
with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as ex:
|
135 |
-
|
136 |
-
for
|
137 |
-
if
|
138 |
-
corrupt.append(
|
139 |
|
140 |
score = 100 - (len(miss_lbl) + len(miss_img) + len(corrupt)) / max(len(imgs), 1) * 100
|
141 |
return {
|
@@ -149,49 +153,49 @@ def check_integrity(imgs: List[Path], lbls: List[Path]) -> Dict:
|
|
149 |
}
|
150 |
|
151 |
|
152 |
-
def
|
153 |
cls_counts = Counter()
|
154 |
boxes_per_img = []
|
155 |
for l in lbls:
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
cls_counts.update([b[0] for b in boxes])
|
161 |
if not cls_counts:
|
162 |
-
return {"name": "Class balance", "score": 0, "details":
|
163 |
-
|
164 |
return {
|
165 |
"name": "Class balance",
|
166 |
-
"score":
|
167 |
"details": {
|
168 |
"class_counts": dict(cls_counts),
|
169 |
-
"
|
170 |
-
"min": int(np.min(boxes_per_img)
|
171 |
-
"max": int(np.max(boxes_per_img)
|
172 |
-
"mean": float(np.mean(boxes_per_img)
|
173 |
},
|
174 |
},
|
175 |
}
|
176 |
|
177 |
|
178 |
-
def
|
179 |
if cv2 is None:
|
180 |
-
return {"name": "Image quality", "score": 100, "details":
|
181 |
blurry, dark, bright = [], [], []
|
182 |
-
for p in tqdm(imgs, desc="
|
183 |
im = cv2.imread(str(p))
|
184 |
if im is None:
|
185 |
continue
|
186 |
gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
|
187 |
-
|
188 |
-
|
189 |
-
if
|
190 |
blurry.append(p)
|
191 |
-
if
|
192 |
dark.append(p)
|
193 |
-
if
|
194 |
bright.append(p)
|
|
|
195 |
bad = len(set(blurry + dark + bright))
|
196 |
score = 100 - bad / max(len(imgs), 1) * 100
|
197 |
return {
|
@@ -205,76 +209,58 @@ def image_quality(imgs: List[Path], blur_thresh: float = 100.0) -> Dict:
|
|
205 |
}
|
206 |
|
207 |
|
208 |
-
def
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
groups.append(DuplicateGroup(h, [Path(p) for p in lst]))
|
221 |
-
except ImportError:
|
222 |
-
use_fastdup = False
|
223 |
-
if not use_fastdup:
|
224 |
-
if imagehash is None:
|
225 |
-
return {"name": "Duplicates", "score": 100, "details": {"message": "imagehash not installed"}}
|
226 |
-
hashes: Dict[str, List[Path]] = {}
|
227 |
-
for p in tqdm(imgs, desc="Hashing", leave=False):
|
228 |
-
h = str(imagehash.average_hash(Image.open(p)))
|
229 |
-
hashes.setdefault(h, []).append(p)
|
230 |
-
groups = [DuplicateGroup(h, v) for h, v in hashes.items() if len(v) > 1]
|
231 |
-
|
232 |
-
dup_count = sum(len(g.paths) - 1 for g in groups)
|
233 |
-
score = 100 - dup_count / max(len(imgs), 1) * 100
|
234 |
return {
|
235 |
"name": "Duplicates",
|
236 |
"score": score,
|
237 |
-
"details": {"groups": [[str(p) for p in g
|
238 |
}
|
239 |
|
240 |
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
x1, y1, w1, h1 = box1
|
245 |
-
x2, y2, w2, h2 = box2
|
246 |
xa1, ya1, xa2, ya2 = x1 - w1 / 2, y1 - h1 / 2, x1 + w1 / 2, y1 + h1 / 2
|
247 |
xb1, yb1, xb2, yb2 = x2 - w2 / 2, y2 - h2 / 2, x2 + w2 / 2, y2 + h2 / 2
|
248 |
-
ix1, iy1 = max(xa1, xb1), max(ya1, yb1)
|
249 |
-
|
250 |
-
iw, ih = max(0, ix2 - ix1), max(0, iy2 - iy1)
|
251 |
inter = iw * ih
|
252 |
union = w1 * h1 + w2 * h2 - inter
|
253 |
return inter / union if union else 0
|
254 |
|
255 |
|
256 |
-
def
|
257 |
if weights is None or YOLO is None:
|
258 |
-
return {"name": "Model QA", "score": 100, "details":
|
|
|
259 |
model = YOLO(weights)
|
260 |
-
ious
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
for
|
267 |
-
|
268 |
-
for
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
if best < iou_thr:
|
277 |
-
mism.append(pth)
|
278 |
miou = float(np.mean(ious)) if ious else 1.0
|
279 |
return {
|
280 |
"name": "Model QA",
|
@@ -283,88 +269,160 @@ def model_qa(imgs: List[Path], lbls: List[Path], weights: str | None, iou_thr: f
|
|
283 |
}
|
284 |
|
285 |
|
286 |
-
#
|
287 |
-
|
288 |
-
|
289 |
-
DEFAULT_WEIGHTS = {
|
290 |
-
"Integrity": 0.3,
|
291 |
"Class balance": 0.15,
|
292 |
"Image quality": 0.15,
|
293 |
-
"Duplicates": 0.
|
294 |
-
"Model QA": 0.
|
295 |
}
|
296 |
|
297 |
|
298 |
-
def aggregate(
|
299 |
-
return sum(
|
300 |
|
301 |
|
302 |
-
#
|
303 |
-
#
|
304 |
-
#
|
|
|
305 |
|
306 |
-
def
|
307 |
-
|
308 |
-
|
|
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
if
|
313 |
-
|
314 |
-
shutil.unpack_archive(dataset_zip.name, tmp)
|
315 |
-
root = tmp
|
316 |
-
else:
|
317 |
-
root = Path(dataset_path)
|
318 |
|
319 |
-
|
|
|
|
|
320 |
|
321 |
-
imgs, lbls, _ = gather_dataset(root, yaml_path)
|
322 |
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
329 |
]
|
330 |
-
final = aggregate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
if r["details"]:
|
337 |
-
lines.append("<details><summary>Details</summary>\n\n```json")
|
338 |
-
lines.append(json.dumps(r["details"], indent=2))
|
339 |
-
lines.append("```\n</details>\n")
|
340 |
-
md = "\n".join(lines)
|
341 |
|
342 |
-
|
343 |
-
|
|
|
344 |
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
|
|
348 |
|
|
|
|
|
|
|
349 |
|
350 |
-
with gr.Blocks(title="YOLO Dataset Quality Evaluator") as demo:
|
351 |
-
gr.Markdown("""## YOLOv8 Dataset Quality Evaluator
|
352 |
-
Upload a Roboflow‑exported (or generic YOLO) dataset and get a quick quality report.
|
353 |
-
* Provide either a ZIP file or a server path.
|
354 |
-
* Optionally add trained weights to enable model‑assisted checks.
|
355 |
-
""")
|
356 |
with gr.Row():
|
357 |
zip_in = gr.File(label="Dataset ZIP")
|
358 |
-
path_in = gr.Textbox(label="
|
|
|
359 |
with gr.Row():
|
360 |
-
yaml_in = gr.File(label="
|
361 |
weights_in = gr.File(label="YOLO weights (.pt)")
|
362 |
-
|
|
|
363 |
out_md = gr.Markdown()
|
364 |
out_df = gr.Dataframe()
|
365 |
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
368 |
|
369 |
if __name__ == "__main__":
|
370 |
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|
|
|
1 |
+
# app.py – Roboflow‑aware YOLOv8 Dataset Quality Evaluator for Hugging Face Spaces
|
2 |
+
#
|
3 |
+
# ▸ Prompts for a Roboflow **API key** and a `.txt` list of Universe dataset URLs (one per line)
|
4 |
+
# ▸ Downloads each dataset automatically in YOLOv8 format to a temp directory
|
5 |
+
# ▸ Runs a battery of quality checks:
|
6 |
+
# – integrity / corruption
|
7 |
+
# – class‑balance stats
|
8 |
+
# – blur / brightness image‑quality flags
|
9 |
+
# – exact / near‑duplicate detection
|
10 |
+
# – optional model‑assisted label QA (needs a YOLO .pt weights file)
|
11 |
+
# ▸ Still supports manual ZIP / server‑path evaluation
|
12 |
+
# ▸ Outputs a Markdown report + class‑distribution dataframe
|
13 |
+
#
|
14 |
+
# Hugging Face Spaces picks up `app.py` automatically. Dependencies go in `requirements.txt`.
|
15 |
+
# Spaces injects the port as $PORT – we pass it to demo.launch().
|
16 |
+
|
17 |
from __future__ import annotations
|
18 |
|
19 |
import imghdr
|
20 |
import json
|
21 |
import os
|
22 |
+
import re
|
23 |
import shutil
|
24 |
import tempfile
|
25 |
+
from collections import Counter, defaultdict
|
26 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
27 |
from dataclasses import dataclass
|
28 |
from pathlib import Path
|
|
|
35 |
from PIL import Image
|
36 |
from tqdm import tqdm
|
37 |
|
38 |
+
# --------------------------------------------------------------------------- #
|
39 |
+
# Optional heavy deps – present locally, but fine‑grained to keep Spaces slim #
|
40 |
+
# --------------------------------------------------------------------------- #
|
41 |
try:
|
42 |
import cv2 # type: ignore
|
43 |
except ImportError:
|
44 |
+
cv2 = None
|
45 |
|
46 |
try:
|
47 |
import imagehash # type: ignore
|
48 |
except ImportError:
|
49 |
+
imagehash = None
|
50 |
|
51 |
try:
|
52 |
from ultralytics import YOLO # type: ignore
|
|
|
54 |
YOLO = None # noqa: N806
|
55 |
|
56 |
try:
|
57 |
+
from roboflow import Roboflow # type: ignore
|
58 |
except ImportError:
|
59 |
+
Roboflow = None # type: ignore
|
60 |
+
|
61 |
+
# --------------------------------------------------------------------------- #
|
62 |
+
TMP_ROOT = Path(tempfile.gettempdir()) / "rf_datasets"
|
63 |
+
TMP_ROOT.mkdir(parents=True, exist_ok=True)
|
64 |
|
|
|
|
|
65 |
|
|
|
|
|
|
|
66 |
@dataclass
|
67 |
class DuplicateGroup:
|
68 |
hash_val: str
|
69 |
paths: List[Path]
|
70 |
|
71 |
|
72 |
+
# --------------------------------------------------------------------------- #
|
73 |
+
# Generic helpers #
|
74 |
+
# --------------------------------------------------------------------------- #
|
|
|
75 |
def load_yaml(path: Path) -> Dict:
|
76 |
+
with path.open(encoding="utf-8") as f:
|
77 |
return yaml.safe_load(f)
|
78 |
|
79 |
|
80 |
def parse_label_file(path: Path) -> List[Tuple[int, float, float, float, float]]:
|
81 |
+
out: List[Tuple[int, float, float, float, float]] = []
|
82 |
+
if not path.exists():
|
83 |
+
return out
|
84 |
+
with path.open(encoding="utf-8") as f:
|
85 |
for ln in f:
|
86 |
parts = ln.strip().split()
|
87 |
+
if len(parts) == 5:
|
88 |
+
cid, *coords = parts
|
89 |
+
out.append((int(cid), *map(float, coords)))
|
90 |
+
return out
|
|
|
91 |
|
92 |
|
93 |
def guess_image_dirs(root: Path) -> List[Path]:
|
|
|
103 |
|
104 |
def gather_dataset(root: Path, yaml_path: Path | None = None):
|
105 |
if yaml_path is None:
|
106 |
+
yamls = list(root.glob("*.yaml"))
|
107 |
+
if not yamls:
|
108 |
+
raise FileNotFoundError("Dataset YAML not found")
|
109 |
+
yaml_path = yamls[0]
|
|
|
110 |
|
111 |
+
meta = load_yaml(yaml_path)
|
112 |
img_dirs = guess_image_dirs(root)
|
113 |
if not img_dirs:
|
114 |
+
raise FileNotFoundError("images/ directory hierarchy missing")
|
115 |
|
116 |
imgs = [p for d in img_dirs for p in d.rglob("*.*") if imghdr.what(p) is not None]
|
117 |
+
lbls = [p.parent.parent / "labels" / f"{p.stem}.txt" for p in imgs]
|
|
|
|
|
118 |
return imgs, lbls, meta
|
119 |
|
120 |
|
121 |
+
# --------------------------------------------------------------------------- #
|
122 |
+
# Quality‑check stages #
|
123 |
+
# --------------------------------------------------------------------------- #
|
124 |
+
def _is_corrupt(path: Path) -> bool:
|
|
|
125 |
try:
|
126 |
+
with Image.open(path) as im:
|
127 |
im.verify()
|
128 |
return False
|
129 |
except Exception:
|
130 |
return True
|
131 |
|
132 |
|
133 |
+
def qc_integrity(imgs: List[Path], lbls: List[Path]) -> Dict:
|
134 |
miss_lbl = [i for i, l in zip(imgs, lbls) if not l.exists()]
|
135 |
miss_img = [l for l in lbls if l.exists() and not (l.parent.parent / "images" / f"{l.stem}{l.suffix}").exists()]
|
136 |
|
137 |
corrupt: List[Path] = []
|
138 |
with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as ex:
|
139 |
+
fut = {ex.submit(_is_corrupt, p): p for p in imgs}
|
140 |
+
for f in tqdm(as_completed(fut), total=len(fut), desc="integrity", leave=False):
|
141 |
+
if f.result():
|
142 |
+
corrupt.append(fut[f])
|
143 |
|
144 |
score = 100 - (len(miss_lbl) + len(miss_img) + len(corrupt)) / max(len(imgs), 1) * 100
|
145 |
return {
|
|
|
153 |
}
|
154 |
|
155 |
|
156 |
+
def qc_class_balance(lbls: List[Path]) -> Dict:
|
157 |
cls_counts = Counter()
|
158 |
boxes_per_img = []
|
159 |
for l in lbls:
|
160 |
+
bs = parse_label_file(l)
|
161 |
+
boxes_per_img.append(len(bs))
|
162 |
+
cls_counts.update(b[0] for b in bs)
|
163 |
+
|
|
|
164 |
if not cls_counts:
|
165 |
+
return {"name": "Class balance", "score": 0, "details": "No labels"}
|
166 |
+
bal = min(cls_counts.values()) / max(cls_counts.values()) * 100
|
167 |
return {
|
168 |
"name": "Class balance",
|
169 |
+
"score": bal,
|
170 |
"details": {
|
171 |
"class_counts": dict(cls_counts),
|
172 |
+
"boxes_per_image": {
|
173 |
+
"min": int(np.min(boxes_per_img)),
|
174 |
+
"max": int(np.max(boxes_per_img)),
|
175 |
+
"mean": float(np.mean(boxes_per_img)),
|
176 |
},
|
177 |
},
|
178 |
}
|
179 |
|
180 |
|
181 |
+
def qc_image_quality(imgs: List[Path], blur_thr: float = 100.0) -> Dict:
|
182 |
if cv2 is None:
|
183 |
+
return {"name": "Image quality", "score": 100, "details": "cv2 not installed"}
|
184 |
blurry, dark, bright = [], [], []
|
185 |
+
for p in tqdm(imgs, desc="img‑quality", leave=False):
|
186 |
im = cv2.imread(str(p))
|
187 |
if im is None:
|
188 |
continue
|
189 |
gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
|
190 |
+
lap = cv2.Laplacian(gray, cv2.CV_64F).var()
|
191 |
+
br = np.mean(gray)
|
192 |
+
if lap < blur_thr:
|
193 |
blurry.append(p)
|
194 |
+
if br < 25:
|
195 |
dark.append(p)
|
196 |
+
if br > 230:
|
197 |
bright.append(p)
|
198 |
+
|
199 |
bad = len(set(blurry + dark + bright))
|
200 |
score = 100 - bad / max(len(imgs), 1) * 100
|
201 |
return {
|
|
|
209 |
}
|
210 |
|
211 |
|
212 |
+
def qc_duplicates(imgs: List[Path]) -> Dict:
|
213 |
+
if imagehash is None:
|
214 |
+
return {"name": "Duplicates", "score": 100, "details": "imagehash not installed"}
|
215 |
+
|
216 |
+
hashes: Dict[str, List[Path]] = defaultdict(list)
|
217 |
+
for p in tqdm(imgs, desc="hashing", leave=False):
|
218 |
+
h = str(imagehash.average_hash(Image.open(p)))
|
219 |
+
hashes[h].append(p)
|
220 |
+
|
221 |
+
groups = [g for g in hashes.values() if len(g) > 1]
|
222 |
+
dup = sum(len(g) - 1 for g in groups)
|
223 |
+
score = 100 - dup / max(len(imgs), 1) * 100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
return {
|
225 |
"name": "Duplicates",
|
226 |
"score": score,
|
227 |
+
"details": {"groups": [[str(p) for p in g] for g in groups]},
|
228 |
}
|
229 |
|
230 |
|
231 |
+
def _rel_iou(b1, b2):
|
232 |
+
x1, y1, w1, h1 = b1
|
233 |
+
x2, y2, w2, h2 = b2
|
|
|
|
|
234 |
xa1, ya1, xa2, ya2 = x1 - w1 / 2, y1 - h1 / 2, x1 + w1 / 2, y1 + h1 / 2
|
235 |
xb1, yb1, xb2, yb2 = x2 - w2 / 2, y2 - h2 / 2, x2 + w2 / 2, y2 + h2 / 2
|
236 |
+
ix1, iy1, ix2, iy2 = max(xa1, xb1), max(ya1, yb1), min(xa2, xb2), min(ya2, yb2)
|
237 |
+
iw, ih = max(ix2 - ix1, 0), max(iy2 - iy1, 0)
|
|
|
238 |
inter = iw * ih
|
239 |
union = w1 * h1 + w2 * h2 - inter
|
240 |
return inter / union if union else 0
|
241 |
|
242 |
|
243 |
+
def qc_model_qa(imgs: List[Path], weights: str | None, lbls: List[Path], iou_thr: float = 0.5) -> Dict:
|
244 |
if weights is None or YOLO is None:
|
245 |
+
return {"name": "Model QA", "score": 100, "details": "weights or YOLO unavailable"}
|
246 |
+
|
247 |
model = YOLO(weights)
|
248 |
+
ious, mism = [], []
|
249 |
+
for p in tqdm(imgs, desc="model‑QA", leave=False):
|
250 |
+
gtb = parse_label_file(p.parent.parent / "labels" / f"{p.stem}.txt")
|
251 |
+
if not gtb:
|
252 |
+
continue
|
253 |
+
res = model.predict(p, verbose=False)[0]
|
254 |
+
for cls, x, y, w, h in gtb:
|
255 |
+
best = 0.0
|
256 |
+
for b, c in zip(res.boxes.xywh, res.boxes.cls):
|
257 |
+
if int(c) != cls:
|
258 |
+
continue
|
259 |
+
best = max(best, _rel_iou((x, y, w, h), tuple(b.tolist())))
|
260 |
+
ious.append(best)
|
261 |
+
if best < iou_thr:
|
262 |
+
mism.append(p)
|
263 |
+
|
|
|
|
|
264 |
miou = float(np.mean(ious)) if ious else 1.0
|
265 |
return {
|
266 |
"name": "Model QA",
|
|
|
269 |
}
|
270 |
|
271 |
|
272 |
+
# --------------------------------------------------------------------------- #
|
273 |
+
DEFAULT_W = {
|
274 |
+
"Integrity": 0.30,
|
|
|
|
|
275 |
"Class balance": 0.15,
|
276 |
"Image quality": 0.15,
|
277 |
+
"Duplicates": 0.10,
|
278 |
+
"Model QA": 0.30,
|
279 |
}
|
280 |
|
281 |
|
282 |
+
def aggregate(scores):
|
283 |
+
return sum(DEFAULT_W.get(r["name"], 0) * r["score"] for r in scores)
|
284 |
|
285 |
|
286 |
+
# --------------------------------------------------------------------------- #
|
287 |
+
# Roboflow helpers #
|
288 |
+
# --------------------------------------------------------------------------- #
|
289 |
+
RF_RE = re.compile(r"https://universe\.roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)")
|
290 |
|
291 |
+
def download_rf_dataset(url: str, rf_api: "Roboflow", dest: Path) -> Path:
|
292 |
+
m = RF_RE.match(url.strip())
|
293 |
+
if not m:
|
294 |
+
raise ValueError(f"Bad RF URL: {url}")
|
295 |
|
296 |
+
ws, proj, ver = m.groups()
|
297 |
+
ds_dir = dest / f"{ws}_{proj}_v{ver}"
|
298 |
+
if ds_dir.exists():
|
299 |
+
return ds_dir
|
|
|
|
|
|
|
|
|
300 |
|
301 |
+
project = rf_api.workspace(ws).project(proj)
|
302 |
+
project.version(int(ver)).download("yolov8", location=str(ds_dir))
|
303 |
+
return ds_dir
|
304 |
|
|
|
305 |
|
306 |
+
# --------------------------------------------------------------------------- #
|
307 |
+
# Main evaluation logic #
|
308 |
+
# --------------------------------------------------------------------------- #
|
309 |
+
def run_quality(root: Path, yaml_override: Path | None, weights: Path | None):
|
310 |
+
imgs, lbls, meta = gather_dataset(root, yaml_override)
|
311 |
+
res = [
|
312 |
+
qc_integrity(imgs, lbls),
|
313 |
+
qc_class_balance(lbls),
|
314 |
+
qc_image_quality(imgs),
|
315 |
+
qc_duplicates(imgs),
|
316 |
+
qc_model_qa(imgs, str(weights) if weights else None, lbls),
|
317 |
]
|
318 |
+
final = aggregate(res)
|
319 |
+
# markdown
|
320 |
+
md = [f"## **{meta.get('name', root.name)}** — Score {final:.1f}/100"]
|
321 |
+
for r in res:
|
322 |
+
md.append(f"### {r['name']} {r['score']:.1f}")
|
323 |
+
md.append("<details><summary>details</summary>\n\n```json")
|
324 |
+
md.append(json.dumps(r["details"], indent=2))
|
325 |
+
md.append("```\n</details>\n")
|
326 |
+
md_str = "\n".join(md)
|
327 |
+
|
328 |
+
cls_counts = res[1]["details"].get("class_counts", {}) # type: ignore[index]
|
329 |
+
df = pd.DataFrame.from_dict(cls_counts, orient="index", columns=["count"])
|
330 |
+
df.index.name = "class"
|
331 |
+
return md_str, df
|
332 |
+
|
333 |
+
|
334 |
+
# --------------------------------------------------------------------------- #
|
335 |
+
# Gradio interface #
|
336 |
+
# --------------------------------------------------------------------------- #
|
337 |
+
def evaluate(
|
338 |
+
api_key: str,
|
339 |
+
url_txt: gr.File | None,
|
340 |
+
zip_file: gr.File | None,
|
341 |
+
server_path: str,
|
342 |
+
yaml_file: gr.File | None,
|
343 |
+
weights: gr.File | None,
|
344 |
+
):
|
345 |
+
if not any([url_txt, zip_file, server_path]):
|
346 |
+
return "Upload a .txt of URLs or dataset ZIP/path", pd.DataFrame()
|
347 |
+
|
348 |
+
reports, dfs = [], []
|
349 |
+
|
350 |
+
# ---- Roboflow batch mode ----
|
351 |
+
if url_txt:
|
352 |
+
if Roboflow is None:
|
353 |
+
return "`roboflow` not installed", pd.DataFrame()
|
354 |
+
if not api_key:
|
355 |
+
return "Enter Roboflow API key", pd.DataFrame()
|
356 |
+
|
357 |
+
rf = Roboflow(api_key=api_key.strip())
|
358 |
+
txt_lines = Path(url_txt.name).read_text().splitlines()
|
359 |
+
for line in txt_lines:
|
360 |
+
if not line.strip():
|
361 |
+
continue
|
362 |
+
try:
|
363 |
+
ds_root = download_rf_dataset(line, rf, TMP_ROOT)
|
364 |
+
md, df = run_quality(ds_root, None, Path(weights.name) if weights else None)
|
365 |
+
reports.append(md)
|
366 |
+
dfs.append(df)
|
367 |
+
except Exception as e:
|
368 |
+
reports.append(f"### {line}\n\n⚠️ {e}")
|
369 |
+
|
370 |
+
# ---- Manual ZIP ----
|
371 |
+
if zip_file:
|
372 |
+
tmp_dir = Path(tempfile.mkdtemp())
|
373 |
+
shutil.unpack_archive(zip_file.name, tmp_dir)
|
374 |
+
md, df = run_quality(tmp_dir, Path(yaml_file.name) if yaml_file else None, Path(weights.name) if weights else None)
|
375 |
+
reports.append(md)
|
376 |
+
dfs.append(df)
|
377 |
+
shutil.rmtree(tmp_dir, ignore_errors=True)
|
378 |
+
|
379 |
+
# ---- Manual path ----
|
380 |
+
if server_path:
|
381 |
+
md, df = run_quality(Path(server_path), Path(yaml_file.name) if yaml_file else None, Path(weights.name) if weights else None)
|
382 |
+
reports.append(md)
|
383 |
+
dfs.append(df)
|
384 |
+
|
385 |
+
summary_md = "\n\n---\n\n".join(reports)
|
386 |
+
combined_df = pd.concat(dfs).groupby(level=0).sum() if dfs else pd.DataFrame()
|
387 |
+
return summary_md, combined_df
|
388 |
+
|
389 |
|
390 |
+
with gr.Blocks(title="YOLO Dataset Quality Evaluator") as demo:
|
391 |
+
gr.Markdown(
|
392 |
+
"""
|
393 |
+
# YOLOv8 Dataset Quality Evaluator
|
|
|
|
|
|
|
|
|
|
|
394 |
|
395 |
+
### Roboflow batch
|
396 |
+
1. Paste your **Roboflow API key**
|
397 |
+
2. Upload a **.txt** file – one `https://universe.roboflow.com/.../dataset/x` per line
|
398 |
|
399 |
+
### Manual
|
400 |
+
* Upload a dataset **ZIP** or type a dataset **path** on the server
|
401 |
+
* Optionally supply a custom **data.yaml** and/or a **YOLO .pt** weights file for model‑assisted QA
|
402 |
+
"""
|
403 |
+
)
|
404 |
|
405 |
+
with gr.Row():
|
406 |
+
api_in = gr.Textbox(label="Roboflow API key", type="password", placeholder="rf_XXXXXXXXXXXXXXXX")
|
407 |
+
url_txt_in = gr.File(label=".txt of RF dataset URLs", file_types=[".txt"])
|
408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
with gr.Row():
|
410 |
zip_in = gr.File(label="Dataset ZIP")
|
411 |
+
path_in = gr.Textbox(label="Path on server", placeholder="/data/my_dataset")
|
412 |
+
|
413 |
with gr.Row():
|
414 |
+
yaml_in = gr.File(label="Custom YAML", file_types=[".yaml"])
|
415 |
weights_in = gr.File(label="YOLO weights (.pt)")
|
416 |
+
|
417 |
+
run_btn = gr.Button("Evaluate")
|
418 |
out_md = gr.Markdown()
|
419 |
out_df = gr.Dataframe()
|
420 |
|
421 |
+
run_btn.click(
|
422 |
+
evaluate,
|
423 |
+
inputs=[api_in, url_txt_in, zip_in, path_in, yaml_in, weights_in],
|
424 |
+
outputs=[out_md, out_df],
|
425 |
+
)
|
426 |
|
427 |
if __name__ == "__main__":
|
428 |
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|