wuhp commited on
Commit
6620c2f
Β·
verified Β·
1 Parent(s): 22b976e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -31
app.py CHANGED
@@ -1,11 +1,15 @@
1
  """
2
- app.py – Roboflow‑aware YOLOv8 Dataset Quality Evaluator (v3)
3
- ─────────────────────────────────────────────────────────────
4
- ChangelogΒ (2025‑04‑17Β b)
5
- β€’Β **Cleanlab** integration β†’ extra *label‑issue* metric (skips gracefully if lib missing).
6
- β€’Β New **BBoxΒ validity** check: flags coords outsideΒ [0,β€―1].
7
- β€’Β Weight table updated (IntegrityΒ 25β€―%, ModelΒ 20β€―%, CleanlabΒ 10β€―%, etc.).
8
- β€’Β Minor: switched to cached NumPy reader for labels, clarified envΒ vars.
 
 
 
 
9
  """
10
 
11
  from __future__ import annotations
@@ -29,7 +33,7 @@ import yaml
29
  from PIL import Image
30
  from tqdm import tqdm
31
 
32
- # ───────────────────────────── Optional heavy deps (fail‑soft) ──
33
  try:
34
  import cv2 # type: ignore
35
  except ImportError:
@@ -45,11 +49,6 @@ try:
45
  except ImportError:
46
  fastdup = None
47
 
48
- try:
49
- from cleanlab.object_detection import find_label_issues # type: ignore
50
- except (ImportError, AttributeError):
51
- find_label_issues = None # type: ignore
52
-
53
  try:
54
  from ultralytics import YOLO # type: ignore
55
  except ImportError:
@@ -64,17 +63,16 @@ except ImportError:
64
  TMP_ROOT = Path(tempfile.gettempdir()) / "rf_datasets"
65
  TMP_ROOT.mkdir(parents=True, exist_ok=True)
66
 
 
67
  CPU_COUNT = int(os.getenv("QC_CPU", max(1, (os.cpu_count() or 4) // 2)))
68
  BATCH = int(os.getenv("QC_BATCH", 16))
69
 
70
  DEFAULT_W = {
71
- "Integrity": 0.25,
72
  "Class balance": 0.15,
73
  "Image quality": 0.15,
74
  "Duplicates": 0.10,
75
- "BBox validity": 0.05,
76
- "Model QA": 0.20,
77
- "Cleanlab QA": 0.10,
78
  }
79
 
80
  @dataclass
@@ -88,19 +86,17 @@ def load_yaml(path: Path) -> Dict:
88
  with path.open(encoding="utf-8") as f:
89
  return yaml.safe_load(f)
90
 
91
- _label_cache: dict[Path, np.ndarray] = {}
92
 
93
- def load_labels_np(path: Path) -> np.ndarray:
94
- if path in _label_cache:
95
- return _label_cache[path]
96
  try:
97
  arr = np.loadtxt(path, dtype=float)
98
  if arr.ndim == 1:
99
  arr = arr.reshape(1, -1)
 
100
  except Exception:
101
- arr = np.empty((0, 5))
102
- _label_cache[path] = arr
103
- return arr
104
 
105
 
106
  def guess_image_dirs(root: Path) -> List[Path]:
@@ -168,9 +164,9 @@ def qc_class_balance(lbls: List[Path]):
168
  cls_counts = Counter()
169
  boxes_per_img = []
170
  for l in lbls:
171
- arr = load_labels_np(l) if l else np.empty((0, 5))
172
- boxes_per_img.append(len(arr))
173
- cls_counts.update(arr[:, 0].astype(int) if arr.size else [])
174
 
175
  if not cls_counts:
176
  return {"name": "Class balance", "score": 0, "details": "No labels"}
@@ -204,7 +200,10 @@ def qc_image_quality(imgs: List[Path], blur_thr: float = 100.0):
204
  if cv2 is None:
205
  return {"name": "Image quality", "score": 100, "details": "cv2 not installed"}
206
 
207
- blurry, dark, bright = [], [], []
 
 
 
208
  with ProcessPoolExecutor(max_workers=CPU_COUNT) as ex:
209
  for p, is_blur, is_dark, is_bright in tqdm(
210
  ex.map(lambda x: _quality_stat(x, blur_thr), imgs),
@@ -234,6 +233,7 @@ def qc_image_quality(imgs: List[Path], blur_thr: float = 100.0):
234
  # Duplicate images ---------------------------------------------
235
 
236
  def qc_duplicates(imgs: List[Path]):
 
237
  if fastdup is not None and len(imgs) > 50:
238
  try:
239
  fd = fastdup.create(input_dir=str(Path(imgs[0]).parent.parent), work_dir=str(TMP_ROOT / "fastdup"))
@@ -241,9 +241,13 @@ def qc_duplicates(imgs: List[Path]):
241
  clusters = fd.get_clusters()
242
  dup = sum(len(c) - 1 for c in clusters)
243
  score = 100 - dup / max(len(imgs), 1) * 100
244
- return {"name": "Duplicates", "score": score, "details": {"groups": clusters[:50]}}
 
 
 
 
245
  except Exception:
246
- pass
247
 
248
  if imagehash is None:
249
  return {"name": "Duplicates", "score": 100, "details": "skipped (deps)"}
@@ -252,4 +256,206 @@ def qc_duplicates(imgs: List[Path]):
252
  return str(imagehash.average_hash(Image.open(p)))
253
 
254
  hashes: Dict[str, List[Path]] = defaultdict(list)
255
- with ProcessPoolExecutor(max_workers=CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ app.py – Roboflow‑aware YOLOv8 Dataset Quality Evaluator (v2)
3
+
4
+ Changelog (2025‑04‑17)
5
+ ──────────────────────
6
+ β€’ **CPU‑bound loops parallelised** with `concurrent.futures.ProcessPoolExecutor`.
7
+ β€’ **Batch inference** in `qc_model_qa()` (GPU util ↑, latency ↓).
8
+ β€’ Optional **fastdup** path for duplicate detection (β‰ˆβ€―10Γ— faster on large sets).
9
+ β€’ Faster NumPy‑based `parse_label_file()`.
10
+ β€’ Small refactors β†’ clearer separation of stages & fewer globals.
11
+ β€’ Graceful degradation if heavy deps unavailable (cv2, imagehash, fastdup).
12
+ β€’ Tunable `CPU_COUNT` + env‑var guard for HF Spaces quota.
13
  """
14
 
15
  from __future__ import annotations
 
33
  from PIL import Image
34
  from tqdm import tqdm
35
 
36
+ # ───────────────────────────────────────── Heavy optional deps ──
37
  try:
38
  import cv2 # type: ignore
39
  except ImportError:
 
49
  except ImportError:
50
  fastdup = None
51
 
 
 
 
 
 
52
  try:
53
  from ultralytics import YOLO # type: ignore
54
  except ImportError:
 
63
  TMP_ROOT = Path(tempfile.gettempdir()) / "rf_datasets"
64
  TMP_ROOT.mkdir(parents=True, exist_ok=True)
65
 
66
+ # Limit CPU workers on HF Spaces (feel free to raise locally)
67
  CPU_COUNT = int(os.getenv("QC_CPU", max(1, (os.cpu_count() or 4) // 2)))
68
  BATCH = int(os.getenv("QC_BATCH", 16))
69
 
70
  DEFAULT_W = {
71
+ "Integrity": 0.30,
72
  "Class balance": 0.15,
73
  "Image quality": 0.15,
74
  "Duplicates": 0.10,
75
+ "Model QA": 0.30,
 
 
76
  }
77
 
78
  @dataclass
 
86
  with path.open(encoding="utf-8") as f:
87
  return yaml.safe_load(f)
88
 
 
89
 
90
+ def parse_label_file(path: Path) -> list[tuple[int, float, float, float, float]]:
91
+ if not path.exists() or path.stat().st_size == 0:
92
+ return []
93
  try:
94
  arr = np.loadtxt(path, dtype=float)
95
  if arr.ndim == 1:
96
  arr = arr.reshape(1, -1)
97
+ return [tuple(row) for row in arr]
98
  except Exception:
99
+ return []
 
 
100
 
101
 
102
  def guess_image_dirs(root: Path) -> List[Path]:
 
164
  cls_counts = Counter()
165
  boxes_per_img = []
166
  for l in lbls:
167
+ bs = parse_label_file(l) if l else []
168
+ boxes_per_img.append(len(bs))
169
+ cls_counts.update(b[0] for b in bs)
170
 
171
  if not cls_counts:
172
  return {"name": "Class balance", "score": 0, "details": "No labels"}
 
200
  if cv2 is None:
201
  return {"name": "Image quality", "score": 100, "details": "cv2 not installed"}
202
 
203
+ blurry: list[Path] = []
204
+ dark: list[Path] = []
205
+ bright: list[Path] = []
206
+
207
  with ProcessPoolExecutor(max_workers=CPU_COUNT) as ex:
208
  for p, is_blur, is_dark, is_bright in tqdm(
209
  ex.map(lambda x: _quality_stat(x, blur_thr), imgs),
 
233
  # Duplicate images ---------------------------------------------
234
 
235
  def qc_duplicates(imgs: List[Path]):
236
+ # Fast path – use fastdup if installed & enough images
237
  if fastdup is not None and len(imgs) > 50:
238
  try:
239
  fd = fastdup.create(input_dir=str(Path(imgs[0]).parent.parent), work_dir=str(TMP_ROOT / "fastdup"))
 
241
  clusters = fd.get_clusters()
242
  dup = sum(len(c) - 1 for c in clusters)
243
  score = 100 - dup / max(len(imgs), 1) * 100
244
+ return {
245
+ "name": "Duplicates",
246
+ "score": score,
247
+ "details": {"groups": clusters[:50]},
248
+ }
249
  except Exception:
250
+ pass # fallback to hash
251
 
252
  if imagehash is None:
253
  return {"name": "Duplicates", "score": 100, "details": "skipped (deps)"}
 
256
  return str(imagehash.average_hash(Image.open(p)))
257
 
258
  hashes: Dict[str, List[Path]] = defaultdict(list)
259
+ with ProcessPoolExecutor(max_workers=CPU_COUNT) as ex:
260
+ for h, p in tqdm(
261
+ zip(ex.map(_hash, imgs), imgs),
262
+ total=len(imgs),
263
+ desc="hashing",
264
+ leave=False,
265
+ ):
266
+ hashes[h].append(p)
267
+
268
+ groups = [g for g in hashes.values() if len(g) > 1]
269
+ dup = sum(len(g) - 1 for g in groups)
270
+ score = 100 - dup / max(len(imgs), 1) * 100
271
+ return {
272
+ "name": "Duplicates",
273
+ "score": score,
274
+ "details": {"groups": [[str(p) for p in g] for g in groups[:50]]},
275
+ }
276
+
277
+ # Model‑assisted QA --------------------------------------------
278
+
279
+ def _rel_iou(b1, b2):
280
+ x1, y1, w1, h1 = b1
281
+ x2, y2, w2, h2 = b2
282
+ xa1, ya1, xa2, ya2 = x1 - w1 / 2, y1 - h1 / 2, x1 + w1 / 2, y1 + h1 / 2
283
+ xb1, yb1, xb2, yb2 = x2 - w2 / 2, y2 - h2 / 2, x2 + w2 / 2, y2 + h2 / 2
284
+ ix1, iy1, ix2, iy2 = max(xa1, xb1), max(ya1, yb1), min(xa2, xb2), min(ya2, yb2)
285
+ inter = max(ix2 - ix1, 0) * max(iy2 - iy1, 0)
286
+ union = w1 * h1 + w2 * h2 - inter
287
+ return inter / union if union else 0.0
288
+
289
+
290
+ def qc_model_qa(imgs: List[Path], weights: str | None, lbls: List[Path], iou_thr: float = 0.5):
291
+ if weights is None or YOLO is None:
292
+ return {"name": "Model QA", "score": 100, "details": "skipped (no weights)"}
293
+
294
+ model = YOLO(weights)
295
+ ious, mism = [], []
296
+
297
+ for i in range(0, len(imgs), BATCH):
298
+ batch_paths = imgs[i : i + BATCH]
299
+ results = model.predict(batch_paths, verbose=False)
300
+ for p, res in zip(batch_paths, results):
301
+ gtb = parse_label_file(p.parent.parent / "labels" / f"{p.stem}.txt")
302
+ if not gtb:
303
+ continue
304
+ for cls, x, y, w, h in gtb:
305
+ best = 0.0
306
+ for b, c in zip(res.boxes.xywh.cpu().numpy(), res.boxes.cls.cpu().numpy()):
307
+ if int(c) != cls:
308
+ continue
309
+ best = max(best, _rel_iou((x, y, w, h), tuple(b)))
310
+ ious.append(best)
311
+ if best < iou_thr:
312
+ mism.append(str(p))
313
+
314
+ miou = float(np.mean(ious)) if ious else 1.0
315
+ return {
316
+ "name": "Model QA",
317
+ "score": miou * 100,
318
+ "details": {"mean_iou": miou, "mismatched_images": mism[:50]},
319
+ }
320
+
321
+ # Aggregate -----------------------------------------------------
322
+
323
+ def aggregate(scores):
324
+ return sum(DEFAULT_W.get(r["name"], 0) * r["score"] for r in scores)
325
+
326
+ # ───────────────────────────────────────── Roboflow helpers ────
327
+ RF_RE = re.compile(r"https://universe\.roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)")
328
+
329
+ def download_rf_dataset(url: str, rf_api: "Roboflow", dest: Path) -> Path:
330
+ m = RF_RE.match(url.strip())
331
+ if not m:
332
+ raise ValueError(f"Bad RF URL: {url}")
333
+
334
+ ws, proj, ver = m.groups()
335
+ ds_dir = dest / f"{ws}_{proj}_v{ver}"
336
+ if ds_dir.exists():
337
+ return ds_dir
338
+
339
+ project = rf_api.workspace(ws).project(proj)
340
+ project.version(int(ver)).download("yolov8", location=str(ds_dir))
341
+ return ds_dir
342
+
343
+ # ───────────────────────────────────────── Main logic ──────────
344
+
345
+ def run_quality(root: Path, yaml_override: Path | None, weights: Path | None):
346
+ imgs, lbls, meta = gather_dataset(root, yaml_override)
347
+ res = [
348
+ qc_integrity(imgs, lbls),
349
+ qc_class_balance(lbls),
350
+ qc_image_quality(imgs),
351
+ qc_duplicates(imgs),
352
+ qc_model_qa(imgs, str(weights) if weights else None, lbls),
353
+ ]
354
+ final = aggregate(res)
355
+
356
+ md = [f"## **{meta.get('name', root.name)}**Β β€”Β ScoreΒ {final:.1f}/100"]
357
+ for r in res:
358
+ md.append(f"### {r['name']}Β Β {r['score']:.1f}")
359
+ md.append("<details><summary>details</summary>\n\n```json")
360
+ md.append(json.dumps(r["details"], indent=2))
361
+ md.append("```\n</details>\n")
362
+ md_str = "\n".join(md)
363
+
364
+ cls_counts = res[1]["details"].get("class_counts", {}) # type: ignore[index]
365
+ df = pd.DataFrame.from_dict(cls_counts, orient="index", columns=["count"])
366
+ df.index.name = "class"
367
+ return md_str, df
368
+
369
+ # ────────────────────────────────────��──── Gradio UI ───────────
370
+
371
+ def evaluate(
372
+ api_key: str,
373
+ url_txt: gr.File | None,
374
+ zip_file: gr.File | None,
375
+ server_path: str,
376
+ yaml_file: gr.File | None,
377
+ weights: gr.File | None,
378
+ ):
379
+ if not any([url_txt, zip_file, server_path]):
380
+ return "Upload a .txt of URLs or dataset ZIP/path", pd.DataFrame()
381
+
382
+ reports, dfs = [], []
383
+
384
+ # Roboflow batch ------------------------------------------
385
+ if url_txt:
386
+ if Roboflow is None:
387
+ return "`roboflow` not installed", pd.DataFrame()
388
+ if not api_key:
389
+ return "Enter Roboflow API key", pd.DataFrame()
390
+
391
+ rf = Roboflow(api_key=api_key.strip())
392
+ for line in Path(url_txt.name).read_text().splitlines():
393
+ if not line.strip():
394
+ continue
395
+ try:
396
+ ds_root = download_rf_dataset(line, rf, TMP_ROOT)
397
+ md, df = run_quality(ds_root, None, Path(weights.name) if weights else None)
398
+ reports.append(md)
399
+ dfs.append(df)
400
+ except Exception as e:
401
+ reports.append(f"### {line}\n\n⚠️ {e}")
402
+
403
+ # Manual ZIP ----------------------------------------------
404
+ if zip_file:
405
+ tmp_dir = Path(tempfile.mkdtemp())
406
+ shutil.unpack_archive(zip_file.name, tmp_dir)
407
+ md, df = run_quality(tmp_dir, Path(yaml_file.name) if yaml_file else None, Path(weights.name) if weights else None)
408
+ reports.append(md)
409
+ dfs.append(df)
410
+ shutil.rmtree(tmp_dir, ignore_errors=True)
411
+
412
+ # Manual path ---------------------------------------------
413
+ if server_path:
414
+ md, df = run_quality(Path(server_path), Path(yaml_file.name) if yaml_file else None, Path(weights.name) if weights else None)
415
+ reports.append(md)
416
+ dfs.append(df)
417
+
418
+ summary_md = "\n\n---\n\n".join(reports)
419
+ combined_df = pd.concat(dfs).groupby(level=0).sum() if dfs else pd.DataFrame()
420
+ return summary_md, combined_df
421
+
422
+ # ───────────────────────────────────────── Launch ────────────
423
+ with gr.Blocks(title="YOLO Dataset Quality Evaluator") as demo:
424
+ gr.Markdown(
425
+ """
426
+ # YOLOv8 Dataset Quality Evaluator
427
+
428
+ ### Roboflow batch
429
+ 1. Paste your **Roboflow API key**
430
+ 2. Upload a **.txt** file – one `https://universe.roboflow.com/.../dataset/x` per line
431
+
432
+ ### Manual
433
+ * Upload a dataset **ZIP** or type a dataset **path** on the server
434
+ * Optionally supply a custom **data.yaml** and/or a **YOLOΒ .pt** weights file for model‑assisted QA
435
+ """
436
+ )
437
+
438
+ with gr.Row():
439
+ api_in = gr.Textbox(label="Roboflow API key", type="password", placeholder="rf_XXXXXXXXXXXXXXXX")
440
+ url_txt_in = gr.File(label=".txt of RF dataset URLs", file_types=[".txt"])
441
+
442
+ with gr.Row():
443
+ zip_in = gr.File(label="Dataset ZIP")
444
+ path_in = gr.Textbox(label="Path on server", placeholder="/data/my_dataset")
445
+
446
+ with gr.Row():
447
+ yaml_in = gr.File(label="Custom YAML", file_types=[".yaml"])
448
+ weights_in = gr.File(label="YOLO weights (.pt)")
449
+
450
+ run_btn = gr.Button("Evaluate")
451
+ out_md = gr.Markdown()
452
+ out_df = gr.Dataframe()
453
+
454
+ run_btn.click(
455
+ evaluate,
456
+ inputs=[api_in, url_txt_in, zip_in, path_in, yaml_in, weights_in],
457
+ outputs=[out_md, out_df],
458
+ )
459
+
460
+ if __name__ == "__main__":
461
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))